From 371041c9db4ddce9b7b03f8b4d1aa75629a6c56d Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 31 Oct 2025 19:24:58 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=B6=88=E6=81=AF=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=B9=B6=E7=94=A8DatabaseMessages=E6=9B=BF=E6=8D=A2Me?= =?UTF-8?q?ssageRecv?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit -更新PlusCommand以使用DatabaseMessages而不是MessageRecv。 -将消息处理逻辑重构到一个新模块message_processor.py中,以处理消息段并从消息字典中创建DatabaseMessages。 -删除了已弃用的MessageRecv类及其相关逻辑。 -调整了各种插件以适应新的DatabaseMessages结构。 -增强了消息处理功能中的错误处理和日志记录。 --- integration_test_relationship_tools.py | 303 ----------- pyrightconfig.json | 2 +- .../processors/message_processor.py | 10 +- src/chat/message_manager/context_manager.py | 30 +- src/chat/message_manager/message_manager.py | 60 +-- .../message_manager/stream_cache_manager.py | 377 -------------- src/chat/message_receive/bot.py | 235 ++++----- src/chat/message_receive/chat_stream.py | 379 ++++---------- src/chat/message_receive/message.py | 426 ++------------- src/chat/message_receive/message_processor.py | 493 ++++++++++++++++++ .../message_receive/message_recv_backup.py | 434 +++++++++++++++ src/chat/message_receive/storage.py | 254 +++++---- src/chat/planner_actions/action_modifier.py | 2 +- src/chat/replyer/default_generator.py | 11 +- src/chat/utils/utils.py | 51 +- src/common/database/db_migration.py | 7 +- src/common/database/sqlalchemy_models.py | 8 +- src/mood/mood_manager.py | 10 +- src/plugin_system/apis/send_api.py | 126 ++--- src/plugin_system/base/base_command.py | 44 +- src/plugin_system/base/plus_command.py | 47 +- .../affinity_flow_chatter/plan_executor.py | 8 +- .../proactive_thinking_event.py | 4 +- 23 files changed, 1520 insertions(+), 1801 deletions(-) delete mode 100644 integration_test_relationship_tools.py delete mode 100644 src/chat/message_manager/stream_cache_manager.py create mode 100644 src/chat/message_receive/message_processor.py create mode 100644 src/chat/message_receive/message_recv_backup.py diff --git a/integration_test_relationship_tools.py b/integration_test_relationship_tools.py deleted file mode 100644 index a2ac3a7fa..000000000 --- a/integration_test_relationship_tools.py +++ /dev/null @@ -1,303 +0,0 @@ -""" -关系追踪工具集成测试脚本 - -注意:此脚本需要在完整的应用环境中运行 -建议通过 bot.py 启动后在交互式环境中测试 -""" - -import asyncio - - -async def test_user_profile_tool(): - """测试用户画像工具""" - print("\n" + "=" * 80) - print("测试 UserProfileTool") - print("=" * 80) - - from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool - from src.common.database.sqlalchemy_database_api import db_query - from src.common.database.sqlalchemy_models import UserRelationships - - tool = UserProfileTool() - print(f"✅ 工具名称: {tool.name}") - print(f" 工具描述: {tool.description}") - - # 执行工具 - test_user_id = "integration_test_user_001" - result = await tool.execute({ - "target_user_id": test_user_id, - "user_aliases": "测试小明,TestMing,小明君", - "impression_description": "这是一个集成测试用户,性格开朗活泼,喜欢技术讨论,对AI和编程特别感兴趣。经常提出有深度的问题。", - "preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说", - "affection_score": 0.85 - }) - - print(f"\n✅ 工具执行结果:") - print(f" 类型: {result.get('type')}") - print(f" 内容: {result.get('content')}") - - # 验证数据库 - db_data = await db_query( - UserRelationships, - filters={"user_id": test_user_id}, - limit=1 - ) - - if db_data: - data = db_data[0] - print(f"\n✅ 数据库验证:") - print(f" user_id: {data.get('user_id')}") - print(f" user_aliases: {data.get('user_aliases')}") - print(f" relationship_text: {data.get('relationship_text', '')[:80]}...") - print(f" preference_keywords: {data.get('preference_keywords')}") - print(f" relationship_score: {data.get('relationship_score')}") - return True - else: - print(f"\n❌ 数据库中未找到数据") - return False - - -async def test_chat_stream_impression_tool(): - """测试聊天流印象工具""" - print("\n" + "=" * 80) - print("测试 ChatStreamImpressionTool") - print("=" * 80) - - from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool - from src.common.database.sqlalchemy_database_api import db_query - from src.common.database.sqlalchemy_models import ChatStreams, get_db_session - - # 准备测试数据:先创建一条 ChatStreams 记录 - test_stream_id = "integration_test_stream_001" - print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}") - - import time - current_time = time.time() - - async with get_db_session() as session: - new_stream = ChatStreams( - stream_id=test_stream_id, - create_time=current_time, - last_active_time=current_time, - platform="QQ", - user_platform="QQ", - user_id="test_user_123", - user_nickname="测试用户", - group_name="测试技术交流群", - group_platform="QQ", - group_id="test_group_456", - stream_impression_text="", # 初始为空 - stream_chat_style="", - stream_topic_keywords="", - stream_interest_score=0.5 - ) - session.add(new_stream) - await session.commit() - print(f"✅ 测试聊天流记录已创建") - - tool = ChatStreamImpressionTool() - print(f"✅ 工具名称: {tool.name}") - print(f" 工具描述: {tool.description}") - - # 执行工具 - result = await tool.execute({ - "stream_id": test_stream_id, - "impression_description": "这是一个技术交流群,成员主要是程序员和AI爱好者。大家经常分享最新的技术文章,讨论编程问题,氛围友好且专业。", - "chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享", - "topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目", - "interest_score": 0.90 - }) - - print(f"\n✅ 工具执行结果:") - print(f" 类型: {result.get('type')}") - print(f" 内容: {result.get('content')}") - - # 验证数据库 - db_data = await db_query( - ChatStreams, - filters={"stream_id": test_stream_id}, - limit=1 - ) - - if db_data: - data = db_data[0] - print(f"\n✅ 数据库验证:") - print(f" stream_id: {data.get('stream_id')}") - print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...") - print(f" stream_chat_style: {data.get('stream_chat_style')}") - print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}") - print(f" stream_interest_score: {data.get('stream_interest_score')}") - return True - else: - print(f"\n❌ 数据库中未找到数据") - return False - - -async def test_relationship_info_build(): - """测试关系信息构建""" - print("\n" + "=" * 80) - print("测试关系信息构建(提示词集成)") - print("=" * 80) - - from src.person_info.relationship_fetcher import relationship_fetcher_manager - - test_stream_id = "integration_test_stream_001" - test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试 - - fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id) - print(f"✅ RelationshipFetcher 已创建") - - # 测试聊天流印象构建 - print(f"\n🔍 构建聊天流印象...") - stream_info = await fetcher.build_chat_stream_impression(test_stream_id) - - if stream_info: - print(f"✅ 聊天流印象构建成功") - print(f"\n{'=' * 80}") - print(stream_info) - print(f"{'=' * 80}") - else: - print(f"⚠️ 聊天流印象为空(可能测试数据不存在)") - - return True - - -async def cleanup_test_data(): - """清理测试数据""" - print("\n" + "=" * 80) - print("清理测试数据") - print("=" * 80) - - from src.common.database.sqlalchemy_database_api import db_query - from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams - - try: - # 清理用户数据 - await db_query( - UserRelationships, - query_type="delete", - filters={"user_id": "integration_test_user_001"} - ) - print("✅ 用户测试数据已清理") - - # 清理聊天流数据 - await db_query( - ChatStreams, - query_type="delete", - filters={"stream_id": "integration_test_stream_001"} - ) - print("✅ 聊天流测试数据已清理") - - return True - except Exception as e: - print(f"⚠️ 清理失败: {e}") - return False - - -async def run_all_tests(): - """运行所有测试""" - print("\n" + "=" * 80) - print("关系追踪工具集成测试") - print("=" * 80) - - results = {} - - # 测试1 - try: - results["UserProfileTool"] = await test_user_profile_tool() - except Exception as e: - print(f"\n❌ UserProfileTool 测试失败: {e}") - import traceback - traceback.print_exc() - results["UserProfileTool"] = False - - # 测试2 - try: - results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool() - except Exception as e: - print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}") - import traceback - traceback.print_exc() - results["ChatStreamImpressionTool"] = False - - # 测试3 - try: - results["RelationshipFetcher"] = await test_relationship_info_build() - except Exception as e: - print(f"\n❌ RelationshipFetcher 测试失败: {e}") - import traceback - traceback.print_exc() - results["RelationshipFetcher"] = False - - # 清理 - try: - await cleanup_test_data() - except Exception as e: - print(f"\n⚠️ 清理测试数据失败: {e}") - - # 总结 - print("\n" + "=" * 80) - print("测试总结") - print("=" * 80) - - passed = sum(1 for r in results.values() if r) - total = len(results) - - for test_name, result in results.items(): - status = "✅ 通过" if result else "❌ 失败" - print(f"{status} - {test_name}") - - print(f"\n总计: {passed}/{total} 测试通过") - - if passed == total: - print("\n🎉 所有测试通过!") - else: - print(f"\n⚠️ {total - passed} 个测试失败") - - return passed == total - - -# 使用说明 -print(""" -============================================================================ -关系追踪工具集成测试脚本 -============================================================================ - -此脚本需要在完整的应用环境中运行。 - -使用方法1: 在 bot.py 中添加测试调用 ------------------------------------ -在 bot.py 的 main() 函数中添加: - - # 测试关系追踪工具 - from tests.integration_test_relationship_tools import run_all_tests - await run_all_tests() - -使用方法2: 在 Python REPL 中运行 ------------------------------------ -启动 bot.py 后,在 Python 调试控制台中执行: - - import asyncio - from tests.integration_test_relationship_tools import run_all_tests - asyncio.create_task(run_all_tests()) - -使用方法3: 直接在此文件底部运行 ------------------------------------ -取消注释下面的代码,然后确保已启动应用环境 -============================================================================ -""") - - -# 如果需要直接运行(需要应用环境已启动) -if __name__ == "__main__": - print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境") - print("建议在 bot.py 启动后的环境中运行\n") - - try: - asyncio.run(run_all_tests()) - except Exception as e: - print(f"\n❌ 测试失败: {e}") - print("\n建议:") - print("1. 确保已启动 bot.py") - print("2. 在 Python 调试控制台中运行测试") - print("3. 或在 bot.py 中添加测试调用") diff --git a/pyrightconfig.json b/pyrightconfig.json index 3cffac58c..adf9c8dcf 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -27,6 +27,6 @@ "venvPath": ".", "venv": ".venv", "executionEnvironments": [ - {"root": "src"} + {"root": "."} ] } diff --git a/src/chat/antipromptinjector/processors/message_processor.py b/src/chat/antipromptinjector/processors/message_processor.py index 0e37efc0d..b13baff13 100644 --- a/src/chat/antipromptinjector/processors/message_processor.py +++ b/src/chat/antipromptinjector/processors/message_processor.py @@ -6,7 +6,7 @@ import re -from src.chat.message_receive.message import MessageRecv +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger logger = get_logger("anti_injector.message_processor") @@ -15,7 +15,7 @@ logger = get_logger("anti_injector.message_processor") class MessageProcessor: """消息内容处理器""" - def extract_text_content(self, message: MessageRecv) -> str: + def extract_text_content(self, message: DatabaseMessages) -> str: """提取消息中的文本内容,过滤掉引用的历史内容 Args: @@ -64,7 +64,7 @@ class MessageProcessor: return new_content @staticmethod - def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None: + def check_whitelist(message: DatabaseMessages, whitelist: list) -> tuple | None: """检查用户白名单 Args: @@ -74,8 +74,8 @@ class MessageProcessor: Returns: 如果在白名单中返回结果元组,否则返回None """ - user_id = message.message_info.user_info.user_id - platform = message.message_info.platform + user_id = message.user_info.user_id + platform = message.chat_info.platform # 检查用户白名单:格式为 [[platform, user_id], ...] for whitelist_entry in whitelist: diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index 41bf47781..c569fde6b 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -29,7 +29,6 @@ class SingleStreamContextManager: # 配置参数 self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100) - self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时 # 元数据 self.created_time = time.time() @@ -93,27 +92,24 @@ class SingleStreamContextManager: return True else: logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}") - - except ImportError: - logger.debug("MessageManager不可用,使用直接添加模式") except Exception as e: logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}") - # 回退方案:直接添加到未读消息 - message.is_read = False - self.context.unread_messages.append(message) + # 回退方案:直接添加到未读消息 + message.is_read = False + self.context.unread_messages.append(message) - # 自动检测和更新chat type - self._detect_chat_type(message) + # 自动检测和更新chat type + self._detect_chat_type(message) - # 在上下文管理器中计算兴趣值 - await self._calculate_message_interest(message) - self.total_messages += 1 - self.last_access_time = time.time() - # 启动流的循环任务(如果还未启动) - asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id)) - logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}") - return True + # 在上下文管理器中计算兴趣值 + await self._calculate_message_interest(message) + self.total_messages += 1 + self.last_access_time = time.time() + # 启动流的循环任务(如果还未启动) + asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id)) + logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}") + return True except Exception as e: logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) return False diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 49c169640..b54fc8bdc 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -71,14 +71,6 @@ class MessageManager: except Exception as e: logger.error(f"启动批量数据库写入器失败: {e}") - # 启动流缓存管理器 - try: - from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager - - await init_stream_cache_manager() - except Exception as e: - logger.error(f"启动流缓存管理器失败: {e}") - # 启动消息缓存系统(内置) logger.info("📦 消息缓存系统已启动") @@ -116,15 +108,6 @@ class MessageManager: except Exception as e: logger.error(f"停止批量数据库写入器失败: {e}") - # 停止流缓存管理器 - try: - from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager - - await shutdown_stream_cache_manager() - logger.info("🗄️ 流缓存管理器已停止") - except Exception as e: - logger.error(f"停止流缓存管理器失败: {e}") - # 停止消息缓存系统(内置) self.message_caches.clear() self.stream_processing_status.clear() @@ -152,7 +135,7 @@ class MessageManager: # 检查是否为notice消息 if self._is_notice_message(message): # Notice消息处理 - 添加到全局管理器 - logger.info(f"📢 检测到notice消息: message_id={message.message_id}, is_notify={message.is_notify}, notice_type={getattr(message, 'notice_type', None)}") + logger.info(f"📢 检测到notice消息: notice_type={getattr(message, 'notice_type', None)}") await self._handle_notice_message(stream_id, message) # 根据配置决定是否继续处理(触发聊天流程) @@ -206,39 +189,6 @@ class MessageManager: except Exception as e: logger.error(f"更新消息 {message_id} 时发生错误: {e}") - async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int: - """批量更新消息信息,降低更新频率""" - if not updates: - return 0 - - try: - chat_manager = get_chat_manager() - chat_stream = await chat_manager.get_stream(stream_id) - if not chat_stream: - logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在") - return 0 - - updated_count = 0 - for item in updates: - message_id = item.get("message_id") - if not message_id: - continue - - payload = {key: value for key, value in item.items() if key != "message_id" and value is not None} - - if not payload: - continue - - success = await chat_stream.context_manager.update_message(message_id, payload) - if success: - updated_count += 1 - - if updated_count: - logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})") - return updated_count - except Exception as e: - logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}") - return 0 async def add_action(self, stream_id: str, message_id: str, action: str): """添加动作到消息""" @@ -266,7 +216,7 @@ class MessageManager: logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在") return - context = chat_stream.stream_context + context = chat_stream.context_manager.context context.is_active = False # 取消处理任务 @@ -288,7 +238,7 @@ class MessageManager: logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在") return - context = chat_stream.stream_context + context = chat_stream.context_manager.context context.is_active = True logger.info(f"激活聊天流: {stream_id}") @@ -304,7 +254,7 @@ class MessageManager: if not chat_stream: return None - context = chat_stream.stream_context + context = chat_stream.context_manager.context unread_count = len(chat_stream.context_manager.get_unread_messages()) return StreamStats( @@ -447,7 +397,7 @@ class MessageManager: await asyncio.sleep(0.1) # 获取当前的stream context - context = chat_stream.stream_context + context = chat_stream.context_manager.context # 确保有未读消息需要处理 unread_messages = context.get_unread_messages() diff --git a/src/chat/message_manager/stream_cache_manager.py b/src/chat/message_manager/stream_cache_manager.py deleted file mode 100644 index ea85c3855..000000000 --- a/src/chat/message_manager/stream_cache_manager.py +++ /dev/null @@ -1,377 +0,0 @@ -""" -流缓存管理器 - 使用优化版聊天流和智能缓存策略 -提供分层缓存和自动清理功能 -""" - -import asyncio -import time -from collections import OrderedDict -from dataclasses import dataclass - -from maim_message import GroupInfo, UserInfo - -from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream -from src.common.logger import get_logger - -logger = get_logger("stream_cache_manager") - - -@dataclass -class StreamCacheStats: - """缓存统计信息""" - - hot_cache_size: int = 0 - warm_storage_size: int = 0 - cold_storage_size: int = 0 - total_memory_usage: int = 0 # 估算的内存使用(字节) - cache_hits: int = 0 - cache_misses: int = 0 - evictions: int = 0 - last_cleanup_time: float = 0 - - -class TieredStreamCache: - """分层流缓存管理器""" - - def __init__( - self, - max_hot_size: int = 100, - max_warm_size: int = 500, - max_cold_size: int = 2000, - cleanup_interval: float = 300.0, # 5分钟清理一次 - hot_timeout: float = 1800.0, # 30分钟未访问降级到warm - warm_timeout: float = 7200.0, # 2小时未访问降级到cold - cold_timeout: float = 86400.0, # 24小时未访问删除 - ): - self.max_hot_size = max_hot_size - self.max_warm_size = max_warm_size - self.max_cold_size = max_cold_size - self.cleanup_interval = cleanup_interval - self.hot_timeout = hot_timeout - self.warm_timeout = warm_timeout - self.cold_timeout = cold_timeout - - # 三层缓存存储 - self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU) - self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间) - self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间) - - # 统计信息 - self.stats = StreamCacheStats() - - # 清理任务 - self.cleanup_task: asyncio.Task | None = None - self.is_running = False - - logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})") - - async def start(self): - """启动缓存管理器""" - if self.is_running: - logger.warning("缓存管理器已经在运行") - return - - self.is_running = True - self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup") - - async def stop(self): - """停止缓存管理器""" - if not self.is_running: - return - - self.is_running = False - - if self.cleanup_task and not self.cleanup_task.done(): - self.cleanup_task.cancel() - try: - await asyncio.wait_for(self.cleanup_task, timeout=10.0) - except asyncio.TimeoutError: - logger.warning("缓存清理任务停止超时") - except Exception as e: - logger.error(f"停止缓存清理任务时出错: {e}") - - logger.info("分层流缓存管理器已停止") - - async def get_or_create_stream( - self, - stream_id: str, - platform: str, - user_info: UserInfo, - group_info: GroupInfo | None = None, - data: dict | None = None, - ) -> OptimizedChatStream: - """获取或创建流 - 优化版本""" - current_time = time.time() - - # 1. 检查热缓存 - if stream_id in self.hot_cache: - stream = self.hot_cache[stream_id] - # 移动到末尾(LRU更新) - self.hot_cache.move_to_end(stream_id) - self.stats.cache_hits += 1 - logger.debug(f"热缓存命中: {stream_id}") - return stream.create_snapshot() - - # 2. 检查温存储 - if stream_id in self.warm_storage: - stream, last_access = self.warm_storage[stream_id] - self.warm_storage[stream_id] = (stream, current_time) - self.stats.cache_hits += 1 - logger.debug(f"温缓存命中: {stream_id}") - # 提升到热缓存 - await self._promote_to_hot(stream_id, stream) - return stream.create_snapshot() - - # 3. 检查冷存储 - if stream_id in self.cold_storage: - stream, last_access = self.cold_storage[stream_id] - self.cold_storage[stream_id] = (stream, current_time) - self.stats.cache_hits += 1 - logger.debug(f"冷缓存命中: {stream_id}") - # 提升到温缓存 - await self._promote_to_warm(stream_id, stream) - return stream.create_snapshot() - - # 4. 缓存未命中,创建新流 - self.stats.cache_misses += 1 - stream = create_optimized_chat_stream( - stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data - ) - logger.debug(f"缓存未命中,创建新流: {stream_id}") - - # 添加到热缓存 - await self._add_to_hot(stream_id, stream) - - return stream - - async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream): - """添加到热缓存""" - # 检查是否需要驱逐 - if len(self.hot_cache) >= self.max_hot_size: - await self._evict_from_hot() - - self.hot_cache[stream_id] = stream - self.stats.hot_cache_size = len(self.hot_cache) - - async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream): - """提升到热缓存""" - # 从温存储中移除 - if stream_id in self.warm_storage: - del self.warm_storage[stream_id] - self.stats.warm_storage_size = len(self.warm_storage) - - # 添加到热缓存 - await self._add_to_hot(stream_id, stream) - logger.debug(f"流 {stream_id} 提升到热缓存") - - async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream): - """提升到温缓存""" - # 从冷存储中移除 - if stream_id in self.cold_storage: - del self.cold_storage[stream_id] - self.stats.cold_storage_size = len(self.cold_storage) - - # 添加到温存储 - if len(self.warm_storage) >= self.max_warm_size: - await self._evict_from_warm() - - current_time = time.time() - self.warm_storage[stream_id] = (stream, current_time) - self.stats.warm_storage_size = len(self.warm_storage) - logger.debug(f"流 {stream_id} 提升到温缓存") - - async def _evict_from_hot(self): - """从热缓存驱逐最久未使用的流""" - if not self.hot_cache: - return - - # LRU驱逐 - stream_id, stream = self.hot_cache.popitem(last=False) - self.stats.evictions += 1 - logger.debug(f"从热缓存驱逐: {stream_id}") - - # 移动到温存储 - if len(self.warm_storage) < self.max_warm_size: - current_time = time.time() - self.warm_storage[stream_id] = (stream, current_time) - self.stats.warm_storage_size = len(self.warm_storage) - else: - # 温存储也满了,直接删除 - logger.debug(f"温存储已满,删除流: {stream_id}") - - self.stats.hot_cache_size = len(self.hot_cache) - - async def _evict_from_warm(self): - """从温存储驱逐最久未使用的流""" - if not self.warm_storage: - return - - # 找到最久未访问的流 - oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1]) - stream, last_access = self.warm_storage.pop(oldest_stream_id) - self.stats.evictions += 1 - logger.debug(f"从温存储驱逐: {oldest_stream_id}") - - # 移动到冷存储 - if len(self.cold_storage) < self.max_cold_size: - current_time = time.time() - self.cold_storage[oldest_stream_id] = (stream, current_time) - self.stats.cold_storage_size = len(self.cold_storage) - else: - # 冷存储也满了,直接删除 - logger.debug(f"冷存储已满,删除流: {oldest_stream_id}") - - self.stats.warm_storage_size = len(self.warm_storage) - - async def _cleanup_loop(self): - """清理循环""" - logger.info("流缓存清理循环启动") - - while self.is_running: - try: - await asyncio.sleep(self.cleanup_interval) - await self._perform_cleanup() - except asyncio.CancelledError: - logger.info("流缓存清理循环被取消") - break - except Exception as e: - logger.error(f"流缓存清理出错: {e}") - - logger.info("流缓存清理循环结束") - - async def _perform_cleanup(self): - """执行清理操作""" - current_time = time.time() - cleanup_stats = { - "hot_to_warm": 0, - "warm_to_cold": 0, - "cold_removed": 0, - } - - # 1. 检查热缓存超时 - hot_to_demote = [] - for stream_id, stream in self.hot_cache.items(): - # 获取最后访问时间(简化:使用创建时间作为近似) - last_access = getattr(stream, "last_active_time", stream.create_time) - if current_time - last_access > self.hot_timeout: - hot_to_demote.append(stream_id) - - for stream_id in hot_to_demote: - stream = self.hot_cache.pop(stream_id) - current_time_local = time.time() - self.warm_storage[stream_id] = (stream, current_time_local) - cleanup_stats["hot_to_warm"] += 1 - - # 2. 检查温存储超时 - warm_to_demote = [] - for stream_id, (stream, last_access) in self.warm_storage.items(): - if current_time - last_access > self.warm_timeout: - warm_to_demote.append(stream_id) - - for stream_id in warm_to_demote: - stream, last_access = self.warm_storage.pop(stream_id) - self.cold_storage[stream_id] = (stream, last_access) - cleanup_stats["warm_to_cold"] += 1 - - # 3. 检查冷存储超时 - cold_to_remove = [] - for stream_id, (stream, last_access) in self.cold_storage.items(): - if current_time - last_access > self.cold_timeout: - cold_to_remove.append(stream_id) - - for stream_id in cold_to_remove: - self.cold_storage.pop(stream_id) - cleanup_stats["cold_removed"] += 1 - - # 更新统计信息 - self.stats.hot_cache_size = len(self.hot_cache) - self.stats.warm_storage_size = len(self.warm_storage) - self.stats.cold_storage_size = len(self.cold_storage) - self.stats.last_cleanup_time = current_time - - # 估算内存使用(粗略估计) - self.stats.total_memory_usage = ( - len(self.hot_cache) * 1024 # 每个热流约1KB - + len(self.warm_storage) * 512 # 每个温流约512B - + len(self.cold_storage) * 256 # 每个冷流约256B - ) - - if sum(cleanup_stats.values()) > 0: - logger.info( - f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, " - f"{cleanup_stats['warm_to_cold']}温→冷, " - f"{cleanup_stats['cold_removed']}冷删除" - ) - - def get_stats(self) -> StreamCacheStats: - """获取缓存统计信息""" - # 计算命中率 - total_requests = self.stats.cache_hits + self.stats.cache_misses - hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0 - - stats_copy = StreamCacheStats( - hot_cache_size=self.stats.hot_cache_size, - warm_storage_size=self.stats.warm_storage_size, - cold_storage_size=self.stats.cold_storage_size, - total_memory_usage=self.stats.total_memory_usage, - cache_hits=self.stats.cache_hits, - cache_misses=self.stats.cache_misses, - evictions=self.stats.evictions, - last_cleanup_time=self.stats.last_cleanup_time, - ) - - # 添加命中率信息 - stats_copy.hit_rate = hit_rate - - return stats_copy - - def clear_cache(self): - """清空所有缓存""" - self.hot_cache.clear() - self.warm_storage.clear() - self.cold_storage.clear() - - self.stats.hot_cache_size = 0 - self.stats.warm_storage_size = 0 - self.stats.cold_storage_size = 0 - self.stats.total_memory_usage = 0 - - logger.info("所有缓存已清空") - - async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None: - """获取流的快照(不修改缓存状态)""" - if stream_id in self.hot_cache: - return self.hot_cache[stream_id].create_snapshot() - elif stream_id in self.warm_storage: - return self.warm_storage[stream_id][0].create_snapshot() - elif stream_id in self.cold_storage: - return self.cold_storage[stream_id][0].create_snapshot() - return None - - def get_cached_stream_ids(self) -> set[str]: - """获取所有缓存的流ID""" - return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys()) - - -# 全局缓存管理器实例 -_cache_manager: TieredStreamCache | None = None - - -def get_stream_cache_manager() -> TieredStreamCache: - """获取流缓存管理器实例""" - global _cache_manager - if _cache_manager is None: - _cache_manager = TieredStreamCache() - return _cache_manager - - -async def init_stream_cache_manager(): - """初始化流缓存管理器""" - manager = get_stream_cache_manager() - await manager.start() - - -async def shutdown_stream_cache_manager(): - """关闭流缓存管理器""" - manager = get_stream_cache_manager() - await manager.stop() diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 1096852cf..32e8c90e5 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -9,10 +9,10 @@ from maim_message import UserInfo from src.chat.antipromptinjector import initialize_anti_injector from src.chat.message_manager import message_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.storage import MessageStorage from src.chat.utils.prompt import create_prompt_async, global_prompt_manager from src.chat.utils.utils import is_mentioned_bot_in_message +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config from src.mood.mood_manager import mood_manager # 导入情绪管理器 @@ -105,10 +105,10 @@ class ChatBot: self._started = True - async def _process_plus_commands(self, message: MessageRecv): + async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream): """独立处理PlusCommand系统""" try: - text = message.processed_plain_text + text = message.processed_plain_text or "" # 获取配置的命令前缀 from src.config.config import global_config @@ -166,10 +166,10 @@ class ChatBot: # 检查命令是否被禁用 if ( - message.chat_stream - and message.chat_stream.stream_id + chat + and chat.stream_id and plus_command_name - in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id) + in global_announcement_manager.get_disabled_chat_commands(chat.stream_id) ): logger.info("用户禁用的PlusCommand,跳过处理") return False, None, True @@ -181,11 +181,14 @@ class ChatBot: # 创建PlusCommand实例 plus_command_instance = plus_command_class(message, plugin_config) + + # 为插件实例设置 chat_stream 运行时属性 + setattr(plus_command_instance, "chat_stream", chat) try: # 检查聊天类型限制 if not plus_command_instance.is_chat_type_allowed(): - is_group = message.message_info.group_info + is_group = chat.group_info is not None logger.info( f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" ) @@ -225,11 +228,11 @@ class ChatBot: logger.error(f"处理PlusCommand时出错: {e}") return False, None, True # 出错时继续处理消息 - async def _process_commands_with_new_system(self, message: MessageRecv): + async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream): # sourcery skip: use-named-expression """使用新插件系统处理命令""" try: - text = message.processed_plain_text + text = message.processed_plain_text or "" # 使用新的组件注册中心查找命令 command_result = component_registry.find_command_by_text(text) @@ -238,10 +241,10 @@ class ChatBot: plugin_name = command_info.plugin_name command_name = command_info.name if ( - message.chat_stream - and message.chat_stream.stream_id + chat + and chat.stream_id and command_name - in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id) + in global_announcement_manager.get_disabled_chat_commands(chat.stream_id) ): logger.info("用户禁用的命令,跳过处理") return False, None, True @@ -254,11 +257,14 @@ class ChatBot: # 创建命令实例 command_instance: BaseCommand = command_class(message, plugin_config) command_instance.set_matched_groups(matched_groups) + + # 为插件实例设置 chat_stream 运行时属性 + setattr(command_instance, "chat_stream", chat) try: # 检查聊天类型限制 if not command_instance.is_chat_type_allowed(): - is_group = message.message_info.group_info + is_group = chat.group_info is not None logger.info( f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" ) @@ -295,13 +301,20 @@ class ChatBot: logger.error(f"处理命令时出错: {e}") return False, None, True # 出错时继续处理消息 - async def handle_notice_message(self, message: MessageRecv): + async def handle_notice_message(self, message: DatabaseMessages): """处理notice消息 notice消息是系统事件通知(如禁言、戳一戳等),具有以下特点: 1. 默认不触发聊天流程,只记录 2. 可通过配置开启触发聊天流程 3. 会在提示词中展示 + + Args: + message: DatabaseMessages 对象 + + Returns: + bool: True表示notice已完整处理(需要存储并终止后续流程) + False表示不是notice或notice需要继续处理(触发聊天流程) """ # 检查是否是notice消息 if message.is_notify: @@ -309,53 +322,42 @@ class ChatBot: # 根据配置决定是否触发聊天流程 if not global_config.notice.enable_notice_trigger_chat: - logger.debug("notice消息不触发聊天流程(配置已关闭)") - return True # 返回True表示已处理,不继续后续流程 + logger.debug("notice消息不触发聊天流程(配置已关闭),将存储后终止") + return True # 返回True:需要在调用处存储并终止 else: - logger.debug("notice消息触发聊天流程(配置已开启)") - return False # 返回False表示继续处理,触发聊天流程 + logger.debug("notice消息触发聊天流程(配置已开启),继续处理") + return False # 返回False:继续正常流程,作为普通消息处理 # 兼容旧的notice判断方式 - if message.message_info.message_id == "notice": - message.is_notify = True + if message.message_id == "notice": + # 为 DatabaseMessages 设置 is_notify 运行时属性 + from src.chat.message_receive.message_processor import set_db_message_runtime_attr + set_db_message_runtime_attr(message, "is_notify", True) logger.info("旧格式notice消息") # 同样根据配置决定 if not global_config.notice.enable_notice_trigger_chat: - return True + logger.debug("旧格式notice消息不触发聊天流程,将存储后终止") + return True # 需要存储并终止 else: - return False + logger.debug("旧格式notice消息触发聊天流程,继续处理") + return False # 继续正常流程 - # 处理适配器响应消息 - if hasattr(message, "message_segment") and message.message_segment: - if message.message_segment.type == "adapter_response": - await self.handle_adapter_response(message) - return True - elif message.message_segment.type == "adapter_command": - # 适配器命令消息不需要进一步处理 - logger.debug("收到适配器命令消息,跳过后续处理") - return True + # DatabaseMessages 不再有 message_segment,适配器响应处理已在消息处理阶段完成 + # 这里保留逻辑以防万一,但实际上不会再执行到 + return False # 不是notice消息,继续正常流程 - return False - - async def handle_adapter_response(self, message: MessageRecv): - """处理适配器命令响应""" + async def handle_adapter_response(self, message: DatabaseMessages): + """处理适配器命令响应 + + 注意: 此方法目前未被调用,但保留以备将来使用 + """ try: from src.plugin_system.apis.send_api import put_adapter_response - seg_data = message.message_segment.data - if isinstance(seg_data, dict): - request_id = seg_data.get("request_id") - response_data = seg_data.get("response") - else: - request_id = None - response_data = None - - if request_id and response_data: - logger.debug(f"收到适配器响应: request_id={request_id}") - put_adapter_response(request_id, response_data) - else: - logger.warning("适配器响应消息格式不正确") + # DatabaseMessages 使用 message_segments 字段存储消息段 + # 注意: 这可能需要根据实际使用情况进行调整 + logger.warning("handle_adapter_response 方法被调用,但目前未实现对 DatabaseMessages 的支持") except Exception as e: logger.error(f"处理适配器响应时出错: {e}") @@ -381,9 +383,6 @@ class ChatBot: await self._ensure_started() # 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError - if not isinstance(message_data, dict): - logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过") - return message_info = message_data.get("message_info") if not isinstance(message_info, dict): logger.debug( @@ -392,8 +391,6 @@ class ChatBot: ) return - platform = message_info.get("platform") - if message_info.get("group_info") is not None: message_info["group_info"]["group_id"] = str( message_info["group_info"]["group_id"] @@ -404,74 +401,94 @@ class ChatBot: ) # print(message_data) # logger.debug(str(message_data)) - message = MessageRecv(message_data) - - group_info = message.message_info.group_info - user_info = message.message_info.user_info - if message.message_info.additional_config: - sent_message = message.message_info.additional_config.get("echo", False) + + # 先提取基础信息检查是否是自身消息上报 + from maim_message import BaseMessageInfo + temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {})) + if temp_message_info.additional_config: + sent_message = temp_message_info.additional_config.get("echo", False) if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题 - await MessageStorage.update_message(message) + # 直接使用消息字典更新,不再需要创建 MessageRecv + await MessageStorage.update_message(message_data) return + + group_info = temp_message_info.group_info + user_info = temp_message_info.user_info - get_chat_manager().register_message(message) - + # 获取或创建聊天流 chat = await get_chat_manager().get_or_create_stream( - platform=message.message_info.platform, # type: ignore + platform=temp_message_info.platform, # type: ignore user_info=user_info, # type: ignore group_info=group_info, ) - message.update_chat_stream(chat) - - # 处理消息内容,生成纯文本 - await message.process() - + # 使用新的消息处理器直接生成 DatabaseMessages + from src.chat.message_receive.message_processor import process_message_from_dict + message = await process_message_from_dict( + message_dict=message_data, + stream_id=chat.stream_id, + platform=chat.platform + ) + + # 填充聊天流时间信息 + message.chat_info.create_time = chat.create_time + message.chat_info.last_active_time = chat.last_active_time + + # 注册消息到聊天管理器 + get_chat_manager().register_message(message) + + # 检测是否提及机器人 message.is_mentioned, _ = is_mentioned_bot_in_message(message) # 在这里打印[所见]日志,确保在所有处理和过滤之前记录 chat_name = chat.group_info.group_name if chat.group_info else "私聊" - if message.message_info.user_info: - logger.info( - f"[{chat_name}]{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m" - ) + user_nickname = message.user_info.user_nickname if message.user_info else "未知用户" + logger.info( + f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m" + ) # 在此添加硬编码过滤,防止回复图片处理失败的消息 failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"] - if any(keyword in message.processed_plain_text for keyword in failure_keywords): - logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({message.processed_plain_text}),消息被静默处理。") + processed_text = message.processed_plain_text or "" + if any(keyword in processed_text for keyword in failure_keywords): + logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。") return # 处理notice消息 + # notice_handled=True: 表示notice不触发聊天,需要在此存储并终止 + # notice_handled=False: 表示notice触发聊天或不是notice,继续正常流程 notice_handled = await self.handle_notice_message(message) if notice_handled: - # notice消息已处理,使用统一的转换方法 + # notice消息不触发聊天流程,在此进行存储和记录后终止 try: - # 直接转换为 DatabaseMessages - db_message = message.to_database_message() - + # message 已经是 DatabaseMessages,直接使用 # 添加到message_manager(这会将notice添加到全局notice管理器) - await message_manager.add_message(message.chat_stream.stream_id, db_message) - logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={message.chat_stream.stream_id}") + await message_manager.add_message(chat.stream_id, message) + logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={chat.stream_id}") except Exception as e: logger.error(f"Notice消息添加到message_manager失败: {e}", exc_info=True) - # 存储后直接返回 - await MessageStorage.store_message(message, chat) - logger.debug("notice消息已存储,跳过后续处理") + # 存储notice消息到数据库(需要更新 storage.py 支持 DatabaseMessages) + # 暂时跳过存储,等待更新 storage.py + logger.debug("notice消息已添加到message_manager(存储功能待更新)") return + + # 如果notice_handled=False,则继续执行后续流程 + # 对于启用触发聊天的notice,会在后续的正常流程中被存储和处理 # 过滤检查 + # DatabaseMessages 使用 display_message 作为原始消息表示 + raw_text = message.display_message or message.processed_plain_text or "" if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore - message.raw_message, # type: ignore + raw_text, chat, user_info, # type: ignore ): return # 命令处理 - 首先尝试PlusCommand独立处理 - is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message) + is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat) # 如果是PlusCommand且不需要继续处理,则直接返回 if is_plus_command and not plus_continue_process: @@ -481,7 +498,7 @@ class ChatBot: # 如果不是PlusCommand,尝试传统的BaseCommand处理 if not is_plus_command: - is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message) + is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat) # 如果是命令且不需要继续处理,则直接返回 if is_command and not continue_process: @@ -493,24 +510,14 @@ class ChatBot: if result and not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理") - # TODO:暂不可用 + # TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info # 确认从接口发来的message是否有自定义的prompt模板信息 - if message.message_info.template_info and not message.message_info.template_info.template_default: - template_group_name: str | None = message.message_info.template_info.template_name # type: ignore - template_items = message.message_info.template_info.template_items - async with global_prompt_manager.async_message_scope(template_group_name): - if isinstance(template_items, dict): - for k in template_items.keys(): - await create_prompt_async(template_items[k], k) - logger.debug(f"注册{template_items[k]},{k}") - else: - template_group_name = None + # 这个功能需要在 adapter 层通过 additional_config 传递 + template_group_name = None async def preprocess(): - # 使用统一的转换方法创建数据库消息对象 - db_message = message.to_database_message() - - group_info = getattr(message.chat_stream, "group_info", None) + # message 已经是 DatabaseMessages,直接使用 + group_info = chat.group_info # 先交给消息管理器处理,计算兴趣度等衍生数据 try: @@ -527,31 +534,15 @@ class ChatBot: should_process_in_manager = False if should_process_in_manager: - await message_manager.add_message(message.chat_stream.stream_id, db_message) - logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}") + await message_manager.add_message(chat.stream_id, message) + logger.debug(f"消息已添加到消息管理器: {chat.stream_id}") except Exception as e: logger.error(f"消息添加到消息管理器失败: {e}") - # 将兴趣度结果同步回原始消息,便于后续流程使用 - message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0)) - setattr( - message, - "should_reply", - getattr(db_message, "should_reply", getattr(message, "should_reply", False)), - ) - setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False))) - # 存储消息到数据库,只进行一次写入 try: - await MessageStorage.store_message(message, message.chat_stream) - logger.debug( - "消息已存储到数据库: %s (interest=%.3f, should_reply=%s, should_act=%s)", - message.message_info.message_id, - getattr(message, "interest_value", -1.0), - getattr(message, "should_reply", None), - getattr(message, "should_act", None), - ) + await MessageStorage.store_message(message, chat) except Exception as e: logger.error(f"存储消息到数据库失败: {e}") traceback.print_exc() @@ -560,13 +551,13 @@ class ChatBot: try: if global_config.mood.enable_mood: # 获取兴趣度用于情绪更新 - interest_rate = getattr(message, "interest_value", 0.0) + interest_rate = message.interest_value if interest_rate is None: interest_rate = 0.0 logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}") # 获取当前聊天的情绪对象并更新情绪状态 - chat_mood = mood_manager.get_mood_by_chat_id(message.chat_stream.stream_id) + chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) await chat_mood.update_mood_by_message(message, interest_rate) logger.debug("情绪状态更新完成") except Exception as e: diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index c22d755fb..a79b91400 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -12,13 +12,10 @@ from sqlalchemy.dialects.sqlite import insert as sqlite_insert from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config # 新增导入 -# 避免循环导入,使用TYPE_CHECKING进行类型提示 -if TYPE_CHECKING: - from .message import MessageRecv - install(extra_lines=3) @@ -33,7 +30,7 @@ class ChatStream: self, stream_id: str, platform: str, - user_info: UserInfo, + user_info: UserInfo | None = None, group_info: GroupInfo | None = None, data: dict | None = None, ): @@ -46,20 +43,18 @@ class ChatStream: self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0 self.saved = False - # 使用StreamContext替代ChatMessageContext + # 创建单流上下文管理器(包含StreamContext) + from src.chat.message_manager.context_manager import SingleStreamContextManager from src.common.data_models.message_manager_data_model import StreamContext from src.plugin_system.base.component_types import ChatMode, ChatType - # 创建StreamContext - self.stream_context: StreamContext = StreamContext( - stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL - ) - - # 创建单流上下文管理器 - from src.chat.message_manager.context_manager import SingleStreamContextManager - self.context_manager: SingleStreamContextManager = SingleStreamContextManager( - stream_id=stream_id, context=self.stream_context + stream_id=stream_id, + context=StreamContext( + stream_id=stream_id, + chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, + chat_mode=ChatMode.NORMAL, + ), ) # 基础参数 @@ -88,13 +83,12 @@ class ChatStream: new_stream._focus_energy = self._focus_energy new_stream.no_reply_consecutive = self.no_reply_consecutive - # 复制 stream_context,但跳过 processing_task - new_stream.stream_context = copy.deepcopy(self.stream_context, memo) - if hasattr(new_stream.stream_context, "processing_task"): - new_stream.stream_context.processing_task = None - - # 复制 context_manager + # 复制 context_manager(包含 stream_context) new_stream.context_manager = copy.deepcopy(self.context_manager, memo) + + # 清理 processing_task(如果存在) + if hasattr(new_stream.context_manager.context, "processing_task"): + new_stream.context_manager.context.processing_task = None return new_stream @@ -111,11 +105,11 @@ class ChatStream: "focus_energy": self.focus_energy, # 基础兴趣度 "base_interest_energy": self.base_interest_energy, - # stream_context基本信息 - "stream_context_chat_type": self.stream_context.chat_type.value, - "stream_context_chat_mode": self.stream_context.chat_mode.value, + # stream_context基本信息(通过context_manager访问) + "stream_context_chat_type": self.context_manager.context.chat_type.value, + "stream_context_chat_mode": self.context_manager.context.chat_mode.value, # 统计信息 - "interruption_count": self.stream_context.interruption_count, + "interruption_count": self.context_manager.context.interruption_count, } @classmethod @@ -132,27 +126,19 @@ class ChatStream: data=data, ) - # 恢复stream_context信息 + # 恢复stream_context信息(通过context_manager访问) if "stream_context_chat_type" in data: from src.plugin_system.base.component_types import ChatMode, ChatType - instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"]) + instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"]) if "stream_context_chat_mode" in data: from src.plugin_system.base.component_types import ChatMode, ChatType - instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"]) + instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"]) # 恢复interruption_count信息 if "interruption_count" in data: - instance.stream_context.interruption_count = data["interruption_count"] - - # 确保 context_manager 已初始化 - if not hasattr(instance, "context_manager"): - from src.chat.message_manager.context_manager import SingleStreamContextManager - - instance.context_manager = SingleStreamContextManager( - stream_id=instance.stream_id, context=instance.stream_context - ) + instance.context_manager.context.interruption_count = data["interruption_count"] return instance @@ -160,156 +146,44 @@ class ChatStream: """获取原始的、未哈希的聊天流ID字符串""" if self.group_info: return f"{self.platform}:{self.group_info.group_id}:group" - else: + elif self.user_info: return f"{self.platform}:{self.user_info.user_id}:private" + else: + return f"{self.platform}:unknown:private" def update_active_time(self): """更新最后活跃时间""" self.last_active_time = time.time() self.saved = False - async def set_context(self, message: "MessageRecv"): - """设置聊天消息上下文""" - # 将MessageRecv转换为DatabaseMessages并设置到stream_context - import json - - from src.common.data_models.database_data_model import DatabaseMessages - - # 安全获取message_info中的数据 - message_info = getattr(message, "message_info", {}) - user_info = getattr(message_info, "user_info", {}) - group_info = getattr(message_info, "group_info", {}) - - # 提取reply_to信息(从message_segment中查找reply类型的段) - reply_to = None - if hasattr(message, "message_segment") and message.message_segment: - reply_to = self._extract_reply_from_segment(message.message_segment) - - # 完整的数据转移逻辑 - db_message = DatabaseMessages( - # 基础消息信息 - message_id=getattr(message, "message_id", ""), - time=getattr(message, "time", time.time()), - chat_id=self._generate_chat_id(message_info), - reply_to=reply_to, - # 兴趣度相关 - interest_value=getattr(message, "interest_value", 0.0), - # 关键词 - key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False) - if getattr(message, "key_words", None) - else None, - key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False) - if getattr(message, "key_words_lite", None) - else None, - # 消息状态标记 - is_mentioned=getattr(message, "is_mentioned", None), - is_at=getattr(message, "is_at", False), - is_emoji=getattr(message, "is_emoji", False), - is_picid=getattr(message, "is_picid", False), - is_voice=getattr(message, "is_voice", False), - is_video=getattr(message, "is_video", False), - is_command=getattr(message, "is_command", False), - is_notify=getattr(message, "is_notify", False), - is_public_notice=getattr(message, "is_public_notice", False), - notice_type=getattr(message, "notice_type", None), - # 消息内容 - processed_plain_text=getattr(message, "processed_plain_text", ""), - display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text - # 优先级信息 - priority_mode=getattr(message, "priority_mode", None), - priority_info=json.dumps(getattr(message, "priority_info", None)) - if getattr(message, "priority_info", None) - else None, - # 额外配置 - 需要将 format_info 嵌入到 additional_config 中 - additional_config=self._prepare_additional_config(message_info), - # 用户信息 - user_id=str(getattr(user_info, "user_id", "")), - user_nickname=getattr(user_info, "user_nickname", ""), - user_cardname=getattr(user_info, "user_cardname", None), - user_platform=getattr(user_info, "platform", ""), - # 群组信息 - chat_info_group_id=getattr(group_info, "group_id", None), - chat_info_group_name=getattr(group_info, "group_name", None), - chat_info_group_platform=getattr(group_info, "platform", None), - # 聊天流信息 - chat_info_user_id=str(getattr(user_info, "user_id", "")), - chat_info_user_nickname=getattr(user_info, "user_nickname", ""), - chat_info_user_cardname=getattr(user_info, "user_cardname", None), - chat_info_user_platform=getattr(user_info, "platform", ""), - chat_info_stream_id=self.stream_id, - chat_info_platform=self.platform, - chat_info_create_time=self.create_time, - chat_info_last_active_time=self.last_active_time, - # 新增兴趣度系统字段 - 添加安全处理 - actions=self._safe_get_actions(message), - should_reply=getattr(message, "should_reply", False), - should_act=getattr(message, "should_act", False), - ) - - self.stream_context.set_current_message(db_message) - self.stream_context.priority_mode = getattr(message, "priority_mode", None) - self.stream_context.priority_info = getattr(message, "priority_info", None) - - # 调试日志:记录数据转移情况 - logger.debug( - f"消息数据转移完成 - message_id: {db_message.message_id}, " - f"chat_id: {db_message.chat_id}, " - f"is_mentioned: {db_message.is_mentioned}, " - f"is_emoji: {db_message.is_emoji}, " - f"is_picid: {db_message.is_picid}, " - f"interest_value: {db_message.interest_value}" - ) - - def _prepare_additional_config(self, message_info) -> str | None: - """ - 准备 additional_config,将 format_info 嵌入其中 - - 这个方法模仿 storage.py 中的逻辑,确保 DatabaseMessages 中的 additional_config - 包含 format_info,使得 action_modifier 能够正确获取适配器支持的消息类型 - + async def set_context(self, message: DatabaseMessages): + """设置聊天消息上下文 + Args: - message_info: BaseMessageInfo 对象 - - Returns: - str | None: JSON 字符串格式的 additional_config,如果为空则返回 None + message: DatabaseMessages 对象,直接使用不需要转换 """ - import orjson - - # 首先获取adapter传递的additional_config - additional_config_data = {} - if hasattr(message_info, 'additional_config') and message_info.additional_config: - if isinstance(message_info.additional_config, dict): - additional_config_data = message_info.additional_config.copy() - elif isinstance(message_info.additional_config, str): - # 如果是字符串,尝试解析 - try: - additional_config_data = orjson.loads(message_info.additional_config) - except Exception as e: - logger.warning(f"无法解析 additional_config JSON: {e}") - additional_config_data = {} + # 直接使用传入的 DatabaseMessages,设置到上下文中 + self.context_manager.context.set_current_message(message) - # 然后添加format_info到additional_config中 - if hasattr(message_info, 'format_info') and message_info.format_info: - try: - format_info_dict = message_info.format_info.to_dict() - additional_config_data["format_info"] = format_info_dict - logger.debug(f"嵌入 format_info 到 additional_config: {format_info_dict}") - except Exception as e: - logger.warning(f"将 format_info 转换为字典失败: {e}") - else: - logger.warning(f"[问题] 消息缺少 format_info: message_id={getattr(message_info, 'message_id', 'unknown')}") - logger.warning("[问题] 这可能导致 Action 无法正确检查适配器支持的类型") - - # 序列化为JSON字符串 - if additional_config_data: - try: - return orjson.dumps(additional_config_data).decode("utf-8") - except Exception as e: - logger.error(f"序列化 additional_config 失败: {e}") - return None - return None + # 设置优先级信息(如果存在) + priority_mode = getattr(message, "priority_mode", None) + priority_info = getattr(message, "priority_info", None) + if priority_mode: + self.context_manager.context.priority_mode = priority_mode + if priority_info: + self.context_manager.context.priority_info = priority_info - def _safe_get_actions(self, message: "MessageRecv") -> list | None: + # 调试日志 + logger.debug( + f"消息上下文已设置 - message_id: {message.message_id}, " + f"chat_id: {message.chat_id}, " + f"is_mentioned: {message.is_mentioned}, " + f"is_emoji: {message.is_emoji}, " + f"is_picid: {message.is_picid}, " + f"interest_value: {message.interest_value}" + ) + + def _safe_get_actions(self, message: DatabaseMessages) -> list | None: """安全获取消息的actions字段""" import json @@ -380,23 +254,6 @@ class ChatStream: if hasattr(db_message, "should_act"): db_message.should_act = False - def _extract_reply_from_segment(self, segment) -> str | None: - """从消息段中提取reply_to信息""" - try: - if hasattr(segment, "type") and segment.type == "seglist": - # 递归搜索seglist中的reply段 - if hasattr(segment, "data") and segment.data: - for seg in segment.data: - reply_id = self._extract_reply_from_segment(seg) - if reply_id: - return reply_id - elif hasattr(segment, "type") and segment.type == "reply": - # 找到reply段,返回message_id - return str(segment.data) if segment.data else None - except Exception as e: - logger.warning(f"提取reply_to信息失败: {e}") - return None - def _generate_chat_id(self, message_info) -> str: """生成chat_id,基于群组或用户信息""" try: @@ -493,8 +350,10 @@ class ChatManager: def __init__(self): if not self._initialized: + from src.common.data_models.database_data_model import DatabaseMessages + self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream - self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message + self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message # try: # async with get_db_session() as session: # db.connect(reuse_if_open=True) @@ -528,12 +387,30 @@ class ChatManager: except Exception as e: logger.error(f"聊天流自动保存失败: {e!s}") - def register_message(self, message: "MessageRecv"): + def register_message(self, message: DatabaseMessages): """注册消息到聊天流""" + # 从 DatabaseMessages 提取平台和用户/群组信息 + from maim_message import UserInfo, GroupInfo + + user_info = UserInfo( + platform=message.user_info.platform, + user_id=message.user_info.user_id, + user_nickname=message.user_info.user_nickname, + user_cardname=message.user_info.user_cardname or "" + ) + + group_info = None + if message.group_info: + group_info = GroupInfo( + platform=message.group_info.group_platform or "", + group_id=message.group_info.group_id, + group_name=message.group_info.group_name + ) + stream_id = self._generate_stream_id( - message.message_info.platform, # type: ignore - message.message_info.user_info, - message.message_info.group_info, + message.chat_info.platform, + user_info, + group_info, ) self.last_messages[stream_id] = message # logger.debug(f"注册消息到聊天流: {stream_id}") @@ -578,32 +455,6 @@ class ChatManager: try: stream_id = self._generate_stream_id(platform, user_info, group_info) - # 优先使用缓存管理器(优化版本) - try: - from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager - - cache_manager = get_stream_cache_manager() - - if cache_manager.is_running: - optimized_stream = await cache_manager.get_or_create_stream( - stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info - ) - - # 设置消息上下文 - from .message import MessageRecv - - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): - optimized_stream.set_context(self.last_messages[stream_id]) - - # 转换为原始ChatStream以保持兼容性 - original_stream = self._convert_to_original_stream(optimized_stream) - - return original_stream - - except Exception as e: - logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}") - - # 回退到原始方法 # 检查内存中是否存在 if stream_id in self.streams: stream = self.streams[stream_id] @@ -615,12 +466,13 @@ class ChatManager: stream.user_info = user_info if group_info: stream.group_info = group_info - from .message import MessageRecv # 延迟导入,避免循环引用 - - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): + + # 检查是否有最后一条消息(现在使用 DatabaseMessages) + from src.common.data_models.database_data_model import DatabaseMessages + if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): await stream.set_context(self.last_messages[stream_id]) else: - logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") + logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息") return stream # 检查数据库中是否存在 @@ -679,19 +531,27 @@ class ChatManager: raise e stream = copy.deepcopy(stream) - from .message import MessageRecv # 延迟导入,避免循环引用 + from src.common.data_models.database_data_model import DatabaseMessages - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): + if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): await stream.set_context(self.last_messages[stream_id]) else: - logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") + logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") # 确保 ChatStream 有自己的 context_manager if not hasattr(stream, "context_manager"): - # 创建新的单流上下文管理器 from src.chat.message_manager.context_manager import SingleStreamContextManager + from src.common.data_models.message_manager_data_model import StreamContext + from src.plugin_system.base.component_types import ChatMode, ChatType - stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context) + stream.context_manager = SingleStreamContextManager( + stream_id=stream_id, + context=StreamContext( + stream_id=stream_id, + chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE, + chat_mode=ChatMode.NORMAL, + ), + ) # 保存到内存和数据库 self.streams[stream_id] = stream @@ -700,10 +560,12 @@ class ChatManager: async def get_stream(self, stream_id: str) -> ChatStream | None: """通过stream_id获取聊天流""" + from src.common.data_models.database_data_model import DatabaseMessages + stream = self.streams.get(stream_id) if not stream: return None - if stream_id in self.last_messages: + if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): await stream.set_context(self.last_messages[stream_id]) return stream @@ -921,9 +783,16 @@ class ChatManager: # 确保 ChatStream 有自己的 context_manager if not hasattr(stream, "context_manager"): from src.chat.message_manager.context_manager import SingleStreamContextManager + from src.common.data_models.message_manager_data_model import StreamContext + from src.plugin_system.base.component_types import ChatMode, ChatType stream.context_manager = SingleStreamContextManager( - stream_id=stream.stream_id, context=stream.stream_context + stream_id=stream.stream_id, + context=StreamContext( + stream_id=stream.stream_id, + chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE, + chat_mode=ChatMode.NORMAL, + ), ) except Exception as e: logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True) @@ -932,46 +801,6 @@ class ChatManager: chat_manager = None -def _convert_to_original_stream(self, optimized_stream) -> "ChatStream": - """将OptimizedChatStream转换为原始ChatStream以保持兼容性""" - try: - # 创建原始ChatStream实例 - original_stream = ChatStream( - stream_id=optimized_stream.stream_id, - platform=optimized_stream.platform, - user_info=optimized_stream._get_effective_user_info(), - group_info=optimized_stream._get_effective_group_info(), - ) - - # 复制状态 - original_stream.create_time = optimized_stream.create_time - original_stream.last_active_time = optimized_stream.last_active_time - original_stream.sleep_pressure = optimized_stream.sleep_pressure - original_stream.base_interest_energy = optimized_stream.base_interest_energy - original_stream._focus_energy = optimized_stream._focus_energy - original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive - original_stream.saved = optimized_stream.saved - - # 复制上下文信息(如果存在) - if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context: - original_stream.stream_context = optimized_stream._stream_context - - if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager: - original_stream.context_manager = optimized_stream._context_manager - - return original_stream - - except Exception as e: - logger.error(f"转换OptimizedChatStream失败: {e}") - # 如果转换失败,创建一个新的原始流 - return ChatStream( - stream_id=optimized_stream.stream_id, - platform=optimized_stream.platform, - user_info=optimized_stream._get_effective_user_info(), - group_info=optimized_stream._get_effective_group_info(), - ) - - def get_chat_manager(): global chat_manager if chat_manager is None: diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index b7603b2f3..9cdaa337f 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -2,7 +2,7 @@ import base64 import time from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Optional, Union import urllib3 from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo @@ -13,6 +13,7 @@ from src.chat.utils.self_voice_cache import consume_self_voice_text from src.chat.utils.utils_image import get_image_manager from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available from src.chat.utils.utils_voice import get_voice_text +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config @@ -43,7 +44,7 @@ class Message(MessageBase, metaclass=ABCMeta): user_info: UserInfo, message_segment: Seg | None = None, timestamp: float | None = None, - reply: Optional["MessageRecv"] = None, + reply: Optional["DatabaseMessages"] = None, processed_plain_text: str = "", ): # 使用传入的时间戳或当前时间 @@ -95,346 +96,12 @@ class Message(MessageBase, metaclass=ABCMeta): @dataclass -class MessageRecv(Message): - """接收消息类,用于处理从MessageCQ序列化的消息""" - def __init__(self, message_dict: dict[str, Any]): - """从MessageCQ的字典初始化 - - Args: - message_dict: MessageCQ序列化后的字典 - """ - # Manually initialize attributes from MessageBase and Message - self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) - self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) - self.raw_message = message_dict.get("raw_message") - - self.chat_stream = None - self.reply = None - self.processed_plain_text = message_dict.get("processed_plain_text", "") - self.memorized_times = 0 - - # MessageRecv specific attributes - self.is_emoji = False - self.has_emoji = False - self.is_picid = False - self.has_picid = False - self.is_voice = False - self.is_video = False - self.is_mentioned = None - self.is_notify = False # 是否为notice消息 - self.is_public_notice = False # 是否为公共notice - self.notice_type = None # notice类型 - self.is_at = False - self.is_command = False - - self.priority_mode = "interest" - self.priority_info = None - self.interest_value: float = 0.0 - - self.key_words = [] - self.key_words_lite = [] - - # 解析additional_config中的notice信息 - if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict): - self.is_notify = self.message_info.additional_config.get("is_notice", False) - self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False) - self.notice_type = self.message_info.additional_config.get("notice_type") - - def update_chat_stream(self, chat_stream: "ChatStream"): - self.chat_stream = chat_stream - - def to_database_message(self) -> "DatabaseMessages": - """将 MessageRecv 转换为 DatabaseMessages 对象 - - Returns: - DatabaseMessages: 数据库消息对象 - """ - from src.common.data_models.database_data_model import DatabaseMessages - import json - import time - - message_info = self.message_info - msg_user_info = getattr(message_info, "user_info", None) - stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None - group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None - - message_id = message_info.message_id or "" - message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time() - is_mentioned = None - if isinstance(self.is_mentioned, bool): - is_mentioned = self.is_mentioned - elif isinstance(self.is_mentioned, int | float): - is_mentioned = self.is_mentioned != 0 - - # 提取用户信息 - user_id = "" - user_nickname = "" - user_cardname = None - user_platform = "" - if msg_user_info: - user_id = str(getattr(msg_user_info, "user_id", "") or "") - user_nickname = getattr(msg_user_info, "user_nickname", "") or "" - user_cardname = getattr(msg_user_info, "user_cardname", None) - user_platform = getattr(msg_user_info, "platform", "") or "" - elif stream_user_info: - user_id = str(getattr(stream_user_info, "user_id", "") or "") - user_nickname = getattr(stream_user_info, "user_nickname", "") or "" - user_cardname = getattr(stream_user_info, "user_cardname", None) - user_platform = getattr(stream_user_info, "platform", "") or "" - - # 提取聊天流信息 - chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else "" - chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else "" - chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None - chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else "" - - group_id = getattr(group_info, "group_id", None) if group_info else None - group_name = getattr(group_info, "group_name", None) if group_info else None - group_platform = getattr(group_info, "platform", None) if group_info else None - - # 准备 additional_config - additional_config_str = None - try: - import orjson - - additional_config_data = {} - - # 首先获取adapter传递的additional_config - if hasattr(message_info, 'additional_config') and message_info.additional_config: - if isinstance(message_info.additional_config, dict): - additional_config_data = message_info.additional_config.copy() - elif isinstance(message_info.additional_config, str): - try: - additional_config_data = orjson.loads(message_info.additional_config) - except Exception as e: - logger.warning(f"无法解析 additional_config JSON: {e}") - additional_config_data = {} - - # 添加notice相关标志 - if self.is_notify: - additional_config_data["is_notice"] = True - additional_config_data["notice_type"] = self.notice_type or "unknown" - additional_config_data["is_public_notice"] = bool(self.is_public_notice) - - # 添加format_info到additional_config中 - if hasattr(message_info, 'format_info') and message_info.format_info: - try: - format_info_dict = message_info.format_info.to_dict() - additional_config_data["format_info"] = format_info_dict - logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}") - except Exception as e: - logger.warning(f"将 format_info 转换为字典失败: {e}") - - # 序列化为JSON字符串 - if additional_config_data: - additional_config_str = orjson.dumps(additional_config_data).decode("utf-8") - except Exception as e: - logger.error(f"准备 additional_config 失败: {e}") - - # 创建数据库消息对象 - db_message = DatabaseMessages( - message_id=message_id, - time=float(message_time), - chat_id=self.chat_stream.stream_id if self.chat_stream else "", - processed_plain_text=self.processed_plain_text, - display_message=self.processed_plain_text, - is_mentioned=is_mentioned, - is_at=bool(self.is_at) if self.is_at is not None else None, - is_emoji=bool(self.is_emoji), - is_picid=bool(self.is_picid), - is_command=bool(self.is_command), - is_notify=bool(self.is_notify), - is_public_notice=bool(self.is_public_notice), - notice_type=self.notice_type, - additional_config=additional_config_str, - user_id=user_id, - user_nickname=user_nickname, - user_cardname=user_cardname, - user_platform=user_platform, - chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "", - chat_info_platform=self.chat_stream.platform if self.chat_stream else "", - chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0, - chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0, - chat_info_user_id=chat_user_id, - chat_info_user_nickname=chat_user_nickname, - chat_info_user_cardname=chat_user_cardname, - chat_info_user_platform=chat_user_platform, - chat_info_group_id=group_id, - chat_info_group_name=group_name, - chat_info_group_platform=group_platform, - ) - - # 同步兴趣度等衍生属性 - db_message.interest_value = getattr(self, "interest_value", 0.0) - setattr(db_message, "should_reply", getattr(self, "should_reply", False)) - setattr(db_message, "should_act", getattr(self, "should_act", False)) - - return db_message - - async def process(self) -> None: - """处理消息内容,生成纯文本和详细文本 - - 这个方法必须在创建实例后显式调用,因为它包含异步操作。 - """ - self.processed_plain_text = await self._process_message_segments(self.message_segment) - - async def _process_single_segment(self, segment: Seg) -> str: - """处理单个消息段 - - Args: - segment: 消息段 - - Returns: - str: 处理后的文本 - """ - try: - if segment.type == "text": - self.is_picid = False - self.is_emoji = False - self.is_video = False - return segment.data # type: ignore - elif segment.type == "at": - self.is_picid = False - self.is_emoji = False - self.is_video = False - # 处理at消息,格式为"昵称:QQ号" - if isinstance(segment.data, str) and ":" in segment.data: - nickname, qq_id = segment.data.split(":", 1) - return f"@{nickname}" - return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户" - elif segment.type == "image": - # 如果是base64图片数据 - if isinstance(segment.data, str): - self.has_picid = True - self.is_picid = True - self.is_emoji = False - self.is_video = False - image_manager = get_image_manager() - # print(f"segment.data: {segment.data}") - _, processed_text = await image_manager.process_image(segment.data) - return processed_text - return "[发了一张图片,网卡了加载不出来]" - elif segment.type == "emoji": - self.has_emoji = True - self.is_emoji = True - self.is_picid = False - self.is_voice = False - self.is_video = False - if isinstance(segment.data, str): - return await get_image_manager().get_emoji_description(segment.data) - return "[发了一个表情包,网卡了加载不出来]" - elif segment.type == "voice": - self.is_picid = False - self.is_emoji = False - self.is_voice = True - self.is_video = False - - # 检查消息是否由机器人自己发送 - if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account): - logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。") - if isinstance(segment.data, str): - cached_text = consume_self_voice_text(segment.data) - if cached_text: - logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'") - return f"[语音:{cached_text}]" - else: - logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。") - - # 标准语音识别流程 (也作为缓存未命中的后备方案) - if isinstance(segment.data, str): - return await get_voice_text(segment.data) - return "[发了一段语音,网卡了加载不出来]" - elif segment.type == "mention_bot": - self.is_picid = False - self.is_emoji = False - self.is_voice = False - self.is_video = False - self.is_mentioned = float(segment.data) # type: ignore - return "" - elif segment.type == "priority_info": - self.is_picid = False - self.is_emoji = False - self.is_voice = False - if isinstance(segment.data, dict): - # 处理优先级信息 - self.priority_mode = "priority" - self.priority_info = segment.data - """ - { - 'message_type': 'vip', # vip or normal - 'message_priority': 1.0, # 优先级,大为优先,float - } - """ - return "" - elif segment.type == "file": - if isinstance(segment.data, dict): - file_name = segment.data.get('name', '未知文件') - file_size = segment.data.get('size', '未知大小') - return f"[文件:{file_name} ({file_size}字节)]" - return "[收到一个文件]" - elif segment.type == "video": - self.is_picid = False - self.is_emoji = False - self.is_voice = False - self.is_video = True - logger.info(f"接收到视频消息,数据类型: {type(segment.data)}") - - # 检查视频分析功能是否可用 - if not is_video_analysis_available(): - logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析") - return "[视频]" - - if global_config.video_analysis.enable: - logger.info("已启用视频识别,开始识别") - if isinstance(segment.data, dict): - try: - # 从Adapter接收的视频数据 - video_base64 = segment.data.get("base64") - filename = segment.data.get("filename", "video.mp4") - - logger.info(f"视频文件名: {filename}") - logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}") - - if video_base64: - # 解码base64视频数据 - video_bytes = base64.b64decode(video_base64) - logger.info(f"解码后视频大小: {len(video_bytes)} 字节") - - # 使用video analyzer分析视频 - video_analyzer = get_video_analyzer() - result = await video_analyzer.analyze_video_from_bytes( - video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt - ) - - logger.info(f"视频分析结果: {result}") - - # 返回视频分析结果 - summary = result.get("summary", "") - if summary: - return f"[视频内容] {summary}" - else: - return "[已收到视频,但分析失败]" - else: - logger.warning("视频消息中没有base64数据") - return "[收到视频消息,但数据异常]" - except Exception as e: - logger.error(f"视频处理失败: {e!s}") - import traceback - - logger.error(f"错误详情: {traceback.format_exc()}") - return "[收到视频,但处理时出现错误]" - else: - logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}") - return "[发了一个视频,但格式不支持]" - else: - return "" - else: - logger.warning(f"未知的消息段类型: {segment.type}") - return f"[{segment.type} 消息]" - except Exception as e: - logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") - return f"[处理失败的{segment.type}消息]" +# MessageRecv 类已被完全移除,现在统一使用 DatabaseMessages +# 如需从消息字典创建 DatabaseMessages,请使用: +# from src.chat.message_receive.message_processor import process_message_from_dict +# +# 迁移完成日期: 2025-10-31 @dataclass @@ -447,7 +114,7 @@ class MessageProcessBase(Message): chat_stream: "ChatStream", bot_user_info: UserInfo, message_segment: Seg | None = None, - reply: Optional["MessageRecv"] = None, + reply: Optional["DatabaseMessages"] = None, thinking_start_time: float = 0, timestamp: float | None = None, ): @@ -548,7 +215,7 @@ class MessageSending(MessageProcessBase): sender_info: UserInfo | None, # 用来记录发送者信息 message_segment: Seg, display_message: str = "", - reply: Optional["MessageRecv"] = None, + reply: Optional["DatabaseMessages"] = None, is_head: bool = False, is_emoji: bool = False, thinking_start_time: float = 0, @@ -567,7 +234,11 @@ class MessageSending(MessageProcessBase): # 发送状态特有属性 self.sender_info = sender_info - self.reply_to_message_id = reply.message_info.message_id if reply else None + # 从 DatabaseMessages 获取 message_id + if reply: + self.reply_to_message_id = reply.message_id + else: + self.reply_to_message_id = None self.is_head = is_head self.is_emoji = is_emoji self.apply_set_reply_logic = apply_set_reply_logic @@ -582,14 +253,18 @@ class MessageSending(MessageProcessBase): def build_reply(self): """设置回复消息""" if self.reply: - self.reply_to_message_id = self.reply.message_info.message_id - self.message_segment = Seg( - type="seglist", - data=[ - Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore - self.message_segment, - ], - ) + # 从 DatabaseMessages 获取 message_id + message_id = self.reply.message_id + + if message_id: + self.reply_to_message_id = message_id + self.message_segment = Seg( + type="seglist", + data=[ + Seg(type="reply", data=message_id), # type: ignore + self.message_segment, + ], + ) async def process(self) -> None: """处理消息内容,生成纯文本和详细文本""" @@ -607,48 +282,5 @@ class MessageSending(MessageProcessBase): return self.message_info.group_info is None or self.message_info.group_info.group_id is None -def message_recv_from_dict(message_dict: dict) -> MessageRecv: - return MessageRecv(message_dict) - -def message_from_db_dict(db_dict: dict) -> MessageRecv: - """从数据库字典创建MessageRecv实例""" - # 转换扁平的数据库字典为嵌套结构 - message_info_dict = { - "platform": db_dict.get("chat_info_platform"), - "message_id": db_dict.get("message_id"), - "time": db_dict.get("time"), - "group_info": { - "platform": db_dict.get("chat_info_group_platform"), - "group_id": db_dict.get("chat_info_group_id"), - "group_name": db_dict.get("chat_info_group_name"), - }, - "user_info": { - "platform": db_dict.get("user_platform"), - "user_id": db_dict.get("user_id"), - "user_nickname": db_dict.get("user_nickname"), - "user_cardname": db_dict.get("user_cardname"), - }, - } - - processed_text = db_dict.get("processed_plain_text", "") - - # 构建 MessageRecv 需要的字典 - recv_dict = { - "message_info": message_info_dict, - "message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段 - "raw_message": None, # 数据库中未存储原始消息 - "processed_plain_text": processed_text, - } - - # 创建 MessageRecv 实例 - msg = MessageRecv(recv_dict) - - # 从数据库字典中填充其他可选字段 - msg.interest_value = db_dict.get("interest_value", 0.0) - msg.is_mentioned = db_dict.get("is_mentioned") - msg.priority_mode = db_dict.get("priority_mode", "interest") - msg.priority_info = db_dict.get("priority_info") - msg.is_emoji = db_dict.get("is_emoji", False) - msg.is_picid = db_dict.get("is_picid", False) - - return msg +# message_recv_from_dict 和 message_from_db_dict 函数已被移除 +# 请使用: from src.chat.message_receive.message_processor import process_message_from_dict diff --git a/src/chat/message_receive/message_processor.py b/src/chat/message_receive/message_processor.py new file mode 100644 index 000000000..5da582710 --- /dev/null +++ b/src/chat/message_receive/message_processor.py @@ -0,0 +1,493 @@ +"""消息处理工具模块 +将原 MessageRecv 的消息处理逻辑提取为独立函数, +直接从适配器消息字典生成 DatabaseMessages +""" +import base64 +import time +from typing import Any + +import orjson +from maim_message import BaseMessageInfo, Seg + +from src.chat.utils.self_voice_cache import consume_self_voice_text +from src.chat.utils.utils_image import get_image_manager +from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available +from src.chat.utils.utils_voice import get_voice_text +from src.common.data_models.database_data_model import DatabaseMessages +from src.common.logger import get_logger +from src.config.config import global_config + +logger = get_logger("message_processor") + + +async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages: + """从适配器消息字典处理并生成 DatabaseMessages + + 这个函数整合了原 MessageRecv 的所有处理逻辑: + 1. 解析 message_segment 并异步处理内容(图片、语音、视频等) + 2. 提取所有消息元数据 + 3. 直接构造 DatabaseMessages 对象 + + Args: + message_dict: MessageCQ序列化后的字典 + stream_id: 聊天流ID + platform: 平台标识 + + Returns: + DatabaseMessages: 处理完成的数据库消息对象 + """ + # 解析基础信息 + message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) + message_segment = Seg.from_dict(message_dict.get("message_segment", {})) + + # 初始化处理状态 + processing_state = { + "is_emoji": False, + "has_emoji": False, + "is_picid": False, + "has_picid": False, + "is_voice": False, + "is_video": False, + "is_mentioned": None, + "is_at": False, + "priority_mode": "interest", + "priority_info": None, + } + + # 异步处理消息段,生成纯文本 + processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info) + + # 解析 notice 信息 + is_notify = False + is_public_notice = False + notice_type = None + if message_info.additional_config and isinstance(message_info.additional_config, dict): + is_notify = message_info.additional_config.get("is_notice", False) + is_public_notice = message_info.additional_config.get("is_public_notice", False) + notice_type = message_info.additional_config.get("notice_type") + + # 提取用户信息 + user_info = message_info.user_info + user_id = str(user_info.user_id) if user_info and user_info.user_id else "" + user_nickname = (user_info.user_nickname or "") if user_info else "" + user_cardname = user_info.user_cardname if user_info else None + user_platform = (user_info.platform or "") if user_info else "" + + # 提取群组信息 + group_info = message_info.group_info + group_id = group_info.group_id if group_info else None + group_name = group_info.group_name if group_info else None + group_platform = group_info.platform if group_info else None + + # 生成 chat_id + if group_info and group_id: + chat_id = f"{platform}_{group_id}" + elif user_info and user_info.user_id: + chat_id = f"{platform}_{user_info.user_id}_private" + else: + chat_id = stream_id + + # 准备 additional_config + additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type) + + # 提取 reply_to + reply_to = _extract_reply_from_segment(message_segment) + + # 构造 DatabaseMessages + message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time() + message_id = message_info.message_id or "" + + # 处理 is_mentioned + is_mentioned = None + mentioned_value = processing_state.get("is_mentioned") + if isinstance(mentioned_value, bool): + is_mentioned = mentioned_value + elif isinstance(mentioned_value, (int, float)): + is_mentioned = mentioned_value != 0 + + db_message = DatabaseMessages( + message_id=message_id, + time=float(message_time), + chat_id=chat_id, + reply_to=reply_to, + processed_plain_text=processed_plain_text, + display_message=processed_plain_text, + is_mentioned=is_mentioned, + is_at=bool(processing_state.get("is_at", False)), + is_emoji=bool(processing_state.get("is_emoji", False)), + is_picid=bool(processing_state.get("is_picid", False)), + is_command=False, # 将在后续处理中设置 + is_notify=bool(is_notify), + is_public_notice=bool(is_public_notice), + notice_type=notice_type, + additional_config=additional_config_str, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + user_platform=user_platform, + chat_info_stream_id=stream_id, + chat_info_platform=platform, + chat_info_create_time=0.0, # 将由 ChatStream 填充 + chat_info_last_active_time=0.0, # 将由 ChatStream 填充 + chat_info_user_id=user_id, + chat_info_user_nickname=user_nickname, + chat_info_user_cardname=user_cardname, + chat_info_user_platform=user_platform, + chat_info_group_id=group_id, + chat_info_group_name=group_name, + chat_info_group_platform=group_platform, + ) + + # 设置优先级信息 + if processing_state.get("priority_mode"): + setattr(db_message, "priority_mode", processing_state["priority_mode"]) + if processing_state.get("priority_info"): + setattr(db_message, "priority_info", processing_state["priority_info"]) + + # 设置其他运行时属性 + setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False))) + setattr(db_message, "is_video", bool(processing_state.get("is_video", False))) + setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False))) + setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False))) + + return db_message + + +async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str: + """递归处理消息段,转换为文字描述 + + Args: + segment: 要处理的消息段 + state: 处理状态字典(用于记录消息类型标记) + message_info: 消息基础信息(用于某些处理逻辑) + + Returns: + str: 处理后的文本 + """ + if segment.type == "seglist": + # 处理消息段列表 + segments_text = [] + for seg in segment.data: + processed = await _process_message_segments(seg, state, message_info) + if processed: + segments_text.append(processed) + return " ".join(segments_text) + else: + # 处理单个消息段 + return await _process_single_segment(segment, state, message_info) + + +async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str: + """处理单个消息段 + + Args: + segment: 消息段 + state: 处理状态字典 + message_info: 消息基础信息 + + Returns: + str: 处理后的文本 + """ + try: + if segment.type == "text": + state["is_picid"] = False + state["is_emoji"] = False + state["is_video"] = False + return segment.data + + elif segment.type == "at": + state["is_picid"] = False + state["is_emoji"] = False + state["is_video"] = False + state["is_at"] = True + # 处理at消息,格式为"昵称:QQ号" + if isinstance(segment.data, str) and ":" in segment.data: + nickname, qq_id = segment.data.split(":", 1) + return f"@{nickname}" + return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户" + + elif segment.type == "image": + # 如果是base64图片数据 + if isinstance(segment.data, str): + state["has_picid"] = True + state["is_picid"] = True + state["is_emoji"] = False + state["is_video"] = False + image_manager = get_image_manager() + _, processed_text = await image_manager.process_image(segment.data) + return processed_text + return "[发了一张图片,网卡了加载不出来]" + + elif segment.type == "emoji": + state["has_emoji"] = True + state["is_emoji"] = True + state["is_picid"] = False + state["is_voice"] = False + state["is_video"] = False + if isinstance(segment.data, str): + return await get_image_manager().get_emoji_description(segment.data) + return "[发了一个表情包,网卡了加载不出来]" + + elif segment.type == "voice": + state["is_picid"] = False + state["is_emoji"] = False + state["is_voice"] = True + state["is_video"] = False + + # 检查消息是否由机器人自己发送 + if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account): + logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。") + if isinstance(segment.data, str): + cached_text = consume_self_voice_text(segment.data) + if cached_text: + logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'") + return f"[语音:{cached_text}]" + else: + logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。") + + # 标准语音识别流程 + if isinstance(segment.data, str): + return await get_voice_text(segment.data) + return "[发了一段语音,网卡了加载不出来]" + + elif segment.type == "mention_bot": + state["is_picid"] = False + state["is_emoji"] = False + state["is_voice"] = False + state["is_video"] = False + state["is_mentioned"] = float(segment.data) + return "" + + elif segment.type == "priority_info": + state["is_picid"] = False + state["is_emoji"] = False + state["is_voice"] = False + if isinstance(segment.data, dict): + # 处理优先级信息 + state["priority_mode"] = "priority" + state["priority_info"] = segment.data + return "" + + elif segment.type == "file": + if isinstance(segment.data, dict): + file_name = segment.data.get('name', '未知文件') + file_size = segment.data.get('size', '未知大小') + return f"[文件:{file_name} ({file_size}字节)]" + return "[收到一个文件]" + + elif segment.type == "video": + state["is_picid"] = False + state["is_emoji"] = False + state["is_voice"] = False + state["is_video"] = True + logger.info(f"接收到视频消息,数据类型: {type(segment.data)}") + + # 检查视频分析功能是否可用 + if not is_video_analysis_available(): + logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析") + return "[视频]" + + if global_config.video_analysis.enable: + logger.info("已启用视频识别,开始识别") + if isinstance(segment.data, dict): + try: + # 从Adapter接收的视频数据 + video_base64 = segment.data.get("base64") + filename = segment.data.get("filename", "video.mp4") + + logger.info(f"视频文件名: {filename}") + logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}") + + if video_base64: + # 解码base64视频数据 + video_bytes = base64.b64decode(video_base64) + logger.info(f"解码后视频大小: {len(video_bytes)} 字节") + + # 使用video analyzer分析视频 + video_analyzer = get_video_analyzer() + result = await video_analyzer.analyze_video_from_bytes( + video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt + ) + + logger.info(f"视频分析结果: {result}") + + # 返回视频分析结果 + summary = result.get("summary", "") + if summary: + return f"[视频内容] {summary}" + else: + return "[已收到视频,但分析失败]" + else: + logger.warning("视频消息中没有base64数据") + return "[收到视频消息,但数据异常]" + except Exception as e: + logger.error(f"视频处理失败: {e!s}") + import traceback + logger.error(f"错误详情: {traceback.format_exc()}") + return "[收到视频,但处理时出现错误]" + else: + logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}") + return "[发了一个视频,但格式不支持]" + else: + return "" + else: + logger.warning(f"未知的消息段类型: {segment.type}") + return f"[{segment.type} 消息]" + + except Exception as e: + logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") + return f"[处理失败的{segment.type}消息]" + + +def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None: + """准备 additional_config,包含 format_info 和 notice 信息 + + Args: + message_info: 消息基础信息 + is_notify: 是否为notice消息 + is_public_notice: 是否为公共notice + notice_type: notice类型 + + Returns: + str | None: JSON 字符串格式的 additional_config,如果为空则返回 None + """ + try: + additional_config_data = {} + + # 首先获取adapter传递的additional_config + if hasattr(message_info, 'additional_config') and message_info.additional_config: + if isinstance(message_info.additional_config, dict): + additional_config_data = message_info.additional_config.copy() + elif isinstance(message_info.additional_config, str): + try: + additional_config_data = orjson.loads(message_info.additional_config) + except Exception as e: + logger.warning(f"无法解析 additional_config JSON: {e}") + additional_config_data = {} + + # 添加notice相关标志 + if is_notify: + additional_config_data["is_notice"] = True + additional_config_data["notice_type"] = notice_type or "unknown" + additional_config_data["is_public_notice"] = bool(is_public_notice) + + # 添加format_info到additional_config中 + if hasattr(message_info, 'format_info') and message_info.format_info: + try: + format_info_dict = message_info.format_info.to_dict() + additional_config_data["format_info"] = format_info_dict + logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}") + except Exception as e: + logger.warning(f"将 format_info 转换为字典失败: {e}") + + # 序列化为JSON字符串 + if additional_config_data: + return orjson.dumps(additional_config_data).decode("utf-8") + except Exception as e: + logger.error(f"准备 additional_config 失败: {e}") + + return None + + +def _extract_reply_from_segment(segment: Seg) -> str | None: + """从消息段中提取reply_to信息 + + Args: + segment: 消息段 + + Returns: + str | None: 回复的消息ID,如果没有则返回None + """ + try: + if hasattr(segment, "type") and segment.type == "seglist": + # 递归搜索seglist中的reply段 + if hasattr(segment, "data") and segment.data: + for seg in segment.data: + reply_id = _extract_reply_from_segment(seg) + if reply_id: + return reply_id + elif hasattr(segment, "type") and segment.type == "reply": + # 找到reply段,返回message_id + return str(segment.data) if segment.data else None + except Exception as e: + logger.warning(f"提取reply_to信息失败: {e}") + return None + + +# ============================================================================= +# DatabaseMessages 扩展工具函数 +# ============================================================================= + +def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo: + """从 DatabaseMessages 重建 BaseMessageInfo(用于需要 message_info 的遗留代码) + + Args: + db_message: DatabaseMessages 对象 + + Returns: + BaseMessageInfo: 重建的消息信息对象 + """ + from maim_message import UserInfo, GroupInfo + + # 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo + user_info = UserInfo( + platform=db_message.user_info.platform, + user_id=db_message.user_info.user_id, + user_nickname=db_message.user_info.user_nickname, + user_cardname=db_message.user_info.user_cardname or "" + ) + + # 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo(如果存在) + group_info = None + if db_message.group_info: + group_info = GroupInfo( + platform=db_message.group_info.group_platform or "", + group_id=db_message.group_info.group_id, + group_name=db_message.group_info.group_name + ) + + # 解析 additional_config(从 JSON 字符串到字典) + additional_config = None + if db_message.additional_config: + try: + additional_config = orjson.loads(db_message.additional_config) + except Exception: + # 如果解析失败,保持为字符串 + pass + + # 创建 BaseMessageInfo + message_info = BaseMessageInfo( + platform=db_message.chat_info.platform, + message_id=db_message.message_id, + time=db_message.time, + user_info=user_info, + group_info=group_info, + additional_config=additional_config # type: ignore + ) + + return message_info + + +def set_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, value: Any) -> None: + """安全地为 DatabaseMessages 设置运行时属性 + + Args: + db_message: DatabaseMessages 对象 + attr_name: 属性名 + value: 属性值 + """ + setattr(db_message, attr_name, value) + + +def get_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, default: Any = None) -> Any: + """安全地获取 DatabaseMessages 的运行时属性 + + Args: + db_message: DatabaseMessages 对象 + attr_name: 属性名 + default: 默认值 + + Returns: + 属性值或默认值 + """ + return getattr(db_message, attr_name, default) diff --git a/src/chat/message_receive/message_recv_backup.py b/src/chat/message_receive/message_recv_backup.py new file mode 100644 index 000000000..3f1943ebf --- /dev/null +++ b/src/chat/message_receive/message_recv_backup.py @@ -0,0 +1,434 @@ +# MessageRecv 类备份 - 已从 message.py 中移除 +# 备份日期: 2025-10-31 +# 此类已被 DatabaseMessages 完全取代 + +# MessageRecv 类已被移除 +# 现在所有消息处理都使用 DatabaseMessages +# 如果需要从消息字典创建 DatabaseMessages,请使用 message_processor.process_message_from_dict() +# +# 历史参考: MessageRecv 曾经是接收消息的包装类,现已被 DatabaseMessages 完全取代 +# 迁移完成日期: 2025-10-31 + +""" +# 以下是已删除的 MessageRecv 类(保留作为参考) +class MessageRecv: + 接收消息类 - DatabaseMessages 的轻量级包装器 + + 这个类现在主要作为适配器层,处理外部消息格式并内部使用 DatabaseMessages。 + 保留此类是为了向后兼容性和处理 message_segment 的异步逻辑。 +""" + + def __init__(self, message_dict: dict[str, Any]): + """从MessageCQ的字典初始化 + + Args: + message_dict: MessageCQ序列化后的字典 + """ + # 保留原始消息信息用于某些场景 + self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) + self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) + self.raw_message = message_dict.get("raw_message") + + # 处理状态(在process()之前临时使用) + self._processing_state = { + "is_emoji": False, + "has_emoji": False, + "is_picid": False, + "has_picid": False, + "is_voice": False, + "is_video": False, + "is_mentioned": None, + "is_at": False, + "priority_mode": "interest", + "priority_info": None, + } + + self.chat_stream = None + self.reply = None + self.processed_plain_text = message_dict.get("processed_plain_text", "") + + # 解析additional_config中的notice信息 + self.is_notify = False + self.is_public_notice = False + self.notice_type = None + if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict): + self.is_notify = self.message_info.additional_config.get("is_notice", False) + self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False) + self.notice_type = self.message_info.additional_config.get("notice_type") + + # 兼容性属性 - 代理到 _processing_state + @property + def is_emoji(self) -> bool: + return self._processing_state["is_emoji"] + + @is_emoji.setter + def is_emoji(self, value: bool): + self._processing_state["is_emoji"] = value + + @property + def has_emoji(self) -> bool: + return self._processing_state["has_emoji"] + + @has_emoji.setter + def has_emoji(self, value: bool): + self._processing_state["has_emoji"] = value + + @property + def is_picid(self) -> bool: + return self._processing_state["is_picid"] + + @is_picid.setter + def is_picid(self, value: bool): + self._processing_state["is_picid"] = value + + @property + def has_picid(self) -> bool: + return self._processing_state["has_picid"] + + @has_picid.setter + def has_picid(self, value: bool): + self._processing_state["has_picid"] = value + + @property + def is_voice(self) -> bool: + return self._processing_state["is_voice"] + + @is_voice.setter + def is_voice(self, value: bool): + self._processing_state["is_voice"] = value + + @property + def is_video(self) -> bool: + return self._processing_state["is_video"] + + @is_video.setter + def is_video(self, value: bool): + self._processing_state["is_video"] = value + + @property + def is_mentioned(self): + return self._processing_state["is_mentioned"] + + @is_mentioned.setter + def is_mentioned(self, value): + self._processing_state["is_mentioned"] = value + + @property + def is_at(self) -> bool: + return self._processing_state["is_at"] + + @is_at.setter + def is_at(self, value: bool): + self._processing_state["is_at"] = value + + @property + def priority_mode(self) -> str: + return self._processing_state["priority_mode"] + + @priority_mode.setter + def priority_mode(self, value: str): + self._processing_state["priority_mode"] = value + + @property + def priority_info(self): + return self._processing_state["priority_info"] + + @priority_info.setter + def priority_info(self, value): + self._processing_state["priority_info"] = value + + # 其他常用属性 + interest_value: float = 0.0 + is_command: bool = False + memorized_times: int = 0 + + def __post_init__(self): + """dataclass 初始化后处理""" + self.key_words = [] + self.key_words_lite = [] + + def update_chat_stream(self, chat_stream: "ChatStream"): + self.chat_stream = chat_stream + + def to_database_message(self) -> "DatabaseMessages": + """将 MessageRecv 转换为 DatabaseMessages 对象 + + Returns: + DatabaseMessages: 数据库消息对象 + """ + import time + + message_info = self.message_info + msg_user_info = getattr(message_info, "user_info", None) + stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None + group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None + + message_id = message_info.message_id or "" + message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time() + is_mentioned = None + if isinstance(self.is_mentioned, bool): + is_mentioned = self.is_mentioned + elif isinstance(self.is_mentioned, int | float): + is_mentioned = self.is_mentioned != 0 + + # 提取用户信息 + user_id = "" + user_nickname = "" + user_cardname = None + user_platform = "" + if msg_user_info: + user_id = str(getattr(msg_user_info, "user_id", "") or "") + user_nickname = getattr(msg_user_info, "user_nickname", "") or "" + user_cardname = getattr(msg_user_info, "user_cardname", None) + user_platform = getattr(msg_user_info, "platform", "") or "" + elif stream_user_info: + user_id = str(getattr(stream_user_info, "user_id", "") or "") + user_nickname = getattr(stream_user_info, "user_nickname", "") or "" + user_cardname = getattr(stream_user_info, "user_cardname", None) + user_platform = getattr(stream_user_info, "platform", "") or "" + + # 提取聊天流信息 + chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else "" + chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else "" + chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None + chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else "" + + group_id = getattr(group_info, "group_id", None) if group_info else None + group_name = getattr(group_info, "group_name", None) if group_info else None + group_platform = getattr(group_info, "platform", None) if group_info else None + + # 准备 additional_config + additional_config_str = None + try: + import orjson + + additional_config_data = {} + + # 首先获取adapter传递的additional_config + if hasattr(message_info, 'additional_config') and message_info.additional_config: + if isinstance(message_info.additional_config, dict): + additional_config_data = message_info.additional_config.copy() + elif isinstance(message_info.additional_config, str): + try: + additional_config_data = orjson.loads(message_info.additional_config) + except Exception as e: + logger.warning(f"无法解析 additional_config JSON: {e}") + additional_config_data = {} + + # 添加notice相关标志 + if self.is_notify: + additional_config_data["is_notice"] = True + additional_config_data["notice_type"] = self.notice_type or "unknown" + additional_config_data["is_public_notice"] = bool(self.is_public_notice) + + # 添加format_info到additional_config中 + if hasattr(message_info, 'format_info') and message_info.format_info: + try: + format_info_dict = message_info.format_info.to_dict() + additional_config_data["format_info"] = format_info_dict + logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}") + except Exception as e: + logger.warning(f"将 format_info 转换为字典失败: {e}") + + # 序列化为JSON字符串 + if additional_config_data: + additional_config_str = orjson.dumps(additional_config_data).decode("utf-8") + except Exception as e: + logger.error(f"准备 additional_config 失败: {e}") + + # 创建数据库消息对象 + db_message = DatabaseMessages( + message_id=message_id, + time=float(message_time), + chat_id=self.chat_stream.stream_id if self.chat_stream else "", + processed_plain_text=self.processed_plain_text, + display_message=self.processed_plain_text, + is_mentioned=is_mentioned, + is_at=bool(self.is_at) if self.is_at is not None else None, + is_emoji=bool(self.is_emoji), + is_picid=bool(self.is_picid), + is_command=bool(self.is_command), + is_notify=bool(self.is_notify), + is_public_notice=bool(self.is_public_notice), + notice_type=self.notice_type, + additional_config=additional_config_str, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + user_platform=user_platform, + chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "", + chat_info_platform=self.chat_stream.platform if self.chat_stream else "", + chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0, + chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0, + chat_info_user_id=chat_user_id, + chat_info_user_nickname=chat_user_nickname, + chat_info_user_cardname=chat_user_cardname, + chat_info_user_platform=chat_user_platform, + chat_info_group_id=group_id, + chat_info_group_name=group_name, + chat_info_group_platform=group_platform, + ) + + # 同步兴趣度等衍生属性 + db_message.interest_value = getattr(self, "interest_value", 0.0) + setattr(db_message, "should_reply", getattr(self, "should_reply", False)) + setattr(db_message, "should_act", getattr(self, "should_act", False)) + + return db_message + + async def process(self) -> None: + """处理消息内容,生成纯文本和详细文本 + + 这个方法必须在创建实例后显式调用,因为它包含异步操作。 + """ + self.processed_plain_text = await self._process_message_segments(self.message_segment) + + async def _process_single_segment(self, segment: Seg) -> str: + """处理单个消息段 + + Args: + segment: 消息段 + + Returns: + str: 处理后的文本 + """ + try: + if segment.type == "text": + self.is_picid = False + self.is_emoji = False + self.is_video = False + return segment.data # type: ignore + elif segment.type == "at": + self.is_picid = False + self.is_emoji = False + self.is_video = False + # 处理at消息,格式为"昵称:QQ号" + if isinstance(segment.data, str) and ":" in segment.data: + nickname, qq_id = segment.data.split(":", 1) + return f"@{nickname}" + return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户" + elif segment.type == "image": + # 如果是base64图片数据 + if isinstance(segment.data, str): + self.has_picid = True + self.is_picid = True + self.is_emoji = False + self.is_video = False + image_manager = get_image_manager() + # print(f"segment.data: {segment.data}") + _, processed_text = await image_manager.process_image(segment.data) + return processed_text + return "[发了一张图片,网卡了加载不出来]" + elif segment.type == "emoji": + self.has_emoji = True + self.is_emoji = True + self.is_picid = False + self.is_voice = False + self.is_video = False + if isinstance(segment.data, str): + return await get_image_manager().get_emoji_description(segment.data) + return "[发了一个表情包,网卡了加载不出来]" + elif segment.type == "voice": + self.is_picid = False + self.is_emoji = False + self.is_voice = True + self.is_video = False + + # 检查消息是否由机器人自己发送 + if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account): + logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。") + if isinstance(segment.data, str): + cached_text = consume_self_voice_text(segment.data) + if cached_text: + logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'") + return f"[语音:{cached_text}]" + else: + logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。") + + # 标准语音识别流程 (也作为缓存未命中的后备方案) + if isinstance(segment.data, str): + return await get_voice_text(segment.data) + return "[发了一段语音,网卡了加载不出来]" + elif segment.type == "mention_bot": + self.is_picid = False + self.is_emoji = False + self.is_voice = False + self.is_video = False + self.is_mentioned = float(segment.data) # type: ignore + return "" + elif segment.type == "priority_info": + self.is_picid = False + self.is_emoji = False + self.is_voice = False + if isinstance(segment.data, dict): + # 处理优先级信息 + self.priority_mode = "priority" + self.priority_info = segment.data + """ + { + 'message_type': 'vip', # vip or normal + 'message_priority': 1.0, # 优先级,大为优先,float + } + """ + return "" + elif segment.type == "file": + if isinstance(segment.data, dict): + file_name = segment.data.get('name', '未知文件') + file_size = segment.data.get('size', '未知大小') + return f"[文件:{file_name} ({file_size}字节)]" + return "[收到一个文件]" + elif segment.type == "video": + self.is_picid = False + self.is_emoji = False + self.is_voice = False + self.is_video = True + logger.info(f"接收到视频消息,数据类型: {type(segment.data)}") + + # 检查视频分析功能是否可用 + if not is_video_analysis_available(): + logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析") + return "[视频]" + + if global_config.video_analysis.enable: + logger.info("已启用视频识别,开始识别") + if isinstance(segment.data, dict): + try: + # 从Adapter接收的视频数据 + video_base64 = segment.data.get("base64") + filename = segment.data.get("filename", "video.mp4") + + logger.info(f"视频文件名: {filename}") + logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}") + + if video_base64: + # 解码base64视频数据 + video_bytes = base64.b64decode(video_base64) + logger.info(f"解码后视频大小: {len(video_bytes)} 字节") + + # 使用video analyzer分析视频 + video_analyzer = get_video_analyzer() + result = await video_analyzer.analyze_video_from_bytes( + video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt + ) + + logger.info(f"视频分析结果: {result}") + + # 返回视频分析结果 + summary = result.get("summary", "") + if summary: + return f"[视频内容] {summary}" + else: + return "[已收到视频,但分析失败]" + else: + logger.warning("视频消息中没有base64数据") + return "[收到视频消息,但数据异常]" + except Exception as e: + logger.error(f"视频处理失败: {e!s}") + import traceback + + logger.error(f"错误详情: {traceback.format_exc()}") + return "[收到视频,但处理时出现错误]" + else: + logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}") + return "[发了一个视频,但格式不支持]" + else: diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index edf9bb9c8..9b2b54991 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -9,8 +9,10 @@ from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import Images, Messages from src.common.logger import get_logger +from src.common.data_models.database_data_model import DatabaseMessages + from .chat_stream import ChatStream -from .message import MessageRecv, MessageSending +from .message import MessageSending logger = get_logger("message_storage") @@ -34,97 +36,166 @@ class MessageStorage: return [] @staticmethod - async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None: + async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None: """存储消息到数据库""" try: # 过滤敏感信息的正则模式 pattern = r".*?|.*?|.*?" - processed_plain_text = message.processed_plain_text - - if processed_plain_text: - processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) - # 增加对None的防御性处理 - safe_processed_plain_text = processed_plain_text or "" - filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) - else: - filtered_processed_plain_text = "" - - if isinstance(message, MessageSending): - display_message = message.display_message - if display_message: - filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) + # 如果是 DatabaseMessages,直接使用它的字段 + if isinstance(message, DatabaseMessages): + processed_plain_text = message.processed_plain_text + if processed_plain_text: + processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) + safe_processed_plain_text = processed_plain_text or "" + filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) else: - # 如果没有设置display_message,使用processed_plain_text作为显示消息 - filtered_display_message = ( - re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL) - ) - interest_value = 0 - is_mentioned = False - reply_to = message.reply_to - priority_mode = "" - priority_info = {} - is_emoji = False - is_picid = False - is_notify = False - is_command = False - key_words = "" - key_words_lite = "" - else: - filtered_display_message = "" - interest_value = message.interest_value + filtered_processed_plain_text = "" + + display_message = message.display_message or message.processed_plain_text or "" + filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) + + # 直接从 DatabaseMessages 获取所有字段 + msg_id = message.message_id + msg_time = message.time + chat_id = message.chat_id + reply_to = "" # DatabaseMessages 没有 reply_to 字段 is_mentioned = message.is_mentioned - reply_to = "" - priority_mode = message.priority_mode - priority_info = message.priority_info - is_emoji = message.is_emoji - is_picid = message.is_picid - is_notify = message.is_notify - is_command = message.is_command - # 序列化关键词列表为JSON字符串 - key_words = MessageStorage._serialize_keywords(message.key_words) - key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) + interest_value = message.interest_value or 0.0 + priority_mode = "" # DatabaseMessages 没有 priority_mode + priority_info_json = None # DatabaseMessages 没有 priority_info + is_emoji = message.is_emoji or False + is_picid = message.is_picid or False + is_notify = message.is_notify or False + is_command = message.is_command or False + key_words = "" # DatabaseMessages 没有 key_words + key_words_lite = "" + memorized_times = 0 # DatabaseMessages 没有 memorized_times + + # 使用 DatabaseMessages 中的嵌套对象信息 + user_platform = message.user_info.platform if message.user_info else "" + user_id = message.user_info.user_id if message.user_info else "" + user_nickname = message.user_info.user_nickname if message.user_info else "" + user_cardname = message.user_info.user_cardname if message.user_info else None + + chat_info_stream_id = message.chat_info.stream_id if message.chat_info else "" + chat_info_platform = message.chat_info.platform if message.chat_info else "" + chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0 + chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0 + chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else "" + chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else "" + chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else "" + chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None + chat_info_group_platform = message.group_info.group_platform if message.group_info else None + chat_info_group_id = message.group_info.group_id if message.group_info else None + chat_info_group_name = message.group_info.group_name if message.group_info else None + + else: + # MessageSending 处理逻辑 + processed_plain_text = message.processed_plain_text - chat_info_dict = chat_stream.to_dict() - user_info_dict = message.message_info.user_info.to_dict() # type: ignore + if processed_plain_text: + processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) + # 增加对None的防御性处理 + safe_processed_plain_text = processed_plain_text or "" + filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) + else: + filtered_processed_plain_text = "" - # message_id 现在是 TextField,直接使用字符串值 - msg_id = message.message_info.message_id + if isinstance(message, MessageSending): + display_message = message.display_message + if display_message: + filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) + else: + # 如果没有设置display_message,使用processed_plain_text作为显示消息 + filtered_display_message = ( + re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL) + ) + interest_value = 0 + is_mentioned = False + reply_to = message.reply_to + priority_mode = "" + priority_info = {} + is_emoji = False + is_picid = False + is_notify = False + is_command = False + key_words = "" + key_words_lite = "" + else: + filtered_display_message = "" + interest_value = message.interest_value + is_mentioned = message.is_mentioned + reply_to = "" + priority_mode = message.priority_mode + priority_info = message.priority_info + is_emoji = message.is_emoji + is_picid = message.is_picid + is_notify = message.is_notify + is_command = message.is_command + # 序列化关键词列表为JSON字符串 + key_words = MessageStorage._serialize_keywords(message.key_words) + key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) - # 安全地获取 group_info, 如果为 None 则视为空字典 - group_info_from_chat = chat_info_dict.get("group_info") or {} - # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一) - user_info_from_chat = chat_info_dict.get("user_info") or {} + chat_info_dict = chat_stream.to_dict() + user_info_dict = message.message_info.user_info.to_dict() # type: ignore - # 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段 - priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None + # message_id 现在是 TextField,直接使用字符串值 + msg_id = message.message_info.message_id + msg_time = float(message.message_info.time or time.time()) + chat_id = chat_stream.stream_id + memorized_times = message.memorized_times + + # 安全地获取 group_info, 如果为 None 则视为空字典 + group_info_from_chat = chat_info_dict.get("group_info") or {} + # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一) + user_info_from_chat = chat_info_dict.get("user_info") or {} + + # 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段 + priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None + + user_platform = user_info_dict.get("platform") + user_id = user_info_dict.get("user_id") + user_nickname = user_info_dict.get("user_nickname") + user_cardname = user_info_dict.get("user_cardname") + + chat_info_stream_id = chat_info_dict.get("stream_id") + chat_info_platform = chat_info_dict.get("platform") + chat_info_create_time = float(chat_info_dict.get("create_time", 0.0)) + chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0)) + chat_info_user_platform = user_info_from_chat.get("platform") + chat_info_user_id = user_info_from_chat.get("user_id") + chat_info_user_nickname = user_info_from_chat.get("user_nickname") + chat_info_user_cardname = user_info_from_chat.get("user_cardname") + chat_info_group_platform = group_info_from_chat.get("platform") + chat_info_group_id = group_info_from_chat.get("group_id") + chat_info_group_name = group_info_from_chat.get("group_name") # 获取数据库会话 - new_message = Messages( message_id=msg_id, - time=float(message.message_info.time or time.time()), - chat_id=chat_stream.stream_id, + time=msg_time, + chat_id=chat_id, reply_to=reply_to, is_mentioned=is_mentioned, - chat_info_stream_id=chat_info_dict.get("stream_id"), - chat_info_platform=chat_info_dict.get("platform"), - chat_info_user_platform=user_info_from_chat.get("platform"), - chat_info_user_id=user_info_from_chat.get("user_id"), - chat_info_user_nickname=user_info_from_chat.get("user_nickname"), - chat_info_user_cardname=user_info_from_chat.get("user_cardname"), - chat_info_group_platform=group_info_from_chat.get("platform"), - chat_info_group_id=group_info_from_chat.get("group_id"), - chat_info_group_name=group_info_from_chat.get("group_name"), - chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)), - chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)), - user_platform=user_info_dict.get("platform"), - user_id=user_info_dict.get("user_id"), - user_nickname=user_info_dict.get("user_nickname"), - user_cardname=user_info_dict.get("user_cardname"), + chat_info_stream_id=chat_info_stream_id, + chat_info_platform=chat_info_platform, + chat_info_user_platform=chat_info_user_platform, + chat_info_user_id=chat_info_user_id, + chat_info_user_nickname=chat_info_user_nickname, + chat_info_user_cardname=chat_info_user_cardname, + chat_info_group_platform=chat_info_group_platform, + chat_info_group_id=chat_info_group_id, + chat_info_group_name=chat_info_group_name, + chat_info_create_time=chat_info_create_time, + chat_info_last_active_time=chat_info_last_active_time, + user_platform=user_platform, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, processed_plain_text=filtered_processed_plain_text, display_message=filtered_display_message, - memorized_times=message.memorized_times, + memorized_times=memorized_times, interest_value=interest_value, priority_mode=priority_mode, priority_info=priority_info_json, @@ -145,36 +216,43 @@ class MessageStorage: traceback.print_exc() @staticmethod - async def update_message(message): - """更新消息ID""" + async def update_message(message_data: dict): + """更新消息ID(从消息字典)""" try: - mmc_message_id = message.message_info.message_id + # 从字典中提取信息 + message_info = message_data.get("message_info", {}) + mmc_message_id = message_info.get("message_id") + + message_segment = message_data.get("message_segment", {}) + segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None + segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {} + qq_message_id = None - logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {message.message_segment.type}") + logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}") # 根据消息段类型提取message_id - if message.message_segment.type == "notify": - qq_message_id = message.message_segment.data.get("id") - elif message.message_segment.type == "text": - qq_message_id = message.message_segment.data.get("id") - elif message.message_segment.type == "reply": - qq_message_id = message.message_segment.data.get("id") + if segment_type == "notify": + qq_message_id = segment_data.get("id") + elif segment_type == "text": + qq_message_id = segment_data.get("id") + elif segment_type == "reply": + qq_message_id = segment_data.get("id") if qq_message_id: logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}") - elif message.message_segment.type == "adapter_response": + elif segment_type == "adapter_response": logger.debug("适配器响应消息,不需要更新ID") return - elif message.message_segment.type == "adapter_command": + elif segment_type == "adapter_command": logger.debug("适配器命令消息,不需要更新ID") return else: - logger.debug(f"未知的消息段类型: {message.message_segment.type},跳过ID更新") + logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新") return if not qq_message_id: - logger.debug(f"消息段类型 {message.message_segment.type} 中未找到有效的message_id,跳过更新") - logger.debug(f"消息段数据: {message.message_segment.data}") + logger.debug(f"消息段类型 {segment_type} 中未找到有效的message_id,跳过更新") + logger.debug(f"消息段数据: {segment_data}") return # 使用上下文管理器确保session正确管理 diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 35a17d675..7ea2b4785 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -137,7 +137,7 @@ class ActionModifier: logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用") # === 第二阶段:检查动作的关联类型 === - chat_context = self.chat_stream.stream_context + chat_context = self.chat_stream.context_manager.context current_actions_s2 = self.action_manager.get_using_actions() type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 8a07948fb..66f33ce09 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -13,7 +13,7 @@ from typing import Any from src.chat.express.expression_selector import expression_selector from src.chat.message_receive.chat_stream import ChatStream -from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo +from src.chat.message_receive.message import MessageSending, Seg, UserInfo from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.utils.chat_message_builder import ( build_readable_messages, @@ -1733,7 +1733,7 @@ class DefaultReplyer: is_emoji: bool, thinking_start_time: float, display_message: str, - anchor_message: MessageRecv | None = None, + anchor_message: DatabaseMessages | None = None, ) -> MessageSending: """构建单个发送消息""" @@ -1743,8 +1743,11 @@ class DefaultReplyer: platform=self.chat_stream.platform, ) - # await anchor_message.process() - sender_info = anchor_message.message_info.user_info if anchor_message else None + # 从 DatabaseMessages 获取 sender_info + if anchor_message: + sender_info = anchor_message.user_info + else: + sender_info = None return MessageSending( message_id=message_id, # 使用片段的唯一ID diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 496e50673..bc0641c9d 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -11,7 +11,7 @@ import rjieba from maim_message import UserInfo from src.chat.message_receive.chat_stream import get_chat_manager -from src.chat.message_receive.message import MessageRecv +# MessageRecv 已被移除,现在使用 DatabaseMessages from src.common.logger import get_logger from src.common.message_repository import count_messages, find_messages from src.config.config import global_config, model_config @@ -41,34 +41,58 @@ def db_message_to_str(message_dict: dict) -> str: return result -def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: - """检查消息是否提到了机器人""" +def is_mentioned_bot_in_message(message) -> tuple[bool, float]: + """检查消息是否提到了机器人 + + Args: + message: DatabaseMessages 消息对象 + + Returns: + tuple[bool, float]: (是否提及, 提及概率) + """ keywords = [global_config.bot.nickname] nicknames = global_config.bot.alias_names reply_probability = 0.0 is_at = False is_mentioned = False - if message.is_mentioned is not None: - return bool(message.is_mentioned), message.is_mentioned - if ( - message.message_info.additional_config is not None - and message.message_info.additional_config.get("is_mentioned") is not None - ): + + # 检查 is_mentioned 属性 + mentioned_attr = getattr(message, "is_mentioned", None) + if mentioned_attr is not None: try: - reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore + return bool(mentioned_attr), float(mentioned_attr) + except (ValueError, TypeError): + pass + + # 检查 additional_config + additional_config = None + + # DatabaseMessages: additional_config 是 JSON 字符串 + if message.additional_config: + try: + import orjson + additional_config = orjson.loads(message.additional_config) + except Exception: + pass + + if additional_config and additional_config.get("is_mentioned") is not None: + try: + reply_probability = float(additional_config.get("is_mentioned")) # type: ignore is_mentioned = True return is_mentioned, reply_probability except Exception as e: logger.warning(str(e)) logger.warning( - f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}" + f"消息中包含不合理的设置 is_mentioned: {additional_config.get('is_mentioned')}" ) - if global_config.bot.nickname in message.processed_plain_text: + # 检查消息文本内容 + processed_text = message.processed_plain_text or "" + if global_config.bot.nickname in processed_text: is_mentioned = True for alias_name in global_config.bot.alias_names: - if alias_name in message.processed_plain_text: + if alias_name in processed_text: is_mentioned = True # 判断是否被@ @@ -110,7 +134,6 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: logger.debug("被提及,回复概率设置为100%") return is_mentioned, reply_probability - async def get_embedding(text, request_type="embedding") -> list[float] | None: """获取文本的embedding向量""" # 每次都创建新的LLMRequest实例以避免事件循环冲突 diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index 2ab7ba13e..fad348bf9 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -9,15 +9,18 @@ from src.common.logger import get_logger logger = get_logger("db_migration") -async def check_and_migrate_database(): +async def check_and_migrate_database(existing_engine=None): """ 异步检查数据库结构并自动迁移。 - 自动创建不存在的表。 - 自动为现有表添加缺失的列。 - 自动为现有表创建缺失的索引。 + + Args: + existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎。 """ logger.info("正在检查数据库结构并执行自动迁移...") - engine = await get_engine() + engine = existing_engine if existing_engine is not None else await get_engine() async with engine.connect() as connection: # 在同步上下文中运行inspector操作 diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 9f03aa43c..287f0fc29 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -780,12 +780,8 @@ async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[Async _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) # 迁移 - try: - from src.common.database.db_migration import check_and_migrate_database - await check_and_migrate_database(existing_engine=_engine) - except TypeError: - from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate - await _legacy_migrate() + from src.common.database.db_migration import check_and_migrate_database + await check_and_migrate_database(existing_engine=_engine) if config.database_type == "sqlite": await enable_sqlite_wal_mode(_engine) diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 14f1dfef5..ef52b93a1 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -2,7 +2,6 @@ import math import random import time -from src.chat.message_receive.message import MessageRecv from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.data_models.database_data_model import DatabaseMessages @@ -98,7 +97,7 @@ class ChatMood: if not hasattr(self, "last_change_time"): self.last_change_time = 0 - async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float): + async def update_mood_by_message(self, message: DatabaseMessages, interested_rate: float): # 确保异步初始化已完成 await self._initialize() @@ -109,11 +108,8 @@ class ChatMood: self.regression_count = 0 - # 处理不同类型的消息对象 - if isinstance(message, MessageRecv): - message_time = message.message_info.time - else: # DatabaseMessages - message_time = message.time + # 使用 DatabaseMessages 的时间字段 + message_time = message.time # 防止负时间差 during_last_time = max(0, message_time - self.last_change_time) diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 96f0e4b09..71786562d 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -86,13 +86,16 @@ async def file_to_stream( import asyncio import time import traceback -from typing import Any +from typing import Any, TYPE_CHECKING from maim_message import Seg, UserInfo +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages + # 导入依赖 from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.message_receive.message import MessageRecv, MessageSending +from src.chat.message_receive.message import MessageSending from src.chat.message_receive.uni_message_sender import HeartFCSender from src.common.logger import get_logger from src.config.config import global_config @@ -104,84 +107,53 @@ logger = get_logger("send_api") _adapter_response_pool: dict[str, asyncio.Future] = {} -def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None: - """查找要回复的消息 +def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessages | None": + """从消息字典构建 DatabaseMessages 对象 Args: message_dict: 消息字典或 DatabaseMessages 对象 Returns: - Optional[MessageRecv]: 找到的消息,如果没找到则返回None + Optional[DatabaseMessages]: 构建的消息对象,如果构建失败则返回None """ - # 兼容 DatabaseMessages 对象和字典 - if isinstance(message_dict, dict): - user_platform = message_dict.get("user_platform", "") - user_id = message_dict.get("user_id", "") - user_nickname = message_dict.get("user_nickname", "") - user_cardname = message_dict.get("user_cardname", "") - chat_info_group_id = message_dict.get("chat_info_group_id") - chat_info_group_platform = message_dict.get("chat_info_group_platform", "") - chat_info_group_name = message_dict.get("chat_info_group_name", "") - chat_info_platform = message_dict.get("chat_info_platform", "") - message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id") - time_val = message_dict.get("time") - additional_config = message_dict.get("additional_config") - processed_plain_text = message_dict.get("processed_plain_text") - else: - # DatabaseMessages 对象 - user_platform = getattr(message_dict, "user_platform", "") - user_id = getattr(message_dict, "user_id", "") - user_nickname = getattr(message_dict, "user_nickname", "") - user_cardname = getattr(message_dict, "user_cardname", "") - chat_info_group_id = getattr(message_dict, "chat_info_group_id", None) - chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "") - chat_info_group_name = getattr(message_dict, "chat_info_group_name", "") - chat_info_platform = getattr(message_dict, "chat_info_platform", "") - message_id = getattr(message_dict, "message_id", None) - time_val = getattr(message_dict, "time", None) - additional_config = getattr(message_dict, "additional_config", None) - processed_plain_text = getattr(message_dict, "processed_plain_text", "") + from src.common.data_models.database_data_model import DatabaseMessages - # 构建MessageRecv对象 - user_info = { - "platform": user_platform, - "user_id": user_id, - "user_nickname": user_nickname, - "user_cardname": user_cardname, - } - - group_info = {} - if chat_info_group_id: - group_info = { - "platform": chat_info_group_platform, - "group_id": chat_info_group_id, - "group_name": chat_info_group_name, - } - - format_info = {"content_format": "", "accept_format": ""} - template_info = {"template_items": {}} - - message_info = { - "platform": chat_info_platform, - "message_id": message_id, - "time": time_val, - "group_info": group_info, - "user_info": user_info, - "additional_config": additional_config, - "format_info": format_info, - "template_info": template_info, - } - - new_message_dict = { - "message_info": message_info, - "raw_message": processed_plain_text, - "processed_plain_text": processed_plain_text, - } - - message_recv = MessageRecv(new_message_dict) - - logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}") - return message_recv + # 如果已经是 DatabaseMessages,直接返回 + if isinstance(message_dict, DatabaseMessages): + return message_dict + + # 从字典提取信息 + user_platform = message_dict.get("user_platform", "") + user_id = message_dict.get("user_id", "") + user_nickname = message_dict.get("user_nickname", "") + user_cardname = message_dict.get("user_cardname", "") + chat_info_group_id = message_dict.get("chat_info_group_id") + chat_info_group_platform = message_dict.get("chat_info_group_platform", "") + chat_info_group_name = message_dict.get("chat_info_group_name", "") + chat_info_platform = message_dict.get("chat_info_platform", "") + message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id") + time_val = message_dict.get("time", time.time()) + additional_config = message_dict.get("additional_config") + processed_plain_text = message_dict.get("processed_plain_text", "") + + # DatabaseMessages 使用扁平参数构造 + db_message = DatabaseMessages( + message_id=message_id or "temp_reply_id", + time=time_val, + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + user_platform=user_platform, + chat_info_group_id=chat_info_group_id, + chat_info_group_name=chat_info_group_name, + chat_info_group_platform=chat_info_group_platform, + chat_info_platform=chat_info_platform, + processed_plain_text=processed_plain_text, + additional_config=additional_config + ) + + logger.info(f"[SendAPI] 构建回复消息对象,发送者: {user_nickname}") + return db_message def put_adapter_response(request_id: str, response_data: dict) -> None: @@ -285,17 +257,17 @@ async def _send_to_target( "message_id": "temp_reply_id", # 临时ID "time": time.time() } - anchor_message = message_dict_to_message_recv(message_dict=temp_message_dict) + anchor_message = message_dict_to_db_message(message_dict=temp_message_dict) else: anchor_message = None reply_to_platform_id = f"{target_stream.platform}:{sender_id}" if anchor_message else None elif reply_to_message: - anchor_message = message_dict_to_message_recv(message_dict=reply_to_message) + anchor_message = message_dict_to_db_message(message_dict=reply_to_message) if anchor_message: - anchor_message.update_chat_stream(target_stream) + # DatabaseMessages 不需要 update_chat_stream,它是纯数据对象 reply_to_platform_id = ( - f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}" + f"{anchor_message.chat_info.platform}:{anchor_message.user_info.user_id}" ) else: reply_to_platform_id = None diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 9cb41ed04..7076bbba6 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -1,10 +1,14 @@ from abc import ABC, abstractmethod +from typing import TYPE_CHECKING -from src.chat.message_receive.message import MessageRecv +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.plugin_system.apis import send_api from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("base_command") @@ -29,11 +33,11 @@ class BaseCommand(ABC): chat_type_allow: ChatType = ChatType.ALL """允许的聊天类型,默认为所有类型""" - def __init__(self, message: MessageRecv, plugin_config: dict | None = None): + def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None): """初始化Command组件 Args: - message: 接收到的消息对象 + message: 接收到的消息对象(DatabaseMessages) plugin_config: 插件配置字典 """ self.message = message @@ -41,6 +45,9 @@ class BaseCommand(ABC): self.plugin_config = plugin_config or {} # 直接存储插件配置字典 self.log_prefix = "[Command]" + + # chat_stream 会在运行时被 bot.py 设置 + self.chat_stream: "ChatStream | None" = None # 从类属性获取chat_type_allow设置 self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL) @@ -49,7 +56,7 @@ class BaseCommand(ABC): # 验证聊天类型限制 if not self._validate_chat_type(): - is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message + is_group = message.group_info is not None logger.warning( f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: " f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}" @@ -72,8 +79,8 @@ class BaseCommand(ABC): if self.chat_type_allow == ChatType.ALL: return True - # 检查是否为群聊消息 - is_group = self.message.message_info.group_info + # 检查是否为群聊消息(DatabaseMessages使用group_info来判断) + is_group = self.message.group_info is not None if self.chat_type_allow == ChatType.GROUP and is_group: return True @@ -137,12 +144,11 @@ class BaseCommand(ABC): bool: 是否发送成功 """ # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to) + return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to) async def send_type( self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = "" @@ -160,15 +166,14 @@ class BaseCommand(ABC): bool: 是否发送成功 """ # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False return await send_api.custom_to_stream( message_type=message_type, content=content, - stream_id=chat_stream.stream_id, + stream_id=self.chat_stream.stream_id, display_message=display_message, typing=typing, reply_to=reply_to, @@ -190,8 +195,7 @@ class BaseCommand(ABC): """ try: # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False @@ -200,7 +204,7 @@ class BaseCommand(ABC): success = await send_api.command_to_stream( command=command_data, - stream_id=chat_stream.stream_id, + stream_id=self.chat_stream.stream_id, storage_message=storage_message, display_message=display_message, ) @@ -225,12 +229,11 @@ class BaseCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id) + return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id) async def send_image(self, image_base64: str) -> bool: """发送图片 @@ -241,12 +244,11 @@ class BaseCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.image_to_stream(image_base64, chat_stream.stream_id) + return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id) @classmethod def get_command_info(cls) -> "CommandInfo": diff --git a/src/plugin_system/base/plus_command.py b/src/plugin_system/base/plus_command.py index e442d76c1..b53846fc2 100644 --- a/src/plugin_system/base/plus_command.py +++ b/src/plugin_system/base/plus_command.py @@ -5,8 +5,9 @@ import re from abc import ABC, abstractmethod +from typing import TYPE_CHECKING -from src.chat.message_receive.message import MessageRecv +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.apis import send_api @@ -14,6 +15,9 @@ from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.command_args import CommandArgs from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("plus_command") @@ -50,23 +54,26 @@ class PlusCommand(ABC): intercept_message: bool = False """是否拦截消息,不进行后续处理""" - def __init__(self, message: MessageRecv, plugin_config: dict | None = None): + def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None): """初始化命令组件 Args: - message: 接收到的消息对象 + message: 接收到的消息对象(DatabaseMessages) plugin_config: 插件配置字典 """ self.message = message self.plugin_config = plugin_config or {} self.log_prefix = "[PlusCommand]" + + # chat_stream 会在运行时被 bot.py 设置 + self.chat_stream: "ChatStream | None" = None # 解析命令参数 self._parse_command() # 验证聊天类型限制 if not self._validate_chat_type(): - is_group = self.message.message_info.group_info.group_id + is_group = message.group_info is not None logger.warning( f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: " f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}" @@ -124,8 +131,8 @@ class PlusCommand(ABC): if self.chat_type_allow == ChatType.ALL: return True - # 检查是否为群聊消息 - is_group = hasattr(self.message.message_info, "group_info") and self.message.message_info.group_info + # 检查是否为群聊消息(DatabaseMessages使用group_info判断) + is_group = self.message.group_info is not None if self.chat_type_allow == ChatType.GROUP and is_group: return True @@ -152,7 +159,7 @@ class PlusCommand(ABC): def _is_exact_command_call(self) -> bool: """检查是否是精确的命令调用(无参数)""" - if not hasattr(self.message, "plain_text") or not self.message.processed_plain_text: + if not self.message.processed_plain_text: return False plain_text = self.message.processed_plain_text.strip() @@ -218,12 +225,11 @@ class PlusCommand(ABC): bool: 是否发送成功 """ # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to) + return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to) async def send_type( self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = "" @@ -241,15 +247,14 @@ class PlusCommand(ABC): bool: 是否发送成功 """ # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False return await send_api.custom_to_stream( message_type=message_type, content=content, - stream_id=chat_stream.stream_id, + stream_id=self.chat_stream.stream_id, display_message=display_message, typing=typing, reply_to=reply_to, @@ -264,12 +269,11 @@ class PlusCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id) + return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id) async def send_image(self, image_base64: str) -> bool: """发送图片 @@ -280,12 +284,11 @@ class PlusCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): + if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"): logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") return False - return await send_api.image_to_stream(image_base64, chat_stream.stream_id) + return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id) @classmethod def get_plus_command_info(cls) -> "PlusCommandInfo": @@ -340,12 +343,12 @@ class PlusCommandAdapter(BaseCommand): 将PlusCommand适配到现有的插件系统,继承BaseCommand """ - def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None): + def __init__(self, plus_command_class, message: DatabaseMessages, plugin_config: dict | None = None): """初始化适配器 Args: plus_command_class: PlusCommand子类 - message: 消息对象 + message: 消息对象(DatabaseMessages) plugin_config: 插件配置 """ # 先设置必要的类属性 @@ -400,7 +403,7 @@ def create_plus_command_adapter(plus_command_class): command_pattern = plus_command_class._generate_command_pattern() chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL) - def __init__(self, message: MessageRecv, plugin_config: dict | None = None): + def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None): super().__init__(message, plugin_config) self.plus_command = plus_command_class(message, plugin_config) self.priority = getattr(plus_command_class, "priority", 0) diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py index 3af389f9f..e150e7e62 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py @@ -410,11 +410,9 @@ class ChatterPlanExecutor: ) # 添加到chat_stream的已读消息中 - if hasattr(chat_stream, "stream_context") and chat_stream.stream_context: - chat_stream.stream_context.history_messages.append(bot_message) - logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...") - else: - logger.warning("chat_stream没有stream_context,无法添加已读消息") + chat_stream.context_manager.context.history_messages.append(bot_message) + logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...") + except Exception as e: logger.error(f"添加机器人回复到已读消息时出错: {e}") diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py index 2a719da83..7b8ffdad1 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py @@ -96,7 +96,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler): """处理消息事件 Args: - kwargs: 事件参数,格式为 {"message": MessageRecv} + kwargs: 事件参数,格式为 {"message": DatabaseMessages} Returns: HandlerResult: 处理结果 @@ -104,7 +104,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler): if not kwargs: return HandlerResult(success=True, continue_process=True, message=None) - # 从 kwargs 中获取 MessageRecv 对象 + # 从 kwargs 中获取 DatabaseMessages 对象 message = kwargs.get("message") if not message or not hasattr(message, "chat_stream"): return HandlerResult(success=True, continue_process=True, message=None)