diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index e16930ffe..48a7c2740 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -20,6 +20,9 @@ install(extra_lines=3) logger = get_logger("chat_stream") +# 用于存储后台任务的集合,防止被垃圾回收 +_background_tasks: set[asyncio.Task] = set() + class ChatStream: """聊天流对象,存储一个完整的聊天上下文""" @@ -406,6 +409,40 @@ class ChatManager: key = "_".join(components) return hashlib.sha256(key.encode()).hexdigest() + async def _process_message(self, message: DatabaseMessages): + """ + [新] 在消息处理流程中加入用户信息同步。 + """ + # 1. 从消息中提取用户信息 + user_info = getattr(message, "user_info", None) + if not user_info: + return + + platform = getattr(user_info, "platform", None) + user_id = getattr(user_info, "user_id", None) + nickname = getattr(user_info, "user_nickname", None) + cardname = getattr(user_info, "user_cardname", None) + + if not platform or not user_id: + return + + # 2. 异步执行用户信息同步 + try: + from src.person_info.person_info import get_person_info_manager + person_info_manager = get_person_info_manager() + + # 创建一个后台任务来执行同步,不阻塞当前流程 + sync_task = asyncio.create_task( + person_info_manager.sync_user_info(platform, user_id, nickname, cardname) + ) + # 将任务添加到集合中以防止被垃圾回收 + # 可以在适当的地方(如程序关闭时)清理这个集合 + _background_tasks.add(sync_task) + sync_task.add_done_callback(_background_tasks.discard) + + except Exception as e: + logger.error(f"创建用户信息同步任务失败: {e}") + async def get_or_create_stream( self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None ) -> ChatStream: @@ -437,7 +474,10 @@ class ChatManager: # 检查是否有最后一条消息(现在使用 DatabaseMessages) from src.common.data_models.database_data_model import DatabaseMessages if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): - await stream.set_context(self.last_messages[stream_id]) + last_message = self.last_messages[stream_id] + await stream.set_context(last_message) + # 在这里调用消息处理 + await self._process_message(last_message) else: logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息") return stream @@ -634,23 +674,23 @@ class ChatManager: except Exception as e: logger.debug(f"批量写入器保存聊天流失败,使用原始方法: {e}") - # 尝试使用数据库批量调度器(回退方案1) - try: - from src.common.database.db_batch_scheduler import batch_update, get_batch_session - - async with get_batch_session(): - # 使用批量更新 - result = await batch_update( - model_class=ChatStreams, - conditions={"stream_id": stream_data_dict["stream_id"]}, - data=ChatManager._prepare_stream_data(stream_data_dict), - ) - if result and result > 0: - stream.saved = True - logger.debug(f"聊天流 {stream.stream_id} 通过批量调度器保存成功") - return - except (ImportError, Exception) as e: - logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}") + # 尝试使用数据库批量调度器(回退方案1) - [已废弃] + # try: + # from src.common.database.optimization.batch_scheduler import batch_update, get_batch_session + # + # async with get_batch_session(): + # # 使用批量更新 + # result = await batch_update( + # model_class=ChatStreams, + # conditions={"stream_id": stream_data_dict["stream_id"]}, + # data=ChatManager._prepare_stream_data(stream_data_dict), + # ) + # if result and result > 0: + # stream.saved = True + # logger.debug(f"聊天流 {stream.stream_id} 通过批量调度器保存成功") + # return + # except (ImportError, Exception) as e: + # logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}") # 回退到原始方法(最终方案) async def _db_save_stream_async(s_data_dict: dict): diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index a760e6025..ee4819524 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -1332,8 +1332,8 @@ class DefaultReplyer: ), "cross_context": asyncio.create_task( self._time_and_run_task( - Prompt.build_cross_context(chat_id, "s4u", target_user_info), - "cross_context", + # cross_context 的构建已移至 prompt.py + asyncio.sleep(0, result=""), "cross_context" ) ), "notice_block": asyncio.create_task( @@ -1521,6 +1521,8 @@ class DefaultReplyer: # 使用新的统一Prompt系统 - 创建PromptParameters prompt_parameters = PromptParameters( + platform=platform, + user_id=user_id, chat_scene=chat_scene_prompt, chat_id=chat_id, is_group_chat=is_group_chat, diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 668884d93..6701665d8 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -709,9 +709,17 @@ class Prompt: async def _build_relation_info(self) -> dict[str, Any]: """构建与对话目标相关的关系信息.""" try: - # 调用静态方法来执行实际的构建逻辑 - relation_info = await Prompt.build_relation_info( - self.parameters.chat_id, self.parameters.reply_to + # [重构] 直接从 PromptParameters 获取稳定的用户身份信息 + platform = self.parameters.platform + user_id = self.parameters.user_id + + if not platform or not user_id: + logger.warning("无法从参数中获取platform或user_id,跳过关系信息构建") + return {"relation_info_block": ""} + + # 调用新的、基于ID的静态方法 + relation_info = await Prompt.build_relation_info_by_user_id( + self.parameters.chat_id, platform, user_id ) return {"relation_info_block": relation_info} except Exception as e: @@ -1063,43 +1071,29 @@ class Prompt: return sender, target @staticmethod - async def build_relation_info(chat_id: str, reply_to: str) -> str: - """构建关于回复目标用户的关系信息字符串. - - Args: - chat_id: 当前聊天的ID。 - reply_to: 被回复的原始消息字符串。 - - Returns: - str: 格式化后的关系信息字符串,或在失败时返回空字符串。 + async def build_relation_info_by_user_id(chat_id: str, platform: str, user_id: str) -> str: + """ + [新] 根据用户ID构建关系信息字符串。 """ from src.person_info.relationship_fetcher import relationship_fetcher_manager + + person_info_manager = get_person_info_manager() + person_id = person_info_manager.get_person_id(platform, user_id) + + if not person_id: + logger.warning(f"构建关系信息时未找到用户 platform={platform}, user_id={user_id}") + return f"你似乎还不认识这位用户(ID: {user_id}),这是你们的第一次互动。" relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id) - - if not reply_to: - return "" - # 解析出回复目标的发送者 - sender, text = Prompt.parse_reply_target(reply_to) - if not sender or not text: - return "" - - # 根据发送者名称查找其用户ID - person_info_manager = get_person_info_manager() - person_id = await person_info_manager.get_person_id_by_person_name(sender) - if not person_id: - logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") - return f"你完全不认识{sender},不理解ta的相关信息。" - - # 使用关系提取器构建用户关系信息和聊天流印象 - user_relation_info = await relationship_fetcher.build_relation_info( - person_id, points_num=5 - ) - stream_impression = await relationship_fetcher.build_chat_stream_impression( - chat_id + + # 并行构建用户信息和聊天流印象 + user_relation_info_task = relationship_fetcher.build_relation_info(person_id, points_num=5) + stream_impression_task = relationship_fetcher.build_chat_stream_impression(chat_id) + + user_relation_info, stream_impression = await asyncio.gather( + user_relation_info_task, stream_impression_task ) - # 组合两部分信息 info_parts = [] if user_relation_info: info_parts.append(user_relation_info) @@ -1149,6 +1143,7 @@ class Prompt: Returns: str: 构建好的跨群聊上下文字符串。 """ + logger.info(f"Building cross context with target_user_info: {target_user_info}") if not global_config.cross_context.enable: return "" @@ -1167,32 +1162,22 @@ class Prompt: return "" - @staticmethod - async def parse_reply_target_id(reply_to: str) -> str: - """从回复目标字符串中解析出原始发送者的用户ID. - - Args: - reply_to: 回复目标字符串。 - - Returns: - str: 找到的用户ID,如果找不到则返回空字符串。 - """ - if not reply_to: - return "" - - # 首先,解析出发送者的名称 - sender, _ = Prompt.parse_reply_target(reply_to) - if not sender: - return "" - - # 然后,通过名称查询用户ID - person_info_manager = get_person_info_manager() - person_id = await person_info_manager.get_person_id_by_person_name(sender) - if person_id: - user_id = await person_info_manager.get_value(person_id, "user_id") - return str(user_id) if user_id else "" - - return "" + # [废弃] 该函数完全依赖于不稳定的名称解析,应被移除 + # @staticmethod + # async def parse_reply_target_id(reply_to: str) -> str: + # """从回复目标字符串中解析出原始发送者的用户ID.""" + # if not reply_to: + # return "" + # sender, _ = Prompt.parse_reply_target(reply_to) + # if not sender: + # return "" + # person_info_manager = get_person_info_manager() + # # [脆弱点] 使用了不稳健的按名称查询 + # person_id = await person_info_manager.get_person_id_by_name_robust(sender) + # if person_id: + # user_id = await person_info_manager.get_value(person_id, "user_id") + # return str(user_id) if user_id else "" + # return "" # 工厂函数 diff --git a/src/chat/utils/prompt_params.py b/src/chat/utils/prompt_params.py index 707b18575..5ca376a2a 100644 --- a/src/chat/utils/prompt_params.py +++ b/src/chat/utils/prompt_params.py @@ -11,6 +11,8 @@ class PromptParameters: # 基础参数 chat_id: str = "" + platform: str = "" + user_id: str = "" is_group_chat: bool = False sender: str = "" target: str = "" diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 6165d1a2a..4cd1b3eb7 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -135,6 +135,113 @@ class PersonInfoManager: logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}") return "" + @staticmethod + @cached(ttl=600, key_prefix="person_info_by_user_id", use_kwargs=False) + async def get_person_info_by_user_id(platform: str, user_id: str) -> dict | None: + """[新] 根据 platform 和 user_id 获取用户信息字典""" + if not platform or not user_id: + return None + + person_id = PersonInfoManager.get_person_id(platform, user_id) + crud = CRUDBase(PersonInfo) + record = await crud.get_by(person_id=person_id) + + if not record: + return None + + # 将 SQLAlchemy 模型对象转换为字典 + return {c.name: getattr(record, c.name) for c in record.__table__.columns} + + @staticmethod + @cached(ttl=600, key_prefix="person_info_by_person_id", use_kwargs=False) + async def get_person_info_by_person_id(person_id: str) -> dict | None: + """[新] 根据 person_id 获取用户信息字典""" + if not person_id: + return None + + crud = CRUDBase(PersonInfo) + record = await crud.get_by(person_id=person_id) + + if not record: + return None + + # 将 SQLAlchemy 模型对象转换为字典 + return {c.name: getattr(record, c.name) for c in record.__table__.columns} + + @staticmethod + async def get_person_id_by_name_robust(name: str) -> str | None: + """[新] 稳健地根据名称获取 person_id,按 person_name -> nickname 顺序回退""" + if not name: + return None + + crud = CRUDBase(PersonInfo) + + # 1. 按 person_name 查询 + records = await crud.get_multi(person_name=name, limit=1) + if records: + return records[0].person_id + + # 2. 按 nickname 查询 + records = await crud.get_multi(nickname=name, limit=1) + if records: + return records[0].person_id + + return None + + @staticmethod + @staticmethod + @cached(ttl=600, key_prefix="person_info_by_name_robust", use_kwargs=False) + async def get_person_info_by_name_robust(name: str) -> dict | None: + """[新] 稳健地根据名称获取用户信息,按 person_name -> nickname 顺序回退""" + person_id = await PersonInfoManager.get_person_id_by_name_robust(name) + if person_id: + return await PersonInfoManager.get_person_info_by_person_id(person_id) + return None + + @staticmethod + async def sync_user_info(platform: str, user_id: str, nickname: str | None, cardname: str | None) -> str: + """ + [新] 同步用户信息。查询或创建用户,并更新易变信息(如昵称)。 + 返回 person_id。 + """ + if not platform or not user_id: + raise ValueError("platform 和 user_id 不能为空") + + person_id = PersonInfoManager.get_person_id(platform, user_id) + crud = CRUDBase(PersonInfo) + record = await crud.get_by(person_id=person_id) + + effective_name = cardname or nickname or "未知用户" + + if record: + # 用户已存在,检查是否需要更新 + updates = {} + if nickname and record.nickname != nickname: + updates["nickname"] = nickname + + if updates: + await crud.update(record.id, updates) + logger.debug(f"用户 {person_id} 信息已更新: {updates}") + else: + # 用户不存在,创建新用户 + logger.info(f"新用户 {platform}:{user_id},将创建记录。") + unique_person_name = await PersonInfoManager._generate_unique_person_name(effective_name) + + new_person_data = { + "person_id": person_id, + "platform": platform, + "user_id": str(user_id), + "nickname": nickname, + "person_name": unique_person_name, + "name_reason": "首次遇见时自动设置", + "know_since": int(time.time()), + "last_know": int(time.time()), + } + await PersonInfoManager._safe_create_person_info(person_id, new_person_data) + + return person_id + + @staticmethod @staticmethod async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str): """判断是否认识某人"""