Merge remote-tracking branch 'upstream/debug' into tc_refractor
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union
|
||||
from loguru import logger
|
||||
|
||||
from ...common.database import Database
|
||||
from .message_base import UserInfo
|
||||
@@ -10,9 +11,10 @@ class Impression:
|
||||
traits: str = None
|
||||
called: str = None
|
||||
know_time: float = None
|
||||
|
||||
|
||||
relationship_value: float = None
|
||||
|
||||
|
||||
class Relationship:
|
||||
user_id: int = None
|
||||
platform: str = None
|
||||
@@ -79,7 +81,7 @@ class RelationshipManager:
|
||||
# 保存到数据库
|
||||
await self.storage_relationship(relationship)
|
||||
relationship.saved = True
|
||||
|
||||
|
||||
return relationship
|
||||
|
||||
async def update_relationship_value(self,
|
||||
@@ -121,7 +123,7 @@ class RelationshipManager:
|
||||
# 如果不存在且提供了user_info,则创建新的关系
|
||||
if user_info is not None:
|
||||
return await self.update_relationship(chat_stream=chat_stream, **kwargs)
|
||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
||||
logger.warning(f"[关系管理] 用户 {user_id}({platform}) 不存在,无法更新")
|
||||
return None
|
||||
|
||||
def get_relationship(self,
|
||||
@@ -151,7 +153,7 @@ class RelationshipManager:
|
||||
return self.relationships[key]
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
async def load_relationship(self, data: dict) -> Relationship:
|
||||
"""从数据库加载或创建新的关系对象"""
|
||||
# 确保data中有platform字段,如果没有则默认为'qq'
|
||||
@@ -163,14 +165,14 @@ class RelationshipManager:
|
||||
key = (rela.user_id, rela.platform)
|
||||
self.relationships[key] = rela
|
||||
return rela
|
||||
|
||||
|
||||
async def load_all_relationships(self):
|
||||
"""加载所有关系对象"""
|
||||
db = Database.get_instance()
|
||||
all_relationships = db.db.relationships.find({})
|
||||
for data in all_relationships:
|
||||
await self.load_relationship(data)
|
||||
|
||||
|
||||
async def _start_relationship_manager(self):
|
||||
"""每5分钟自动保存一次关系数据"""
|
||||
db = Database.get_instance()
|
||||
@@ -179,15 +181,15 @@ class RelationshipManager:
|
||||
# 依次加载每条记录
|
||||
for data in all_relationships:
|
||||
await self.load_relationship(data)
|
||||
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
||||
logger.debug(f"[关系管理] 已加载 {len(self.relationships)} 条关系记录")
|
||||
|
||||
while True:
|
||||
print("\033[1;32m[关系管理]\033[0m 正在自动保存关系")
|
||||
logger.debug("正在自动保存关系")
|
||||
await asyncio.sleep(300) # 等待300秒(5分钟)
|
||||
await self._save_all_relationships()
|
||||
|
||||
|
||||
async def _save_all_relationships(self):
|
||||
"""将所有关系数据保存到数据库"""
|
||||
"""将所有关系数据保存到数据库"""
|
||||
# 保存所有关系数据
|
||||
for (userid, platform), relationship in self.relationships.items():
|
||||
if not relationship.saved:
|
||||
@@ -203,7 +205,7 @@ class RelationshipManager:
|
||||
gender = relationship.gender
|
||||
age = relationship.age
|
||||
saved = relationship.saved
|
||||
|
||||
|
||||
db = Database.get_instance()
|
||||
db.db.relationships.update_one(
|
||||
{'user_id': user_id, 'platform': platform},
|
||||
@@ -252,4 +254,4 @@ class RelationshipManager:
|
||||
return "某人"
|
||||
|
||||
|
||||
relationship_manager = RelationshipManager()
|
||||
relationship_manager = RelationshipManager()
|
||||
|
||||
Reference in New Issue
Block a user