Merge branch 'refractor' of https://github.com/tcmofashi/MaiMBot into refractor
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...common.database import Database
|
||||
from .message_base import UserInfo
|
||||
@@ -15,6 +16,7 @@ class Impression:
|
||||
class Relationship:
|
||||
user_id: int = None
|
||||
platform: str = None
|
||||
platform: str = None
|
||||
gender: str = None
|
||||
age: int = None
|
||||
nickname: str = None
|
||||
@@ -33,6 +35,7 @@ class Relationship:
|
||||
class RelationshipManager:
|
||||
def __init__(self):
|
||||
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
||||
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
||||
|
||||
async def update_relationship(self,
|
||||
chat_stream:ChatStream,
|
||||
@@ -63,16 +66,23 @@ class RelationshipManager:
|
||||
|
||||
# 检查是否在内存中已存在
|
||||
relationship = self.relationships.get(key)
|
||||
relationship = self.relationships.get(key)
|
||||
if relationship:
|
||||
# 如果存在,更新现有对象
|
||||
if isinstance(data, dict):
|
||||
for k, value in data.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
for k, value in data.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
else:
|
||||
for k, value in kwargs.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
for k, value in kwargs.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
else:
|
||||
# 如果不存在,创建新对象
|
||||
if user_info is not None:
|
||||
@@ -85,6 +95,16 @@ class RelationshipManager:
|
||||
kwargs['user_id'] = user_id
|
||||
relationship = Relationship(**kwargs)
|
||||
self.relationships[key] = relationship
|
||||
if user_info is not None:
|
||||
relationship = Relationship(user_info=user_info, **kwargs)
|
||||
elif isinstance(data, dict):
|
||||
data['platform'] = platform
|
||||
relationship = Relationship(user_id=user_id, data=data)
|
||||
else:
|
||||
kwargs['platform'] = platform
|
||||
kwargs['user_id'] = user_id
|
||||
relationship = Relationship(**kwargs)
|
||||
self.relationships[key] = relationship
|
||||
|
||||
# 保存到数据库
|
||||
await self.storage_relationship(relationship)
|
||||
@@ -92,6 +112,33 @@ class RelationshipManager:
|
||||
|
||||
return relationship
|
||||
|
||||
async def update_relationship_value(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None,
|
||||
**kwargs) -> Optional[Relationship]:
|
||||
"""更新关系值
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
**kwargs: 其他参数
|
||||
Returns:
|
||||
Relationship: 关系对象
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
# 使用(user_id, platform)作为键
|
||||
key = (user_id, platform)
|
||||
|
||||
async def update_relationship_value(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
@@ -121,7 +168,10 @@ class RelationshipManager:
|
||||
|
||||
# 检查是否在内存中已存在
|
||||
relationship = self.relationships.get(key)
|
||||
relationship = self.relationships.get(key)
|
||||
if relationship:
|
||||
for k, value in kwargs.items():
|
||||
if k == 'relationship_value':
|
||||
for k, value in kwargs.items():
|
||||
if k == 'relationship_value':
|
||||
relationship.relationship_value += value
|
||||
@@ -129,12 +179,41 @@ class RelationshipManager:
|
||||
relationship.saved = True
|
||||
return relationship
|
||||
else:
|
||||
# 如果不存在且提供了user_info,则创建新的关系
|
||||
if user_info is not None:
|
||||
return await self.update_relationship(user_info=user_info, **kwargs)
|
||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
||||
# 如果不存在且提供了user_info,则创建新的关系
|
||||
if user_info is not None:
|
||||
return await self.update_relationship(user_info=user_info, **kwargs)
|
||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
||||
return None
|
||||
|
||||
def get_relationship(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None) -> Optional[Relationship]:
|
||||
"""获取用户关系对象
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
Returns:
|
||||
Relationship: 关系对象
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
key = (user_id, platform)
|
||||
if key in self.relationships:
|
||||
return self.relationships[key]
|
||||
def get_relationship(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
@@ -169,10 +248,18 @@ class RelationshipManager:
|
||||
if 'platform' not in data:
|
||||
data['platform'] = 'qq'
|
||||
|
||||
rela = Relationship(data=data)
|
||||
"""从数据库加载或创建新的关系对象"""
|
||||
# 确保data中有platform字段,如果没有则默认为'qq'
|
||||
if 'platform' not in data:
|
||||
data['platform'] = 'qq'
|
||||
|
||||
rela = Relationship(data=data)
|
||||
rela.saved = True
|
||||
key = (rela.user_id, rela.platform)
|
||||
self.relationships[key] = rela
|
||||
key = (rela.user_id, rela.platform)
|
||||
self.relationships[key] = rela
|
||||
return rela
|
||||
|
||||
async def load_all_relationships(self):
|
||||
@@ -190,6 +277,7 @@ class RelationshipManager:
|
||||
# 依次加载每条记录
|
||||
for data in all_relationships:
|
||||
await self.load_relationship(data)
|
||||
await self.load_relationship(data)
|
||||
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
||||
|
||||
while True:
|
||||
@@ -200,15 +288,19 @@ class RelationshipManager:
|
||||
async def _save_all_relationships(self):
|
||||
"""将所有关系数据保存到数据库"""
|
||||
# 保存所有关系数据
|
||||
for (userid, platform), relationship in self.relationships.items():
|
||||
for (userid, platform), relationship in self.relationships.items():
|
||||
if not relationship.saved:
|
||||
relationship.saved = True
|
||||
await self.storage_relationship(relationship)
|
||||
|
||||
async def storage_relationship(self, relationship: Relationship):
|
||||
"""将关系记录存储到数据库中"""
|
||||
async def storage_relationship(self, relationship: Relationship):
|
||||
"""将关系记录存储到数据库中"""
|
||||
user_id = relationship.user_id
|
||||
platform = relationship.platform
|
||||
platform = relationship.platform
|
||||
nickname = relationship.nickname
|
||||
relationship_value = relationship.relationship_value
|
||||
gender = relationship.gender
|
||||
@@ -217,8 +309,10 @@ class RelationshipManager:
|
||||
|
||||
db = Database.get_instance()
|
||||
db.db.relationships.update_one(
|
||||
{'user_id': user_id, 'platform': platform},
|
||||
{'user_id': user_id, 'platform': platform},
|
||||
{'$set': {
|
||||
'platform': platform,
|
||||
'platform': platform,
|
||||
'nickname': nickname,
|
||||
'relationship_value': relationship_value,
|
||||
@@ -229,6 +323,28 @@ class RelationshipManager:
|
||||
upsert=True
|
||||
)
|
||||
|
||||
def get_name(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None) -> str:
|
||||
"""获取用户昵称
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
Returns:
|
||||
str: 用户昵称
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
def get_name(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
@@ -254,6 +370,11 @@ class RelationshipManager:
|
||||
# 确保user_id是整数类型
|
||||
user_id = int(user_id)
|
||||
key = (user_id, platform)
|
||||
if key in self.relationships:
|
||||
return self.relationships[key].nickname
|
||||
elif user_info is not None:
|
||||
return user_info.user_nickname or user_info.user_cardname or "某人"
|
||||
key = (user_id, platform)
|
||||
if key in self.relationships:
|
||||
return self.relationships[key].nickname
|
||||
elif user_info is not None:
|
||||
|
||||
Reference in New Issue
Block a user