Merge branch 'refractor' of https://github.com/tcmofashi/MaiMBot into refractor

This commit is contained in:
tcmofashi
2025-03-10 21:01:06 +08:00
4 changed files with 183 additions and 0 deletions

View File

@@ -1,5 +1,6 @@
import asyncio
from typing import Optional, Union
from typing import Optional, Union
from ...common.database import Database
from .message_base import UserInfo
@@ -15,6 +16,7 @@ class Impression:
class Relationship:
user_id: int = None
platform: str = None
platform: str = None
gender: str = None
age: int = None
nickname: str = None
@@ -33,6 +35,7 @@ class Relationship:
class RelationshipManager:
def __init__(self):
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
async def update_relationship(self,
chat_stream:ChatStream,
@@ -63,16 +66,23 @@ class RelationshipManager:
# 检查是否在内存中已存在
relationship = self.relationships.get(key)
relationship = self.relationships.get(key)
if relationship:
# 如果存在,更新现有对象
if isinstance(data, dict):
for k, value in data.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
for k, value in data.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else:
for k, value in kwargs.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
for k, value in kwargs.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else:
# 如果不存在,创建新对象
if user_info is not None:
@@ -85,6 +95,16 @@ class RelationshipManager:
kwargs['user_id'] = user_id
relationship = Relationship(**kwargs)
self.relationships[key] = relationship
if user_info is not None:
relationship = Relationship(user_info=user_info, **kwargs)
elif isinstance(data, dict):
data['platform'] = platform
relationship = Relationship(user_id=user_id, data=data)
else:
kwargs['platform'] = platform
kwargs['user_id'] = user_id
relationship = Relationship(**kwargs)
self.relationships[key] = relationship
# 保存到数据库
await self.storage_relationship(relationship)
@@ -92,6 +112,33 @@ class RelationshipManager:
return relationship
async def update_relationship_value(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None,
**kwargs) -> Optional[Relationship]:
"""更新关系值
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
async def update_relationship_value(self,
user_id: int = None,
platform: str = None,
@@ -121,7 +168,10 @@ class RelationshipManager:
# 检查是否在内存中已存在
relationship = self.relationships.get(key)
relationship = self.relationships.get(key)
if relationship:
for k, value in kwargs.items():
if k == 'relationship_value':
for k, value in kwargs.items():
if k == 'relationship_value':
relationship.relationship_value += value
@@ -129,12 +179,41 @@ class RelationshipManager:
relationship.saved = True
return relationship
else:
# 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(user_info=user_info, **kwargs)
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
# 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(user_info=user_info, **kwargs)
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
return None
def get_relationship(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> Optional[Relationship]:
"""获取用户关系对象
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key]
def get_relationship(self,
user_id: int = None,
platform: str = None,
@@ -169,10 +248,18 @@ class RelationshipManager:
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data)
"""从数据库加载或创建新的关系对象"""
# 确保data中有platform字段如果没有则默认为'qq'
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data)
rela.saved = True
key = (rela.user_id, rela.platform)
self.relationships[key] = rela
key = (rela.user_id, rela.platform)
self.relationships[key] = rela
return rela
async def load_all_relationships(self):
@@ -190,6 +277,7 @@ class RelationshipManager:
# 依次加载每条记录
for data in all_relationships:
await self.load_relationship(data)
await self.load_relationship(data)
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
while True:
@@ -200,15 +288,19 @@ class RelationshipManager:
async def _save_all_relationships(self):
"""将所有关系数据保存到数据库"""
# 保存所有关系数据
for (userid, platform), relationship in self.relationships.items():
for (userid, platform), relationship in self.relationships.items():
if not relationship.saved:
relationship.saved = True
await self.storage_relationship(relationship)
async def storage_relationship(self, relationship: Relationship):
"""将关系记录存储到数据库中"""
async def storage_relationship(self, relationship: Relationship):
"""将关系记录存储到数据库中"""
user_id = relationship.user_id
platform = relationship.platform
platform = relationship.platform
nickname = relationship.nickname
relationship_value = relationship.relationship_value
gender = relationship.gender
@@ -217,8 +309,10 @@ class RelationshipManager:
db = Database.get_instance()
db.db.relationships.update_one(
{'user_id': user_id, 'platform': platform},
{'user_id': user_id, 'platform': platform},
{'$set': {
'platform': platform,
'platform': platform,
'nickname': nickname,
'relationship_value': relationship_value,
@@ -229,6 +323,28 @@ class RelationshipManager:
upsert=True
)
def get_name(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> str:
"""获取用户昵称
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
str: 用户昵称
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
def get_name(self,
user_id: int = None,
platform: str = None,
@@ -254,6 +370,11 @@ class RelationshipManager:
# 确保user_id是整数类型
user_id = int(user_id)
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key].nickname
elif user_info is not None:
return user_info.user_nickname or user_info.user_cardname or "某人"
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key].nickname
elif user_info is not None: