diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index f35070135..b2c2e3232 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -261,7 +261,7 @@ class AntiPromptInjector: logger.warning("无法删除消息:缺少message_id") return - with get_db_session() as session: + async with get_db_session() as session: # 删除对应的消息记录 stmt = delete(Messages).where(Messages.message_id == message_id) result = session.execute(stmt) @@ -287,7 +287,7 @@ class AntiPromptInjector: logger.warning("无法更新消息:缺少message_id") return - with get_db_session() as session: + async with get_db_session() as session: # 更新消息内容 stmt = ( update(Messages) diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index ebf4e37d0..a495e04e5 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -42,7 +42,7 @@ class SingleStreamContextManager: self._update_access_stats() return self.context - def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool: + async def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool: """添加消息到上下文 Args: @@ -53,30 +53,21 @@ class SingleStreamContextManager: bool: 是否成功添加 """ try: - # 添加消息到上下文 self.context.add_message(message) - - # 计算消息兴趣度 - interest_value = self._calculate_message_interest(message) + interest_value = await self._calculate_message_interest(message) message.interest_value = interest_value - - # 更新统计 self.total_messages += 1 self.last_access_time = time.time() - - # 更新能量和分发 if not skip_energy_update: - self._update_stream_energy() + await self._update_stream_energy() distribution_manager.add_stream_message(self.stream_id, 1) - logger.debug(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})") return True - except Exception as e: logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) return False - def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool: + async def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool: """更新上下文中的消息 Args: @@ -87,16 +78,11 @@ class SingleStreamContextManager: bool: 是否成功更新 """ try: - # 更新消息信息 self.context.update_message_info(message_id, **updates) - - # 如果更新了兴趣度,重新计算能量 if "interest_value" in updates: - self._update_stream_energy() - + await self._update_stream_energy() logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}") return True - except Exception as e: logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True) return False @@ -164,16 +150,13 @@ class SingleStreamContextManager: logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True) return False - def clear_context(self) -> bool: + async def clear_context(self) -> bool: """清空上下文""" try: - # 清空消息 if hasattr(self.context, "unread_messages"): self.context.unread_messages.clear() if hasattr(self.context, "history_messages"): self.context.history_messages.clear() - - # 重置状态 reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"] for attr in reset_attrs: if hasattr(self.context, attr): @@ -181,13 +164,9 @@ class SingleStreamContextManager: setattr(self.context, attr, 0) else: setattr(self.context, attr, time.time()) - - # 重新计算能量 - self._update_stream_energy() - + await self._update_stream_energy() logger.info(f"清空单流上下文: {self.stream_id}") return True - except Exception as e: logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True) return False @@ -249,39 +228,115 @@ class SingleStreamContextManager: self.last_access_time = time.time() self.access_count += 1 - def _calculate_message_interest(self, message: DatabaseMessages) -> float: - """计算消息兴趣度""" + async def _calculate_message_interest(self, message: DatabaseMessages) -> float: + """异步实现:使用插件的异步评分器正确 await 计算兴趣度并返回分数。""" try: - # 使用插件内部的兴趣度评分系统 try: - from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system - - # 使用插件内部的兴趣度评分系统计算(同步方式) + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( + chatter_interest_scoring_system, + ) try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - interest_score = loop.run_until_complete( - chatter_interest_scoring_system._calculate_single_message_score( + interest_score = await chatter_interest_scoring_system._calculate_single_message_score( message=message, bot_nickname=global_config.bot.nickname ) - ) - interest_value = interest_score.total_score + interest_value = interest_score.total_score + logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}") + return interest_value + except Exception as e: + logger.warning(f"插件内部兴趣度计算失败: {e}") + return 0.5 + except Exception as e: + logger.warning(f"插件内部兴趣度计算加载失败,使用默认值: {e}") + return 0.5 + except Exception as e: + logger.error(f"计算消息兴趣度失败: {e}") + return 0.5 - logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}") + async def _calculate_message_interest_async(self, message: DatabaseMessages) -> float: + """异步实现:使用插件的异步评分器正确 await 计算兴趣度并返回分数。""" + try: + try: + from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( + chatter_interest_scoring_system, + ) + + # 直接 await 插件的异步方法 + try: + interest_score = await chatter_interest_scoring_system._calculate_single_message_score( + message=message, bot_nickname=global_config.bot.nickname + ) + interest_value = interest_score.total_score + logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}") + return interest_value + except Exception as e: + logger.warning(f"插件内部兴趣度计算失败: {e}") + return 0.5 except Exception as e: - logger.warning(f"插件内部兴趣度计算失败,使用默认值: {e}") - interest_value = 0.5 # 默认中等兴趣度 - - return interest_value + logger.warning(f"插件内部兴趣度计算加载失败,使用默认值: {e}") + return 0.5 except Exception as e: logger.error(f"计算消息兴趣度失败: {e}") return 0.5 + async def add_message_async(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool: + """异步实现的 add_message:将消息添加到 context,并 await 能量更新与分发。""" + try: + self.context.add_message(message) + + interest_value = await self._calculate_message_interest_async(message) + message.interest_value = interest_value + + self.total_messages += 1 + self.last_access_time = time.time() + + if not skip_energy_update: + await self._update_stream_energy() + distribution_manager.add_stream_message(self.stream_id, 1) + + logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度: {interest_value:.3f})") + return True + except Exception as e: + logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True) + return False + + async def update_message_async(self, message_id: str, updates: Dict[str, Any]) -> bool: + """异步实现的 update_message:更新消息并在需要时 await 能量更新。""" + try: + self.context.update_message_info(message_id, **updates) + if "interest_value" in updates: + await self._update_stream_energy() + + logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}") + return True + except Exception as e: + logger.error(f"更新单流上下文消息失败 (async) {self.stream_id}/{message_id}: {e}", exc_info=True) + return False + + async def clear_context_async(self) -> bool: + """异步实现的 clear_context:清空消息并 await 能量重算。""" + try: + if hasattr(self.context, "unread_messages"): + self.context.unread_messages.clear() + if hasattr(self.context, "history_messages"): + self.context.history_messages.clear() + + reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"] + for attr in reset_attrs: + if hasattr(self.context, attr): + if attr in ["interruption_count", "afc_threshold_adjustment"]: + setattr(self.context, attr, 0) + else: + setattr(self.context, attr, time.time()) + + await self._update_stream_energy() + logger.info(f"清空单流上下文(异步): {self.stream_id}") + return True + except Exception as e: + logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True) + return False + async def _update_stream_energy(self): """更新流能量""" try: @@ -305,4 +360,4 @@ class SingleStreamContextManager: distribution_manager.update_stream_energy(self.stream_id, energy) except Exception as e: - logger.error(f"更新单流能量失败 {self.stream_id}: {e}") \ No newline at end of file + logger.error(f"更新单流能量失败 {self.stream_id}: {e}") diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 6a0eac5e2..5a8ad98e2 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -75,29 +75,23 @@ class MessageManager: logger.info("消息管理器已停止") - def add_message(self, stream_id: str, message: DatabaseMessages): + async def add_message(self, stream_id: str, message: DatabaseMessages): """添加消息到指定聊天流""" try: - # 通过 ChatManager 获取 ChatStream chat_manager = get_chat_manager() chat_stream = chat_manager.get_stream(stream_id) - if not chat_stream: logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在") return - - # 使用 ChatStream 的 context_manager 添加消息 - success = chat_stream.context_manager.add_message(message) - + success = await chat_stream.context_manager.add_message(message) if success: logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}") else: logger.warning(f"添加消息到聊天流 {stream_id} 失败") - except Exception as e: logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}") - def update_message( + async def update_message( self, stream_id: str, message_id: str, @@ -107,15 +101,11 @@ class MessageManager: ): """更新消息信息""" try: - # 通过 ChatManager 获取 ChatStream chat_manager = get_chat_manager() chat_stream = chat_manager.get_stream(stream_id) - if not chat_stream: logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在") return - - # 构建更新字典 updates = {} if interest_value is not None: updates["interest_value"] = interest_value @@ -123,41 +113,30 @@ class MessageManager: updates["actions"] = actions if should_reply is not None: updates["should_reply"] = should_reply - - # 使用 ChatStream 的 context_manager 更新消息 if updates: - success = chat_stream.context_manager.update_message(message_id, updates) + success = await chat_stream.context_manager.update_message(message_id, updates) if success: logger.debug(f"更新消息 {message_id} 成功") else: logger.warning(f"更新消息 {message_id} 失败") - except Exception as e: logger.error(f"更新消息 {message_id} 时发生错误: {e}") - def add_action(self, stream_id: str, message_id: str, action: str): + async def add_action(self, stream_id: str, message_id: str, action: str): """添加动作到消息""" try: - # 通过 ChatManager 获取 ChatStream chat_manager = get_chat_manager() chat_stream = chat_manager.get_stream(stream_id) - if not chat_stream: logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在") return - - # 使用 ChatStream 的 context_manager 添加动作 - # 注意:这里需要根据实际的 API 调整 - # 假设我们可以通过 update_message 来添加动作 - success = chat_stream.context_manager.update_message( + success = await chat_stream.context_manager.update_message( message_id, {"actions": [action]} ) - if success: logger.debug(f"为消息 {message_id} 添加动作 {action} 成功") else: logger.warning(f"为消息 {message_id} 添加动作 {action} 失败") - except Exception as e: logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}") @@ -382,36 +361,27 @@ class MessageManager: "start_time": self.stats.start_time, } - def cleanup_inactive_streams(self, max_inactive_hours: int = 24): + async def cleanup_inactive_streams(self, max_inactive_hours: int = 24): """清理不活跃的聊天流""" try: - # 通过 ChatManager 清理不活跃的流 chat_manager = get_chat_manager() current_time = time.time() max_inactive_seconds = max_inactive_hours * 3600 - inactive_streams = [] for stream_id, chat_stream in chat_manager.streams.items(): - # 检查最后活跃时间 if current_time - chat_stream.last_active_time > max_inactive_seconds: inactive_streams.append(stream_id) - - # 清理不活跃的流 for stream_id in inactive_streams: try: - # 清理流的内容 - chat_stream.context_manager.clear_context() - # 从 ChatManager 中移除 + await chat_stream.context_manager.clear_context() del chat_manager.streams[stream_id] logger.info(f"清理不活跃聊天流: {stream_id}") except Exception as e: logger.error(f"清理聊天流 {stream_id} 失败: {e}") - if inactive_streams: logger.info(f"已清理 {len(inactive_streams)} 个不活跃聊天流") else: logger.debug("没有需要清理的不活跃聊天流") - except Exception as e: logger.error(f"清理不活跃聊天流时发生错误: {e}") diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 92a44b443..d8b9c6857 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -514,7 +514,7 @@ class ChatBot: db_message.chat_info_group_platform = message.chat_stream.group_info.platform # 添加消息到消息管理器 - message_manager.add_message(message.chat_stream.stream_id, db_message) + await message_manager.add_message(message.chat_stream.stream_id, db_message) logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}") if template_group_name: diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 007ab4dc1..6a4834cb0 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -389,94 +389,105 @@ class ChatStream: from sqlalchemy import select, desc import asyncio - async def _load_messages(): - def _db_query(): - with get_db_session() as session: - # 查询该stream_id的最近20条消息 + async def _load_history_messages_async(): + """异步加载并转换历史消息到 stream_context(在事件循环中运行)。""" + try: + async with get_db_session() as session: stmt = ( select(Messages) .where(Messages.chat_info_stream_id == self.stream_id) .order_by(desc(Messages.time)) .limit(global_config.chat.max_context_size) ) - result = session.execute(stmt) - results = result.scalars().all() - return results + result = await session.execute(stmt) + db_messages = result.scalars().all() - # 在线程中执行数据库查询 - db_messages = await asyncio.to_thread(_db_query) + # 转换为DatabaseMessages对象并添加到StreamContext + for db_msg in db_messages: + try: + import orjson - # 转换为DatabaseMessages对象并添加到StreamContext - for db_msg in db_messages: + actions = None + if db_msg.actions: + try: + actions = orjson.loads(db_msg.actions) + except (orjson.JSONDecodeError, TypeError): + actions = None + + db_message = DatabaseMessages( + message_id=db_msg.message_id, + time=db_msg.time, + chat_id=db_msg.chat_id, + reply_to=db_msg.reply_to, + interest_value=db_msg.interest_value, + key_words=db_msg.key_words, + key_words_lite=db_msg.key_words_lite, + is_mentioned=db_msg.is_mentioned, + processed_plain_text=db_msg.processed_plain_text, + display_message=db_msg.display_message, + priority_mode=db_msg.priority_mode, + priority_info=db_msg.priority_info, + additional_config=db_msg.additional_config, + is_emoji=db_msg.is_emoji, + is_picid=db_msg.is_picid, + is_command=db_msg.is_command, + is_notify=db_msg.is_notify, + user_id=db_msg.user_id, + user_nickname=db_msg.user_nickname, + user_cardname=db_msg.user_cardname, + user_platform=db_msg.user_platform, + chat_info_group_id=db_msg.chat_info_group_id, + chat_info_group_name=db_msg.chat_info_group_name, + chat_info_group_platform=db_msg.chat_info_group_platform, + chat_info_user_id=db_msg.chat_info_user_id, + chat_info_user_nickname=db_msg.chat_info_user_nickname, + chat_info_user_cardname=db_msg.chat_info_user_cardname, + chat_info_user_platform=db_msg.chat_info_user_platform, + chat_info_stream_id=db_msg.chat_info_stream_id, + chat_info_platform=db_msg.chat_info_platform, + chat_info_create_time=db_msg.chat_info_create_time, + chat_info_last_active_time=db_msg.chat_info_last_active_time, + actions=actions, + should_reply=getattr(db_msg, "should_reply", False) or False, + ) + + logger.debug( + f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}" + ) + + db_message.is_read = True + self.stream_context.history_messages.append(db_message) + + except Exception as e: + logger.warning(f"转换消息 {getattr(db_msg, 'message_id', '')} 失败: {e}") + continue + + if self.stream_context.history_messages: + logger.info( + f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}" + ) + + except Exception as e: + logger.warning(f"异步加载历史消息失败: {e}") + + # 在已有事件循环中,避免调用 asyncio.run() + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # 没有运行的事件循环,安全地运行并等待完成 + asyncio.run(_load_history_messages_async()) + else: + # 如果事件循环正在运行,在后台创建任务 + if loop.is_running(): try: - # 从SQLAlchemy模型转换为DatabaseMessages数据模型 - import orjson - - # 解析actions字段(JSON格式) - actions = None - if db_msg.actions: - try: - actions = orjson.loads(db_msg.actions) - except (orjson.JSONDecodeError, TypeError): - actions = None - - db_message = DatabaseMessages( - message_id=db_msg.message_id, - time=db_msg.time, - chat_id=db_msg.chat_id, - reply_to=db_msg.reply_to, - interest_value=db_msg.interest_value, - key_words=db_msg.key_words, - key_words_lite=db_msg.key_words_lite, - is_mentioned=db_msg.is_mentioned, - processed_plain_text=db_msg.processed_plain_text, - display_message=db_msg.display_message, - priority_mode=db_msg.priority_mode, - priority_info=db_msg.priority_info, - additional_config=db_msg.additional_config, - is_emoji=db_msg.is_emoji, - is_picid=db_msg.is_picid, - is_command=db_msg.is_command, - is_notify=db_msg.is_notify, - user_id=db_msg.user_id, - user_nickname=db_msg.user_nickname, - user_cardname=db_msg.user_cardname, - user_platform=db_msg.user_platform, - chat_info_group_id=db_msg.chat_info_group_id, - chat_info_group_name=db_msg.chat_info_group_name, - chat_info_group_platform=db_msg.chat_info_group_platform, - chat_info_user_id=db_msg.chat_info_user_id, - chat_info_user_nickname=db_msg.chat_info_user_nickname, - chat_info_user_cardname=db_msg.chat_info_user_cardname, - chat_info_user_platform=db_msg.chat_info_user_platform, - chat_info_stream_id=db_msg.chat_info_stream_id, - chat_info_platform=db_msg.chat_info_platform, - chat_info_create_time=db_msg.chat_info_create_time, - chat_info_last_active_time=db_msg.chat_info_last_active_time, - actions=actions, - should_reply=getattr(db_msg, "should_reply", False) or False, - ) - - # 添加调试日志:检查从数据库加载的interest_value - logger.debug( - f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}" - ) - - # 标记为已读并添加到历史消息 - db_message.is_read = True - self.stream_context.history_messages.append(db_message) - + asyncio.create_task(_load_history_messages_async()) except Exception as e: - logger.warning(f"转换消息 {db_msg.message_id} 失败: {e}") - continue - - if self.stream_context.history_messages: - logger.info( - f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}" - ) - - # 创建任务来加载历史消息 - asyncio.create_task(_load_messages()) + # 如果无法创建任务,退回到阻塞运行 + logger.warning(f"无法在事件循环中创建后台任务,尝试阻塞运行: {e}") + asyncio.run(_load_history_messages_async()) + else: + # loop 存在但未运行,使用 asyncio.run + asyncio.run(_load_history_messages_async()) except Exception as e: logger.error(f"加载历史消息失败: {e}") @@ -498,7 +509,7 @@ class ChatManager: self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message # try: - # with get_db_session() as session: + # async with get_db_session() as session: # db.connect(reuse_if_open=True) # # 确保 ChatStreams 表存在 # session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)")) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 60583c1f8..4ab494b7b 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -219,7 +219,7 @@ class MessageStorage: return match.group(0) @staticmethod - def update_message_interest_value(message_id: str, interest_value: float) -> None: + async def update_message_interest_value(message_id: str, interest_value: float) -> None: """ 更新数据库中消息的interest_value字段 @@ -228,11 +228,11 @@ class MessageStorage: interest_value: 兴趣度值 """ try: - with get_db_session() as session: + async with get_db_session() as session: # 更新消息的interest_value字段 stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value) - result = session.execute(stmt) - session.commit() + result = await session.execute(stmt) + await session.commit() if result.rowcount > 0: logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}") @@ -244,7 +244,7 @@ class MessageStorage: raise @staticmethod - def fix_zero_interest_values(chat_id: str, since_time: float) -> int: + async def fix_zero_interest_values(chat_id: str, since_time: float) -> int: """ 修复指定聊天中interest_value为0或null的历史消息记录 @@ -256,7 +256,7 @@ class MessageStorage: 修复的记录数量 """ try: - with get_db_session() as session: + async with get_db_session() as session: from sqlalchemy import select, update from src.common.database.sqlalchemy_models import Messages @@ -271,7 +271,7 @@ class MessageStorage: ) ).limit(50) # 限制每次修复的数量,避免性能问题 - result = session.execute(query) + result = await session.execute(query) messages_to_fix = result.scalars().all() fixed_count = 0 @@ -297,12 +297,12 @@ class MessageStorage: Messages.message_id == msg.message_id ).values(interest_value=default_interest) - result = session.execute(update_stmt) + result = await session.execute(update_stmt) if result.rowcount > 0: fixed_count += 1 logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}") - session.commit() + await session.commit() logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值") return fixed_count diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 761c53e86..ea51d5e5b 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -297,15 +297,12 @@ class ChatterActionManager: return # 通过message_manager更新消息的动作记录并刷新focus_energy - if chat_stream.stream_id in message_manager.stream_contexts: - message_manager.add_action( - stream_id=chat_stream.stream_id, - message_id=target_message_id, - action=action_name - ) - logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy") - else: - logger.debug(f"未找到stream_context: {chat_stream.stream_id}") + await message_manager.add_action( + stream_id=chat_stream.stream_id, + message_id=target_message_id, + action=action_name + ) + logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy") except Exception as e: logger.error(f"记录动作到消息失败: {e}") @@ -315,8 +312,11 @@ class ChatterActionManager: """在动作执行成功后重置打断计数""" from src.chat.message_manager.message_manager import message_manager try: - if stream_id in message_manager.stream_contexts: - context = message_manager.stream_contexts[stream_id] + from src.plugin_system.apis.chat_api import get_chat_manager + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(stream_id) + if chat_stream: + context = chat_stream.context_manager if context.interruption_count > 0: old_count = context.interruption_count old_afc_adjustment = context.get_afc_threshold_adjustment() diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 481b848e2..4e144d3f4 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -73,7 +73,7 @@ class ActionModifier: from src.chat.utils.utils import get_chat_type_and_target_info # 获取聊天类型 - is_group_chat, _ = await get_chat_type_and_target_info(self.chat_id) + is_group_chat, _ = get_chat_type_and_target_info(self.chat_id) all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION) chat_type_removals = [] diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 868e34f21..20fc11d1d 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -684,8 +684,11 @@ class DefaultReplyer: from src.chat.message_manager.message_manager import message_manager # 获取聊天流的上下文 - stream_context = message_manager.stream_contexts.get(chat_id) - if stream_context: + from src.plugin_system.apis.chat_api import get_chat_manager + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(chat_id) + if chat_stream: + stream_context = chat_stream.context_manager # 使用真正的已读和未读消息 read_messages = stream_context.history_messages # 已读消息 unread_messages = stream_context.get_unread_messages() # 未读消息 @@ -693,7 +696,7 @@ class DefaultReplyer: # 构建已读历史消息 prompt read_history_prompt = "" if read_messages: - read_content = build_readable_messages( + read_content = await build_readable_messages( [msg.flatten() for msg in read_messages[-50:]], # 限制数量 replace_bot_name=True, timestamp_mode="normal_no_YMD", @@ -716,7 +719,7 @@ class DefaultReplyer: ] if filtered_fallback_messages: - read_content = build_readable_messages( + read_content = await build_readable_messages( filtered_fallback_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", @@ -754,7 +757,7 @@ class DefaultReplyer: if platform and user_id: person_id = PersonInfoManager.get_person_id(platform, user_id) person_info_manager = get_person_info_manager() - sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户" + sender_name = person_info_manager.get_value(person_id, "person_name") or "未知用户" else: sender_name = "未知用户" @@ -819,7 +822,7 @@ class DefaultReplyer: # 构建已读历史消息 prompt read_history_prompt = "" if read_messages: - read_content = build_readable_messages( + read_content = await build_readable_messages( read_messages[-50:], replace_bot_name=True, timestamp_mode="normal_no_YMD", @@ -853,7 +856,7 @@ class DefaultReplyer: if platform and user_id: person_id = PersonInfoManager.get_person_id(platform, user_id) person_info_manager = get_person_info_manager() - sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户" + sender_name = person_info_manager.get_value(person_id, "person_name") or "未知用户" else: sender_name = "未知用户" @@ -1027,7 +1030,7 @@ class DefaultReplyer: # 检查是否是bot自己的名字,如果是则替换为"(你)" bot_user_id = str(global_config.bot.qq_account) - current_user_id = person_info_manager.get_value_sync(person_id, "user_id") + current_user_id = person_info_manager.get_value(person_id, "user_id") current_platform = reply_message.get("chat_info_platform") if current_user_id == bot_user_id and current_platform == global_config.bot.platform: @@ -1046,7 +1049,7 @@ class DefaultReplyer: target = "(无消息内容)" person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) + person_id = await person_info_manager.get_person_id_by_person_name(sender) platform = chat_stream.platform target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) @@ -1071,7 +1074,7 @@ class DefaultReplyer: timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.33), ) - chat_talking_prompt_short = build_readable_messages( + chat_talking_prompt_short = await build_readable_messages( message_list_before_short, replace_bot_name=True, merge_messages=False, @@ -1324,7 +1327,7 @@ class DefaultReplyer: timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 15), ) - chat_talking_prompt_half = build_readable_messages( + chat_talking_prompt_half = await build_readable_messages( message_list_before_now_half, replace_bot_name=True, merge_messages=False, @@ -1523,7 +1526,7 @@ class DefaultReplyer: # 获取用户ID person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) + person_id = await person_info_manager.get_person_id_by_person_name(sender) if not person_id: logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") return f"你完全不认识{sender},不理解ta的相关信息。" diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 1ccc916db..48a396fdc 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -46,7 +46,7 @@ def replace_user_references_sync( if replace_bot_name and user_id == global_config.bot.qq_account: return f"{global_config.bot.nickname}(你)" person_id = PersonInfoManager.get_person_id(platform, user_id) - return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore + return person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore name_resolver = default_resolver @@ -254,7 +254,7 @@ def get_raw_msg_by_timestamp_with_chat_users( return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -def get_actions_by_timestamp_with_chat( +async def get_actions_by_timestamp_with_chat( chat_id: str, timestamp_start: float = 0, timestamp_end: float = time.time(), @@ -273,22 +273,21 @@ def get_actions_by_timestamp_with_chat( f"limit={limit}, limit_mode={limit_mode}" ) - with get_db_session() as session: + async with get_db_session() as session: if limit > 0: - if limit_mode == "latest": - query = session.execute( + result = await session.execute( select(ActionRecords) .where( and_( ActionRecords.chat_id == chat_id, - ActionRecords.time > timestamp_start, - ActionRecords.time < timestamp_end, + ActionRecords.time >= timestamp_start, + ActionRecords.time <= timestamp_end, ) ) .order_by(ActionRecords.time.desc()) .limit(limit) ) - actions = list(query.scalars()) + actions = list(result.scalars()) actions_result = [] for action in reversed(actions): action_dict = { @@ -305,38 +304,39 @@ def get_actions_by_timestamp_with_chat( "chat_info_platform": action.chat_info_platform, } actions_result.append(action_dict) - else: # earliest - query = session.execute( - select(ActionRecords) - .where( - and_( - ActionRecords.chat_id == chat_id, - ActionRecords.time > timestamp_start, - ActionRecords.time < timestamp_end, - ) - ) - .order_by(ActionRecords.time.asc()) - .limit(limit) - ) - actions = list(query.scalars()) - actions_result = [] - for action in actions: - action_dict = { - "id": action.id, - "action_id": action.action_id, - "time": action.time, - "action_name": action.action_name, - "action_data": action.action_data, - "action_done": action.action_done, - "action_build_into_prompt": action.action_build_into_prompt, - "action_prompt_display": action.action_prompt_display, - "chat_id": action.chat_id, - "chat_info_stream_id": action.chat_info_stream_id, - "chat_info_platform": action.chat_info_platform, - } actions_result.append(action_dict) + else: # earliest + result = await session.execute( + select(ActionRecords) + .where( + and_( + ActionRecords.chat_id == chat_id, + ActionRecords.time > timestamp_start, + ActionRecords.time < timestamp_end, + ) + ) + .order_by(ActionRecords.time.asc()) + .limit(limit) + ) + actions = list(result.scalars()) + actions_result = [] + for action in actions: + action_dict = { + "id": action.id, + "action_id": action.action_id, + "time": action.time, + "action_name": action.action_name, + "action_data": action.action_data, + "action_done": action.action_done, + "action_build_into_prompt": action.action_build_into_prompt, + "action_prompt_display": action.action_prompt_display, + "chat_id": action.chat_id, + "chat_info_stream_id": action.chat_info_stream_id, + "chat_info_platform": action.chat_info_platform, + } + actions_result.append(action_dict) else: - query = session.execute( + result = await session.execute( select(ActionRecords) .where( and_( @@ -347,7 +347,7 @@ def get_actions_by_timestamp_with_chat( ) .order_by(ActionRecords.time.asc()) ) - actions = list(query.scalars()) + actions = list(result.scalars()) actions_result = [] for action in actions: action_dict = { @@ -367,14 +367,14 @@ def get_actions_by_timestamp_with_chat( return actions_result -def get_actions_by_timestamp_with_chat_inclusive( +async def get_actions_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" - with get_db_session() as session: + async with get_db_session() as session: if limit > 0: if limit_mode == "latest": - query = session.execute( + result = await session.execute( select(ActionRecords) .where( and_( @@ -386,10 +386,10 @@ def get_actions_by_timestamp_with_chat_inclusive( .order_by(ActionRecords.time.desc()) .limit(limit) ) - actions = list(query.scalars()) + actions = list(result.scalars()) return [action.__dict__ for action in reversed(actions)] else: # earliest - query = session.execute( + result = await session.execute( select(ActionRecords) .where( and_( @@ -402,7 +402,7 @@ def get_actions_by_timestamp_with_chat_inclusive( .limit(limit) ) else: - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -507,7 +507,7 @@ def num_new_messages_since_with_users( return count_messages(message_filter=filter_query) -def _build_readable_messages_internal( +async def _build_readable_messages_internal( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -627,7 +627,7 @@ def _build_readable_messages_internal( if replace_bot_name and user_id == global_config.bot.qq_account: person_name = f"{global_config.bot.nickname}(你)" else: - person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore + person_name = await person_info_manager.get_value(person_id, "person_name") # type: ignore # 如果 person_name 未设置,则使用消息中的 nickname 或默认名称 if not person_name: @@ -800,7 +800,7 @@ def _build_readable_messages_internal( ) -def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: +async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: # sourcery skip: use-contextlib-suppress """ 构建图片映射信息字符串,显示图片的具体描述内容 @@ -823,8 +823,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: # 从数据库中获取图片描述 description = "[图片内容未知]" # 默认描述 try: - with get_db_session() as session: - result = session.execute(select(Images).where(Images.image_id == pic_id)) + async with get_db_session() as session: + result = await session.execute(select(Images).where(Images.image_id == pic_id)) image = result.scalar_one_or_none() if image and image.description: # type: ignore description = image.description @@ -922,17 +922,17 @@ async def build_readable_messages_with_list( 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 允许通过参数控制格式化行为。 """ - formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( + formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal( messages, replace_bot_name, merge_messages, timestamp_mode, truncate ) - if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): + if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping): formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" return formatted_string, details_list -def build_readable_messages_with_id( +async def build_readable_messages_with_id( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -948,7 +948,7 @@ def build_readable_messages_with_id( """ message_id_list = assign_message_ids(messages) - formatted_string = build_readable_messages( + formatted_string = await build_readable_messages( messages=messages, replace_bot_name=replace_bot_name, merge_messages=merge_messages, @@ -960,10 +960,16 @@ def build_readable_messages_with_id( message_id_list=message_id_list, ) + # 如果存在图片映射信息,附加之 + if pic_mapping_info := await build_pic_mapping_info({}): + # 如果当前没有图片映射则不附加 + if pic_mapping_info: + formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" + return formatted_string, message_id_list -def build_readable_messages( +async def build_readable_messages( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -1004,9 +1010,9 @@ def build_readable_messages( from src.common.database.sqlalchemy_database_api import get_db_session - with get_db_session() as session: + async with get_db_session() as session: # 获取这个时间范围内的动作记录,并匹配chat_id - actions_in_range = session.execute( + actions_in_range = (await session.execute( select(ActionRecords) .where( and_( @@ -1014,15 +1020,15 @@ def build_readable_messages( ) ) .order_by(ActionRecords.time) - ).scalars() + )).scalars() # 获取最新消息之后的第一个动作记录 - action_after_latest = session.execute( + action_after_latest = (await session.execute( select(ActionRecords) .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) .order_by(ActionRecords.time) .limit(1) - ).scalars() + )).scalars() # 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError actions = [ @@ -1053,7 +1059,7 @@ def build_readable_messages( if read_mark <= 0: # 没有有效的 read_mark,直接格式化所有消息 - formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal( + formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal( copy_messages, replace_bot_name, merge_messages, @@ -1064,7 +1070,7 @@ def build_readable_messages( ) # 生成图片映射信息并添加到最前面 - pic_mapping_info = build_pic_mapping_info(pic_id_mapping) + pic_mapping_info = await build_pic_mapping_info(pic_id_mapping) if pic_mapping_info: return f"{pic_mapping_info}\n\n{formatted_string}" else: @@ -1079,7 +1085,7 @@ def build_readable_messages( pic_counter = 1 # 分别格式化,但使用共享的图片映射 - formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal( + formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal( messages_before_mark, replace_bot_name, merge_messages, @@ -1090,7 +1096,7 @@ def build_readable_messages( show_pic=show_pic, message_id_list=message_id_list, ) - formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal( + formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal( messages_after_mark, replace_bot_name, merge_messages, @@ -1106,7 +1112,7 @@ def build_readable_messages( # 生成图片映射信息 if pic_id_mapping: - pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n" + pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n" else: pic_mapping_info = "聊天记录信息:\n" @@ -1229,7 +1235,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: # 在最前面添加图片映射信息 final_output_lines = [] - pic_mapping_info = build_pic_mapping_info(pic_id_mapping) + pic_mapping_info = await build_pic_mapping_info(pic_id_mapping) if pic_mapping_info: final_output_lines.append(pic_mapping_info) final_output_lines.append("\n\n") diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 112db6726..c5f6d3b39 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -494,7 +494,7 @@ class Prompt: chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-10:] - chat_history = build_readable_messages( + chat_history = await build_readable_messages( recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) @@ -535,7 +535,7 @@ class Prompt: chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-20:] - chat_history = build_readable_messages( + chat_history = await build_readable_messages( recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) @@ -589,7 +589,7 @@ class Prompt: chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-15:] - chat_history = build_readable_messages( + chat_history = await build_readable_messages( recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) @@ -863,7 +863,7 @@ class Prompt: # 获取用户ID person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) + person_id = await person_info_manager.get_person_id_by_person_name(sender) if not person_id: logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") return f"你完全不认识{sender},不理解ta的相关信息。" @@ -904,7 +904,7 @@ class Prompt: return "" @staticmethod - def parse_reply_target_id(reply_to: str) -> str: + async def parse_reply_target_id(reply_to: str) -> str: """ 解析回复目标中的用户ID @@ -924,9 +924,9 @@ class Prompt: # 获取用户ID person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) + person_id = await person_info_manager.get_person_id_by_person_name(sender) if person_id: - user_id = person_info_manager.get_value_sync(person_id, "user_id") + user_id = person_info_manager.get_value(person_id, "user_id") return str(user_id) if user_id else "" return "" diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index c2e4814f8..989428677 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -1,3 +1,4 @@ +import asyncio import random import re import string @@ -662,9 +663,32 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: person_id = PersonInfoManager.get_person_id(platform, user_id) person_name = None if person_id: - # get_value is async, so await it directly person_info_manager = get_person_info_manager() - person_name = person_info_manager.get_value_sync(person_id, "person_name") + try: + # 如果没有运行的事件循环,直接 asyncio.run + loop = asyncio.get_event_loop() + if loop.is_running(): + # 如果事件循环在运行,从其他线程提交并等待结果 + try: + from concurrent.futures import TimeoutError + + fut = asyncio.run_coroutine_threadsafe( + person_info_manager.get_value(person_id, "person_name"), loop + ) + person_name = fut.result(timeout=2) + except Exception as e: + # 无法在运行循环上安全等待,退回为 None + logger.debug(f"无法通过运行的事件循环获取 person_name: {e}") + person_name = None + else: + person_name = asyncio.run(person_info_manager.get_value(person_id, "person_name")) + except RuntimeError: + # get_event_loop 在某些上下文可能抛出 RuntimeError,退回到 asyncio.run + try: + person_name = asyncio.run(person_info_manager.get_value(person_id, "person_name")) + except Exception as e: + logger.debug(f"获取 person_name 失败: {e}") + person_name = None target_info["person_id"] = person_id target_info["person_name"] = person_name diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index f35c53573..a72b7564c 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -344,6 +344,39 @@ class StreamContext(BaseDataModel): """获取优先级信息""" return self.priority_info + def __deepcopy__(self, memo): + """自定义深拷贝,跳过不可序列化的 asyncio.Task (processing_task)。 + + deepcopy 在内部可能会尝试 pickle 某些对象(如 asyncio.Task), + 这会在多线程或运行时事件循环中导致 TypeError。这里我们手动复制 + __dict__ 中的字段,确保 processing_task 被设置为 None,其他字段使用 + copy.deepcopy 递归复制。 + """ + import copy + + # 如果已经复制过,直接返回缓存结果 + obj_id = id(self) + if obj_id in memo: + return memo[obj_id] + + # 创建一个未初始化的新实例,然后逐个字段深拷贝 + cls = self.__class__ + new = cls.__new__(cls) + memo[obj_id] = new + + for k, v in self.__dict__.items(): + if k == "processing_task": + # 不复制 asyncio.Task,避免无法 pickling + setattr(new, k, None) + else: + try: + setattr(new, k, copy.deepcopy(v, memo)) + except Exception: + # 如果某个字段无法深拷贝,退回到原始引用(安全性谨慎) + setattr(new, k, v) + + return new + @dataclass class MessageManagerStats(BaseDataModel): diff --git a/src/common/message_repository.py b/src/common/message_repository.py index f295f8e8a..57f179c36 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -30,7 +30,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]: return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns} -def find_messages( +async def find_messages( message_filter: dict[str, Any], sort: Optional[List[tuple[str, int]]] = None, limit: int = 0, @@ -51,7 +51,7 @@ def find_messages( 消息字典列表,如果出错则返回空列表。 """ try: - with get_db_session() as session: + async with get_db_session() as session: query = select(Messages) # 应用过滤器 @@ -101,8 +101,8 @@ def find_messages( # 获取时间最早的 limit 条记录,已经是正序 query = query.order_by(Messages.time.asc()).limit(limit) try: - results = result = session.execute(query) - result.scalars().all() + result = await session.execute(query) + results = result.scalars().all() except Exception as e: logger.error(f"执行earliest查询失败: {e}") results = [] @@ -110,8 +110,8 @@ def find_messages( # 获取时间最晚的 limit 条记录 query = query.order_by(Messages.time.desc()).limit(limit) try: - latest_results = result = session.execute(query) - result.scalars().all() + result = await session.execute(query) + latest_results = result.scalars().all() # 将结果按时间正序排列 results = sorted(latest_results, key=lambda msg: msg.time) except Exception as e: @@ -135,8 +135,8 @@ def find_messages( if sort_terms: query = query.order_by(*sort_terms) try: - results = result = session.execute(query) - result.scalars().all() + result = await session.execute(query) + results = result.scalars().all() except Exception as e: logger.error(f"执行无限制查询失败: {e}") results = [] @@ -152,7 +152,7 @@ def find_messages( return [] -def count_messages(message_filter: dict[str, Any]) -> int: +async def count_messages(message_filter: dict[str, Any]) -> int: """ 根据提供的过滤器计算消息数量。 @@ -163,7 +163,7 @@ def count_messages(message_filter: dict[str, Any]) -> int: 符合条件的消息数量,如果出错则返回 0。 """ try: - with get_db_session() as session: + async with get_db_session() as session: query = select(func.count(Messages.id)) # 应用过滤器 @@ -201,7 +201,7 @@ def count_messages(message_filter: dict[str, Any]) -> int: if conditions: query = query.where(*conditions) - count = session.execute(query).scalar() + count = (await session.execute(query)).scalar() return count or 0 except Exception as e: log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" diff --git a/src/main.py b/src/main.py index 1ff96935c..9fae02a28 100644 --- a/src/main.py +++ b/src/main.py @@ -148,8 +148,8 @@ class MainSystem: # 停止消息重组器 from src.plugin_system.core.event_manager import event_manager from src.plugin_system import EventType - asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM")) + from src.utils.message_chunker import reassembler loop = asyncio.get_event_loop() diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index caba99ad6..307c47d4e 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -110,7 +110,7 @@ class ChatMood: limit=int(global_config.chat.max_context_size / 3), limit_mode="last", ) - chat_talking_prompt = build_readable_messages( + chat_talking_prompt = await build_readable_messages( message_list_before_now, replace_bot_name=True, merge_messages=False, @@ -159,7 +159,7 @@ class ChatMood: limit=15, limit_mode="last", ) - chat_talking_prompt = build_readable_messages( + chat_talking_prompt = await build_readable_messages( message_list_before_now, replace_bot_name=True, merge_messages=False, diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index bba004356..4cfb75f94 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -1,3 +1,4 @@ +import asyncio import copy import datetime import hashlib @@ -57,7 +58,7 @@ class PersonInfoManager: self.person_name_list = {} self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") # try: - # with get_db_session() as session: + # async with get_db_session() as session: # db.connect(reuse_if_open=True) # # 设置连接池参数(仅对SQLite有效) # if hasattr(db, "execute_sql"): @@ -75,7 +76,7 @@ class PersonInfoManager: try: pass # 在这里获取会话 - # with get_db_session() as session: + # async with get_db_session() as session: # for record in session.execute( # select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None)) # ).fetchall(): @@ -87,58 +88,25 @@ class PersonInfoManager: @staticmethod def get_person_id(platform: str, user_id: Union[int, str]) -> str: - """获取唯一id""" + """获取唯一id(同步) + + 说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。 + 为了避免将 coroutine 传递到其它同步调用(例如数据库查询条件)中,这里将方法改为同步并仅返回基于 platform 和 user_id 的 MD5 哈希值。 + + 注意: 这会跳过原有的 napcat->qq 迁移检查逻辑。如需保留迁移,请使用显式的、在合适时机执行的迁移任务。 + """ # 检查platform是否为None或空 if platform is None: platform = "unknown" if "-" in platform: platform = platform.split("-")[1] - # 在此处打一个补丁,如果platform为qq,尝试生成id后检查是否存在,如果不存在,则将平台换为napcat后再次检查,如果存在,则更新原id为platform为qq的id + components = [platform, str(user_id)] key = "_".join(components) - # 如果不是 qq 平台,直接返回计算的 id - if platform != "qq": - return hashlib.md5(key.encode()).hexdigest() - - qq_id = hashlib.md5(key.encode()).hexdigest() - - # 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回 - def _db_check_and_migrate_sync(p_id: str, raw_user_id: str): - try: - with get_db_session() as session: - # 检查 qq_id 是否存在 - existing_qq = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() - if existing_qq: - return p_id - - # 如果 qq_id 不存在,尝试使用 napcat 作为平台生成对应 id 并检查 - nap_components = ["napcat", str(raw_user_id)] - nap_key = "_".join(nap_components) - nap_id = hashlib.md5(nap_key.encode()).hexdigest() - - existing_nap = session.execute(select(PersonInfo).where(PersonInfo.person_id == nap_id)).scalar() - if not existing_nap: - # napcat 也不存在,返回 qq_id(未命中) - return p_id - - # napcat 存在,迁移该记录:更新 person_id 与 platform -> qq - try: - # 更新现有 napcat 记录 - existing_nap.person_id = p_id - existing_nap.platform = "qq" - existing_nap.user_id = str(raw_user_id) - session.commit() - return p_id - except Exception: - session.rollback() - return p_id - except Exception as e: - logger.error(f"检查/迁移 napcat->qq 时出错: {e}") - return p_id - - return _db_check_and_migrate_sync(qq_id, user_id) + # 直接返回计算的 id(同步) + return hashlib.md5(key.encode()).hexdigest() async def is_person_known(self, platform: str, user_id: int): """判断是否认识某人""" @@ -157,17 +125,25 @@ class PersonInfoManager: logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}") return False - @staticmethod - async def get_person_id_by_person_name(person_name: str) -> str: - """根据用户名获取用户ID""" + async def get_person_id_by_person_name(self, person_name: str) -> str: + """ + 根据用户名获取用户ID(同步) + + 说明: 为了避免在多个调用点将 coroutine 误传递到数据库查询中, + 此处提供一个同步实现。优先在内存缓存 `self.person_name_list` 中查找, + 若未命中则返回空字符串。若后续需要更强的一致性,可在异步上下文 + 额外实现带 await 的查询方法。 + """ try: - # 在需要时获取会话 - async with get_db_session() as session: - record = result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)) - result.scalar() - return record.person_id if record else "" + # 优先使用内存缓存加速查找:self.person_name_list maps person_id -> person_name + for pid, pname in self.person_name_list.items(): + if pname == person_name: + return pid + + # 未找到缓存命中,避免在同步路径中进行阻塞的数据库查询,直接返回空字符串 + return "" except Exception as e: - logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}") + logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}") return "" @staticmethod @@ -578,26 +554,15 @@ class PersonInfoManager: @staticmethod - def get_value(person_id: str, field_name: str) -> Any: + async def get_value(person_id: str, field_name: str) -> Any: """获取单个字段值(同步版本)""" if not person_id: logger.debug("get_value获取失败:person_id不能为空") return None - import asyncio - - async def _get_record_sync(): - async with get_db_session() as session: - result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)) - record = result.scalar() - return record - - try: - record = asyncio.run(_get_record_sync()) - except RuntimeError: - # 如果当前线程已经有事件循环在运行,则使用现有的循环 - loop = asyncio.get_running_loop() - record = loop.run_until_complete(_get_record_sync()) + async with get_db_session() as session: + result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)) + record = result.scalar() model_fields = [column.name for column in PersonInfo.__table__.columns] diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 89632dd73..93269123b 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -176,7 +176,7 @@ class RelationshipFetcher: # 查询用户关系数据 relationships = await db_query( UserRelationships, - filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))], + filters=[UserRelationships.user_id == str(person_info_manager.get_value(person_id, "user_id"))], limit=1, ) @@ -259,7 +259,7 @@ class RelationshipFetcher: # 记录信息获取请求 self.info_fetching_cache.append( { - "person_id": get_person_info_manager().get_person_id_by_person_name(person_name), + "person_id": await get_person_info_manager().get_person_id_by_person_name(person_name), "person_name": person_name, "info_type": info_type, "start_time": time.time(), diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 76dc8f5cb..92ea0dd80 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -412,7 +412,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa # ============================================================================= -def build_readable_messages_to_str( +async def build_readable_messages_to_str( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -436,7 +436,7 @@ def build_readable_messages_to_str( Returns: 格式化后的可读字符串 """ - return build_readable_messages( + return await build_readable_messages( messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions ) diff --git a/src/plugin_system/apis/person_api.py b/src/plugin_system/apis/person_api.py index a84c5d2bb..e3f7be714 100644 --- a/src/plugin_system/apis/person_api.py +++ b/src/plugin_system/apis/person_api.py @@ -134,7 +134,7 @@ async def is_person_known(platform: str, user_id: int) -> bool: return False -def get_person_id_by_name(person_name: str) -> str: +async def get_person_id_by_name(person_name: str) -> str: """根据用户名获取person_id Args: @@ -148,7 +148,7 @@ def get_person_id_by_name(person_name: str) -> str: """ try: person_info_manager = get_person_info_manager() - return person_info_manager.get_person_id_by_person_name(person_name) + return await person_info_manager.get_person_id_by_person_name(person_name) except Exception as e: logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}") return "" diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index cc7a54d4c..9ec216a9b 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -542,7 +542,22 @@ class PluginManager: plugin_instance.on_unload() # 从组件注册表中移除插件的所有组件 - asyncio.run(component_registry.unregister_plugin(plugin_name)) + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + fut = asyncio.run_coroutine_threadsafe( + component_registry.unregister_plugin(plugin_name), loop + ) + fut.result(timeout=5) + else: + asyncio.run(component_registry.unregister_plugin(plugin_name)) + except Exception: + # 最后兜底:直接同步调用(如果 unregister_plugin 为非协程)或忽略错误 + try: + # 如果 unregister_plugin 是普通函数 + component_registry.unregister_plugin(plugin_name) + except Exception as e: + logger.debug(f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}") # 从已加载插件中移除 del self.loaded_plugins[plugin_name] diff --git a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py index 391ca58fb..8a331f99e 100644 --- a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py +++ b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py @@ -199,7 +199,7 @@ class ChatterInterestScoringSystem: # 如果内存中没有,尝试从关系追踪器获取 if hasattr(self, "relationship_tracker") and self.relationship_tracker: try: - relationship_score = self.relationship_tracker.get_user_relationship_score(user_id) + relationship_score = await self.relationship_tracker.get_user_relationship_score(user_id) # 同时更新内存缓存 self.user_relationships[user_id] = relationship_score return relationship_score diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index 09d7c5b67..f00ce2d8d 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -182,7 +182,7 @@ class ChatterPlanFilter: if plan.mode == ChatMode.PROACTIVE: long_term_memory_block = await self._get_long_term_memory_context() - chat_content_block, message_id_list = build_readable_messages_with_id( + chat_content_block, message_id_list = await build_readable_messages_with_id( messages=[msg.flatten() for msg in plan.chat_history], timestamp_mode="normal", truncate=False, @@ -190,7 +190,7 @@ class ChatterPlanFilter: ) prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt") - actions_before_now = get_actions_by_timestamp_with_chat( + actions_before_now = await get_actions_by_timestamp_with_chat( chat_id=plan.chat_id, timestamp_start=time.time() - 3600, timestamp_end=time.time(), @@ -216,7 +216,7 @@ class ChatterPlanFilter: ) # 为了兼容性,保留原有的chat_content_block - chat_content_block, _ = build_readable_messages_with_id( + chat_content_block, _ = await build_readable_messages_with_id( messages=[msg.flatten() for msg in plan.chat_history], timestamp_mode="normal", read_mark=self.last_obs_time_mark, @@ -224,7 +224,7 @@ class ChatterPlanFilter: show_actions=True, ) - actions_before_now = get_actions_by_timestamp_with_chat( + actions_before_now = await get_actions_by_timestamp_with_chat( chat_id=plan.chat_id, timestamp_start=time.time() - 3600, timestamp_end=time.time(), @@ -319,7 +319,14 @@ class ChatterPlanFilter: from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat # 获取聊天流的上下文 - stream_context = message_manager.stream_contexts.get(plan.chat_id) + from src.plugin_system.apis.chat_api import get_chat_manager + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(plan.chat_id) + if not chat_stream: + logger.warning(f"[plan_filter] 聊天流 {plan.chat_id} 不存在") + return "最近没有聊天内容。", "没有未读消息。", [] + + stream_context = chat_stream.context_manager # 获取真正的已读和未读消息 read_messages = stream_context.history_messages # 已读消息存储在history_messages中 @@ -338,7 +345,7 @@ class ChatterPlanFilter: # 构建已读历史消息块 if read_messages: - read_content, read_ids = build_readable_messages_with_id( + read_content, read_ids = await build_readable_messages_with_id( messages=[msg.flatten() for msg in read_messages[-50:]], # 限制数量 timestamp_mode="normal_no_YMD", truncate=False, diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index 57f98954e..af0b68029 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -138,7 +138,7 @@ class ChatterActionPlanner: # 更新StreamContext中的消息信息并刷新focus_energy if context: from src.chat.message_manager.message_manager import message_manager - message_manager.update_message( + await message_manager.update_message( stream_id=self.chat_id, message_id=message.message_id, interest_value=message_interest, @@ -148,7 +148,7 @@ class ChatterActionPlanner: # 更新数据库中的消息记录 try: from src.chat.message_receive.storage import MessageStorage - MessageStorage.update_message_interest_value(message.message_id, message_interest) + await MessageStorage.update_message_interest_value(message.message_id, message_interest) logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}") except Exception as e: logger.warning(f"更新数据库消息兴趣度失败: {e}") diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index f7b9c231a..f043e4f49 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -124,10 +124,10 @@ class EmojiAction(BaseAction): emoji_base64, emoji_description = random.choice(all_emojis_data) else: # 获取最近的5条消息内容用于判断 - recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) + recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5) messages_text = "" if recent_messages: - messages_text = message_api.build_readable_messages( + messages_text = await message_api.build_readable_messages( messages=recent_messages, timestamp_mode="normal_no_YMD", truncate=False, @@ -185,10 +185,10 @@ class EmojiAction(BaseAction): elif global_config.emoji.emoji_selection_mode == "description": # --- 详细描述选择模式 --- # 获取最近的5条消息内容用于判断 - recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) + recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5) messages_text = "" if recent_messages: - messages_text = message_api.build_readable_messages( + messages_text = await message_api.build_readable_messages( messages=recent_messages, timestamp_mode="normal_no_YMD", truncate=False, diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 752e27dfa..09e6f5e53 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -118,7 +118,7 @@ class QZoneService: async def read_and_process_feeds(self, target_name: str, stream_id: Optional[str]) -> Dict[str, Any]: """读取并处理指定好友的说说""" - target_person_id = person_api.get_person_id_by_name(target_name) + target_person_id = await person_api.get_person_id_by_name(target_name) if not target_person_id: return {"success": False, "message": f"找不到名为'{target_name}'的好友"} target_qq = await person_api.get_person_value(target_person_id, "user_id") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index 5ea018f4d..bf9847707 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -331,6 +331,7 @@ class NoticeHandler: like_emoji_id = raw_message.get("likes")[0].get("emoji_id") await event_manager.trigger_event( +<<<<<<< HEAD NapcatEvent.ON_RECEIVED.EMOJI_LIEK, permission_group=PLUGIN_NAME, group_id=group_id, @@ -342,6 +343,16 @@ class NoticeHandler: type="text", data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]", ) +======= + NapcatEvent.ON_RECEIVED.EMOJI_LIEK, + permission_group=PLUGIN_NAME, + group_id=group_id, + user_id=user_id, + message_id=raw_message.get("message_id",""), + emoji_id=like_emoji_id + ) + seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]") +>>>>>>> 9912d7f643d347cbadcf1e3d618aa78bcbf89cc4 return seg_data, user_info async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: diff --git a/src/plugins/built_in/web_search_tool/_manifest.json b/src/plugins/built_in/web_search_tool/_manifest.json deleted file mode 100644 index 549781c2a..000000000 --- a/src/plugins/built_in/web_search_tool/_manifest.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "manifest_version": 1, - "name": "web_search_tool", - "version": "1.0.0", - "description": "一个用于在互联网上搜索信息的工具", - "author": { - "name": "MoFox-Studio", - "url": "https://github.com/MoFox-Studio" - }, - "license": "GPL-v3.0-or-later", - - "host_application": { - "min_version": "0.10.0" - }, - "keywords": ["web_search", "url_parser"], - "categories": ["web_search", "url_parser"], - - "default_locale": "zh-CN", - "locales_path": "_locales", - - "plugin_info": { - "is_built_in": false, - "plugin_type": "web_search" - } -} \ No newline at end of file diff --git a/test_deepcopy_fix.py b/test_deepcopy_fix.py deleted file mode 100644 index c790619b8..000000000 --- a/test_deepcopy_fix.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -""" -测试 ChatStream 的 deepcopy 功能 -验证 asyncio.Task 序列化问题是否已解决 -""" - -import asyncio -import sys -import os -import copy - -# 添加项目根目录到 Python 路径 -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - -from src.chat.message_receive.chat_stream import ChatStream -from maim_message import UserInfo, GroupInfo - - -async def test_chat_stream_deepcopy(): - """测试 ChatStream 的 deepcopy 功能""" - print("[TEST] 开始测试 ChatStream deepcopy 功能...") - - try: - # 创建测试用的用户和群组信息 - user_info = UserInfo( - platform="test_platform", - user_id="test_user_123", - user_nickname="测试用户", - user_cardname="测试卡片名" - ) - - group_info = GroupInfo( - platform="test_platform", - group_id="test_group_456", - group_name="测试群组" - ) - - # 创建 ChatStream 实例 - print("📝 创建 ChatStream 实例...") - stream_id = "test_stream_789" - platform = "test_platform" - - chat_stream = ChatStream( - stream_id=stream_id, - platform=platform, - user_info=user_info, - group_info=group_info - ) - - print(f"[SUCCESS] ChatStream 创建成功: {chat_stream.stream_id}") - - # 等待一下,让异步任务有机会创建 - await asyncio.sleep(0.1) - - # 尝试进行 deepcopy - print("[INFO] 尝试进行 deepcopy...") - copied_stream = copy.deepcopy(chat_stream) - - print("[SUCCESS] deepcopy 成功!") - - # 验证复制后的对象属性 - print("\n[CHECK] 验证复制后的对象属性:") - print(f" - stream_id: {copied_stream.stream_id}") - print(f" - platform: {copied_stream.platform}") - print(f" - user_info: {copied_stream.user_info.user_nickname}") - print(f" - group_info: {copied_stream.group_info.group_name}") - - # 检查 processing_task 是否被正确处理 - if hasattr(copied_stream.stream_context, 'processing_task'): - print(f" - processing_task: {copied_stream.stream_context.processing_task}") - if copied_stream.stream_context.processing_task is None: - print(" [SUCCESS] processing_task 已被正确设置为 None") - else: - print(" [WARNING] processing_task 不为 None") - else: - print(" [SUCCESS] stream_context 没有 processing_task 属性") - - # 验证原始对象和复制对象是不同的实例 - if id(chat_stream) != id(copied_stream): - print("[SUCCESS] 原始对象和复制对象是不同的实例") - else: - print("[ERROR] 原始对象和复制对象是同一个实例") - - # 验证基本属性是否正确复制 - if (chat_stream.stream_id == copied_stream.stream_id and - chat_stream.platform == copied_stream.platform): - print("[SUCCESS] 基本属性正确复制") - else: - print("[ERROR] 基本属性复制失败") - - print("\n[COMPLETE] 测试完成!deepcopy 功能修复成功!") - return True - - except Exception as e: - print(f"[ERROR] 测试失败: {e}") - import traceback - traceback.print_exc() - return False - - -if __name__ == "__main__": - # 运行测试 - result = asyncio.run(test_chat_stream_deepcopy()) - - if result: - print("\n[SUCCESS] 所有测试通过!") - sys.exit(0) - else: - print("\n[ERROR] 测试失败!") - sys.exit(1) \ No newline at end of file diff --git a/test_simple_deepcopy.py b/test_simple_deepcopy.py deleted file mode 100644 index 63e680d45..000000000 --- a/test_simple_deepcopy.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -""" -简单的 ChatStream deepcopy 测试 -""" - -import asyncio -import sys -import os -import copy - -# 添加项目根目录到 Python 路径 -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - -from src.chat.message_receive.chat_stream import ChatStream -from maim_message import UserInfo, GroupInfo - - -async def test_deepcopy(): - """测试 deepcopy 功能""" - print("开始测试 ChatStream deepcopy 功能...") - - try: - # 创建测试用的用户和群组信息 - user_info = UserInfo( - platform="test_platform", - user_id="test_user_123", - user_nickname="测试用户", - user_cardname="测试卡片名" - ) - - group_info = GroupInfo( - platform="test_platform", - group_id="test_group_456", - group_name="测试群组" - ) - - # 创建 ChatStream 实例 - print("创建 ChatStream 实例...") - stream_id = "test_stream_789" - platform = "test_platform" - - chat_stream = ChatStream( - stream_id=stream_id, - platform=platform, - user_info=user_info, - group_info=group_info - ) - - print(f"ChatStream 创建成功: {chat_stream.stream_id}") - - # 等待一下,让异步任务有机会创建 - await asyncio.sleep(0.1) - - # 尝试进行 deepcopy - print("尝试进行 deepcopy...") - copied_stream = copy.deepcopy(chat_stream) - - print("deepcopy 成功!") - - # 验证复制后的对象属性 - print("\n验证复制后的对象属性:") - print(f" - stream_id: {copied_stream.stream_id}") - print(f" - platform: {copied_stream.platform}") - print(f" - user_info: {copied_stream.user_info.user_nickname}") - print(f" - group_info: {copied_stream.group_info.group_name}") - - # 检查 processing_task 是否被正确处理 - if hasattr(copied_stream.stream_context, 'processing_task'): - print(f" - processing_task: {copied_stream.stream_context.processing_task}") - if copied_stream.stream_context.processing_task is None: - print(" SUCCESS: processing_task 已被正确设置为 None") - else: - print(" WARNING: processing_task 不为 None") - else: - print(" SUCCESS: stream_context 没有 processing_task 属性") - - # 验证原始对象和复制对象是不同的实例 - if id(chat_stream) != id(copied_stream): - print("SUCCESS: 原始对象和复制对象是不同的实例") - else: - print("ERROR: 原始对象和复制对象是同一个实例") - - # 验证基本属性是否正确复制 - if (chat_stream.stream_id == copied_stream.stream_id and - chat_stream.platform == copied_stream.platform): - print("SUCCESS: 基本属性正确复制") - else: - print("ERROR: 基本属性复制失败") - - print("\n测试完成!deepcopy 功能修复成功!") - return True - - except Exception as e: - print(f"ERROR: 测试失败: {e}") - import traceback - traceback.print_exc() - return False - - -if __name__ == "__main__": - # 运行测试 - result = asyncio.run(test_deepcopy()) - - if result: - print("\n所有测试通过!") - sys.exit(0) - else: - print("\n测试失败!") - sys.exit(1) \ No newline at end of file