diff --git a/bot.py b/bot.py index a4a1030c5..0399b0d19 100644 --- a/bot.py +++ b/bot.py @@ -185,12 +185,12 @@ class MaiBotMain(BaseMain): check_eula() logger.info("检查EULA和隐私条款完成") - def initialize_database(self): + async def initialize_database(self): """初始化数据库""" logger.info("正在初始化数据库连接...") try: - initialize_sql_database(global_config.database) + await initialize_sql_database(global_config.database) logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库") except Exception as e: logger.error(f"数据库连接初始化失败: {e}") @@ -211,11 +211,11 @@ class MaiBotMain(BaseMain): self.main_system = MainSystem() return self.main_system - def run(self): + async def run(self): """运行主程序""" self.setup_timezone() self.check_and_confirm_eula() - self.initialize_database() + await self.initialize_database() return self.create_main_system() @@ -225,14 +225,14 @@ if __name__ == "__main__": try: # 创建MaiBotMain实例并获取MainSystem maibot = MaiBotMain() - main_system = maibot.run() # 创建事件循环 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - # 异步初始化数据库表结构 + # 异步初始化数据库和表结构 + main_system = loop.run_until_complete(maibot.run()) loop.run_until_complete(maibot.initialize_database_async()) # 执行初始化和任务调度 loop.run_until_complete(main_system.initialize()) @@ -269,3 +269,4 @@ if __name__ == "__main__": # 在程序退出前暂停,让你有机会看到输出 # input("按 Enter 键退出...") # <--- 添加这行 sys.exit(exit_code) # <--- 使用记录的退出码 + \ No newline at end of file diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 2cfe3e13c..4df22f152 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -8,6 +8,8 @@ import datetime from typing import Dict, Any +from sqlalchemy import select + from src.common.logger import get_logger from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session from src.config.config import global_config @@ -27,9 +29,11 @@ class AntiInjectionStatistics: async def get_or_create_stats(): """获取或创建统计记录""" try: - with get_db_session() as session: + async with get_db_session() as session: # 获取最新的统计记录,如果没有则创建 - stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() + stats = (await session.execute( + select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()) + )).scalars().first() if not stats: stats = AntiInjectionStats() session.add(stats) @@ -44,8 +48,10 @@ class AntiInjectionStatistics: async def update_stats(**kwargs): """更新统计数据""" try: - with get_db_session() as session: - stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() + async with get_db_session() as session: + stats = (await session.execute( + select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()) + )).scalars().first() if not stats: stats = AntiInjectionStats() session.add(stats) @@ -53,7 +59,7 @@ class AntiInjectionStatistics: # 更新统计字段 for key, value in kwargs.items(): if key == "processing_time_delta": - # 处理时间累加 - 确保不为None + # 处理 时间累加 - 确保不为None if stats.processing_time_total is None: stats.processing_time_total = 0.0 stats.processing_time_total += value @@ -138,9 +144,9 @@ class AntiInjectionStatistics: async def reset_stats(): """重置统计信息""" try: - with get_db_session() as session: + async with get_db_session() as session: # 删除现有统计记录 - session.query(AntiInjectionStats).delete() + await session.execute(select(AntiInjectionStats).delete()) await session.commit() logger.info("统计信息已重置") except Exception as e: diff --git a/src/chat/antipromptinjector/management/user_ban.py b/src/chat/antipromptinjector/management/user_ban.py index 676436c42..b965a08af 100644 --- a/src/chat/antipromptinjector/management/user_ban.py +++ b/src/chat/antipromptinjector/management/user_ban.py @@ -8,6 +8,8 @@ import datetime from typing import Optional, Tuple +from sqlalchemy import select + from src.common.logger import get_logger from src.common.database.sqlalchemy_models import BanUser, get_db_session from ..types import DetectionResult @@ -37,8 +39,9 @@ class UserBanManager: 如果用户被封禁则返回拒绝结果,否则返回None """ try: - with get_db_session() as session: - ban_record = session.query(BanUser).filter_by(user_id=user_id, platform=platform).first() + async with get_db_session() as session: + result = await session.execute(select(BanUser).filter_by(user_id=user_id, platform=platform)) + ban_record = result.scalar_one_or_none() if ban_record: # 只有违规次数达到阈值时才算被封禁 @@ -70,9 +73,10 @@ class UserBanManager: detection_result: 检测结果 """ try: - with get_db_session() as session: + async with get_db_session() as session: # 查找或创建违规记录 - ban_record = session.query(BanUser).filter_by(user_id=user_id, platform=platform).first() + result = await session.execute(select(BanUser).filter_by(user_id=user_id, platform=platform)) + ban_record = result.scalar_one_or_none() if ban_record: ban_record.violation_num += 1 diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index b614345f0..9e4829a56 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -149,7 +149,7 @@ class MaiEmoji: # --- 数据库操作 --- try: # 准备数据库记录 for emoji collection - with get_db_session() as session: + async with get_db_session() as session: emotion_str = ",".join(self.emotion) if self.emotion else "" emoji = Emoji( @@ -167,7 +167,7 @@ class MaiEmoji: last_used_time=self.last_used_time, ) session.add(emoji) - session.commit() + await session.commit() logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") @@ -203,17 +203,18 @@ class MaiEmoji: # 2. 删除数据库记录 try: - with get_db_session() as session: - will_delete_emoji = session.execute( + async with get_db_session() as session: + result = await session.execute( select(Emoji).where(Emoji.emoji_hash == self.hash) - ).scalar_one_or_none() + ) + will_delete_emoji = result.scalar_one_or_none() if will_delete_emoji is None: logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") result = 0 # Indicate no DB record was deleted else: - session.delete(will_delete_emoji) + await session.delete(will_delete_emoji) result = 1 # Successfully deleted one record - session.commit() + await session.commit() except Exception as e: logger.error(f"[错误] 删除数据库记录时出错: {str(e)}") result = 0 @@ -424,17 +425,19 @@ class EmojiManager: # if not self._initialized: # raise RuntimeError("EmojiManager not initialized") - def record_usage(self, emoji_hash: str) -> None: + async def record_usage(self, emoji_hash: str) -> None: """记录表情使用次数""" try: - with get_db_session() as session: - emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() + async with get_db_session() as session: + stmt = select(Emoji).where(Emoji.emoji_hash == emoji_hash) + result = await session.execute(stmt) + emoji_update = result.scalar_one_or_none() if emoji_update is None: logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") else: emoji_update.usage_count += 1 emoji_update.last_used_time = time.time() # Update last used time - session.commit() + await session.commit() except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -521,7 +524,7 @@ class EmojiManager: # 7. 获取选中的表情包并更新使用记录 selected_emoji = candidate_emojis[selected_index] - self.record_usage(selected_emoji.hash) + await self.record_usage(selected_emoji.hash) _time_end = time.time() logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s") @@ -657,10 +660,11 @@ class EmojiManager: async def get_all_emoji_from_db(self) -> None: """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" try: - with get_db_session() as session: + async with get_db_session() as session: logger.debug("[数据库] 开始加载所有表情包记录 ...") - emoji_instances = session.execute(select(Emoji)).scalars().all() + result = await session.execute(select(Emoji)) + emoji_instances = result.scalars().all() emoji_objects, load_errors = _to_emoji_objects(emoji_instances) # 更新内存中的列表和数量 @@ -686,14 +690,16 @@ class EmojiManager: list[MaiEmoji]: 表情包对象列表 """ try: - with get_db_session() as session: + async with get_db_session() as session: if emoji_hash: - query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all() + result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)) + query = result.scalars().all() else: logger.warning( "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" ) - query = session.execute(select(Emoji)).scalars().all() + result = await session.execute(select(Emoji)) + query = result.scalars().all() emoji_instances = query emoji_objects, load_errors = _to_emoji_objects(emoji_instances) @@ -770,10 +776,10 @@ class EmojiManager: # 如果内存中没有,从数据库查找 try: - with get_db_session() as session: - emoji_record = session.execute( - select(Emoji).where(Emoji.emoji_hash == emoji_hash) - ).scalar_one_or_none() + async with get_db_session() as session: + stmt = select(Emoji).where(Emoji.emoji_hash == emoji_hash) + result = await session.execute(stmt) + emoji_record = result.scalar_one_or_none() if emoji_record and emoji_record.description: logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") return emoji_record.description @@ -939,12 +945,13 @@ class EmojiManager: # 2. 检查数据库中是否已存在该表情包的描述,实现复用 existing_description = None try: - with get_db_session() as session: - existing_image = ( - session.query(Images) - .filter((Images.emoji_hash == image_hash) & (Images.type == "emoji")) - .one_or_none() + async with get_db_session() as session: + stmt = select(Images).where( + Images.emoji_hash == image_hash, + Images.type == "emoji" ) + result = await session.execute(stmt) + existing_image = result.scalar_one_or_none() if existing_image and existing_image.description: existing_description = existing_image.description logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index 8ee2017cb..6dc5ad8e4 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -198,7 +198,7 @@ class RecencyEnergyCalculator(EnergyCalculator): class RelationshipEnergyCalculator(EnergyCalculator): """关系能量计算器""" - def calculate(self, context: Dict[str, Any]) -> float: + async def calculate(self, context: Dict[str, Any]) -> float: """基于关系计算能量""" user_id = context.get("user_id") if not user_id: @@ -208,7 +208,7 @@ class RelationshipEnergyCalculator(EnergyCalculator): try: from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system - relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id) + relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id) logger.debug(f"使用插件内部系统计算关系分: {relationship_score:.3f}") return max(0.0, min(1.0, relationship_score)) @@ -273,7 +273,7 @@ class EnergyManager: except Exception as e: logger.warning(f"加载AFC阈值失败,使用默认值: {e}") - def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float: + async def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float: """计算聊天流的focus_energy""" start_time = time.time() @@ -303,7 +303,16 @@ class EnergyManager: for calculator in self.calculators: try: - score = calculator.calculate(context) + # 支持同步和异步计算器 + if callable(calculator.calculate): + import inspect + if inspect.iscoroutinefunction(calculator.calculate): + score = await calculator.calculate(context) + else: + score = calculator.calculate(context) + else: + score = calculator.calculate(context) + weight = calculator.get_weight() component_scores[calculator.__class__.__name__] = score diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index be04dd065..5c70fa744 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -8,6 +8,7 @@ import traceback from typing import List, Dict, Optional, Any from datetime import datetime import numpy as np +from sqlalchemy import select from src.common.logger import get_logger from src.config.config import global_config @@ -610,14 +611,13 @@ class BotInterestManager: from src.common.database.sqlalchemy_database_api import get_db_session import orjson - with get_db_session() as session: + async with get_db_session() as session: # 查询最新的兴趣标签配置 - db_interests = ( - session.query(DBBotPersonalityInterests) - .filter(DBBotPersonalityInterests.personality_id == personality_id) + db_interests = (await session.execute( + select(DBBotPersonalityInterests) + .where(DBBotPersonalityInterests.personality_id == personality_id) .order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc()) - .first() - ) + )).scalars().first() if db_interests: logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}") @@ -700,13 +700,12 @@ class BotInterestManager: # 序列化为JSON json_data = orjson.dumps(tags_data) - with get_db_session() as session: + async with get_db_session() as session: # 检查是否已存在相同personality_id的记录 - existing_record = ( - session.query(DBBotPersonalityInterests) - .filter(DBBotPersonalityInterests.personality_id == interests.personality_id) - .first() - ) + existing_record = (await session.execute( + select(DBBotPersonalityInterests) + .where(DBBotPersonalityInterests.personality_id == interests.personality_id) + )).scalars().first() if existing_record: # 更新现有记录 @@ -731,19 +730,17 @@ class BotInterestManager: last_updated=interests.last_updated, ) session.add(new_record) - session.commit() + await session.commit() logger.info(f"✅ 成功创建兴趣标签配置,版本: {interests.version}") logger.info("✅ 兴趣标签已成功保存到数据库") # 验证保存是否成功 - with get_db_session() as session: - saved_record = ( - session.query(DBBotPersonalityInterests) - .filter(DBBotPersonalityInterests.personality_id == interests.personality_id) - .first() - ) - session.commit() + async with get_db_session() as session: + saved_record = (await session.execute( + select(DBBotPersonalityInterests) + .where(DBBotPersonalityInterests.personality_id == interests.personality_id) + )).scalars().first() if saved_record: logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录") logger.info(f" 版本: {saved_record.version}") diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index ca726c1a8..4141b44e0 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -882,7 +882,8 @@ class EntorhinalCortex: # 获取数据库中所有节点和内存中所有节点 async with get_db_session() as session: - db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()} + result = await session.execute(select(GraphNodes)) + db_nodes = {node.concept: node for node in result.scalars()} memory_nodes = list(self.memory_graph.G.nodes(data=True)) # 批量准备节点数据 @@ -978,7 +979,8 @@ class EntorhinalCortex: await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete))) # 处理边的信息 - db_edges = list((await session.execute(select(GraphEdges))).scalars()) + result = await session.execute(select(GraphEdges)) + db_edges = list(result.scalars()) memory_edges = list(self.memory_graph.G.edges(data=True)) # 创建边的哈希值字典 @@ -1157,7 +1159,8 @@ class EntorhinalCortex: # 从数据库加载所有节点 async with get_db_session() as session: - nodes = list((await session.execute(select(GraphNodes))).scalars()) + result = await session.execute(select(GraphNodes)) + nodes = list(result.scalars()) for node in nodes: concept = node.concept try: @@ -1192,7 +1195,8 @@ class EntorhinalCortex: continue # 从数据库加载所有边 - edges = list((await session.execute(select(GraphEdges))).scalars()) + result = await session.execute(select(GraphEdges)) + edges = list(result.scalars()) for edge in edges: source = edge.source target = edge.target diff --git a/src/chat/memory_system/async_memory_optimizer.py b/src/chat/memory_system/async_memory_optimizer.py index e80ad0efd..1fcacb32d 100644 --- a/src/chat/memory_system/async_memory_optimizer.py +++ b/src/chat/memory_system/async_memory_optimizer.py @@ -184,6 +184,11 @@ class AsyncMemoryQueue: from src.chat.memory_system.Hippocampus import hippocampus_manager if hippocampus_manager._initialized: + # 确保海马体对象已正确初始化 + if not hippocampus_manager._hippocampus.parahippocampal_gyrus: + logger.warning("海马体对象未完全初始化,进行同步初始化") + hippocampus_manager._hippocampus.initialize() + await hippocampus_manager.build_memory() return True return False diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index a8675f5c0..ce388a4de 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -108,7 +108,7 @@ class InstantMemory: @staticmethod async def store_memory(memory_item: MemoryItem): - with get_db_session() as session: + async with get_db_session() as session: memory = Memory( memory_id=memory_item.memory_id, chat_id=memory_item.chat_id, @@ -161,20 +161,21 @@ class InstantMemory: logger.info(f"start_time: {start_time}, end_time: {end_time}") # 检索包含关键词的记忆 memories_set = set() - with get_db_session() as session: + async with get_db_session() as session: if start_time and end_time: start_ts = start_time.timestamp() end_ts = end_time.timestamp() - query = session.execute( + query = (await session.execute( select(Memory).where( (Memory.chat_id == self.chat_id) & (Memory.create_time >= start_ts) & (Memory.create_time < end_ts) ) - ).scalars() + )).scalars() else: - query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars() + query = result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id)) + result.scalars() for mem in query: # 对每条记忆 mem_keywords_str = mem.keywords or "[]" diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index 2f623fbd0..368a811a5 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -4,7 +4,7 @@ """ from .message_manager import MessageManager, message_manager -from .context_manager import StreamContextManager, context_manager +from .context_manager import SingleStreamContextManager from .distribution_manager import ( DistributionManager, DistributionPriority, @@ -16,8 +16,7 @@ from .distribution_manager import ( __all__ = [ "MessageManager", "message_manager", - "StreamContextManager", - "context_manager", + "SingleStreamContextManager", "DistributionManager", "DistributionPriority", "DistributionTask", diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index 982b8a8a5..ebf4e37d0 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -1,12 +1,12 @@ """ 重构后的聊天上下文管理器 提供统一、稳定的聊天上下文管理功能 +每个 context_manager 实例只管理一个 stream 的上下文 """ import asyncio import time -from typing import Dict, List, Optional, Any, Union, Tuple -from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Any from src.common.data_models.message_manager_data_model import StreamContext from src.common.logger import get_logger @@ -17,241 +17,112 @@ from .distribution_manager import distribution_manager logger = get_logger("context_manager") -class StreamContextManager: - """流上下文管理器 - 统一管理所有聊天流上下文""" - def __init__(self, max_context_size: Optional[int] = None, context_ttl: Optional[int] = None): - # 上下文存储 - self.stream_contexts: Dict[str, Any] = {} - self.context_metadata: Dict[str, Dict[str, Any]] = {} +class SingleStreamContextManager: + """单流上下文管理器 - 每个实例只管理一个 stream 的上下文""" - # 统计信息 - self.stats: Dict[str, Union[int, float, str, Dict]] = { - "total_messages": 0, - "total_streams": 0, - "active_streams": 0, - "inactive_streams": 0, - "last_activity": time.time(), - "creation_time": time.time(), - } + def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None): + self.stream_id = stream_id + self.context = context # 配置参数 self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100) - self.context_ttl = context_ttl or getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时 - self.cleanup_interval = getattr(global_config.chat, "context_cleanup_interval", 3600) # 1小时 - self.auto_cleanup = getattr(global_config.chat, "auto_cleanup_contexts", True) - self.enable_validation = getattr(global_config.chat, "enable_context_validation", True) + self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时 - # 清理任务 - self.cleanup_task: Optional[Any] = None - self.is_running = False + # 元数据 + self.created_time = time.time() + self.last_access_time = time.time() + self.access_count = 0 + self.total_messages = 0 - logger.info(f"上下文管理器初始化完成 (最大上下文: {self.max_context_size}, TTL: {self.context_ttl}s)") + logger.debug(f"单流上下文管理器初始化: {stream_id}") - def add_stream_context(self, stream_id: str, context: Any, metadata: Optional[Dict[str, Any]] = None) -> bool: - """添加流上下文 + def get_context(self) -> StreamContext: + """获取流上下文""" + self._update_access_stats() + return self.context - Args: - stream_id: 流ID - context: 上下文对象 - metadata: 上下文元数据 - - Returns: - bool: 是否成功添加 - """ - if stream_id in self.stream_contexts: - logger.warning(f"流上下文已存在: {stream_id}") - return False - - # 添加上下文 - self.stream_contexts[stream_id] = context - - # 初始化元数据 - self.context_metadata[stream_id] = { - "created_time": time.time(), - "last_access_time": time.time(), - "access_count": 0, - "last_validation_time": 0.0, - "custom_metadata": metadata or {}, - } - - # 更新统计 - self.stats["total_streams"] += 1 - self.stats["active_streams"] += 1 - self.stats["last_activity"] = time.time() - - logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})") - return True - - def remove_stream_context(self, stream_id: str) -> bool: - """移除流上下文 - - Args: - stream_id: 流ID - - Returns: - bool: 是否成功移除 - """ - if stream_id in self.stream_contexts: - context = self.stream_contexts[stream_id] - metadata = self.context_metadata.get(stream_id, {}) - - del self.stream_contexts[stream_id] - if stream_id in self.context_metadata: - del self.context_metadata[stream_id] - - self.stats["active_streams"] = max(0, self.stats["active_streams"] - 1) - self.stats["inactive_streams"] += 1 - self.stats["last_activity"] = time.time() - - logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})") - return True - return False - - def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]: - """获取流上下文 - - Args: - stream_id: 流ID - update_access: 是否更新访问统计 - - Returns: - Optional[Any]: 上下文对象 - """ - context = self.stream_contexts.get(stream_id) - if context and update_access: - # 更新访问统计 - if stream_id in self.context_metadata: - metadata = self.context_metadata[stream_id] - metadata["last_access_time"] = time.time() - metadata["access_count"] = metadata.get("access_count", 0) + 1 - return context - - def get_context_metadata(self, stream_id: str) -> Optional[Dict[str, Any]]: - """获取上下文元数据 - - Args: - stream_id: 流ID - - Returns: - Optional[Dict[str, Any]]: 元数据 - """ - return self.context_metadata.get(stream_id) - - def update_context_metadata(self, stream_id: str, updates: Dict[str, Any]) -> bool: - """更新上下文元数据 - - Args: - stream_id: 流ID - updates: 更新的元数据 - - Returns: - bool: 是否成功更新 - """ - if stream_id not in self.context_metadata: - return False - - self.context_metadata[stream_id].update(updates) - return True - - def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool: + def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool: """添加消息到上下文 Args: - stream_id: 流ID message: 消息对象 skip_energy_update: 是否跳过能量更新 Returns: bool: 是否成功添加 """ - context = self.get_stream_context(stream_id) - if not context: - logger.warning(f"流上下文不存在: {stream_id}") - return False - try: # 添加消息到上下文 - context.add_message(message) + self.context.add_message(message) # 计算消息兴趣度 interest_value = self._calculate_message_interest(message) message.interest_value = interest_value # 更新统计 - self.stats["total_messages"] += 1 - self.stats["last_activity"] = time.time() + self.total_messages += 1 + self.last_access_time = time.time() # 更新能量和分发 if not skip_energy_update: - self._update_stream_energy(stream_id) - distribution_manager.add_stream_message(stream_id, 1) + self._update_stream_energy() + distribution_manager.add_stream_message(self.stream_id, 1) - logger.debug(f"添加消息到上下文: {stream_id} (兴趣度: {interest_value:.3f})") + logger.debug(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})") return True except Exception as e: - logger.error(f"添加消息到上下文失败 {stream_id}: {e}", exc_info=True) + logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) return False - def update_message_in_context(self, stream_id: str, message_id: str, updates: Dict[str, Any]) -> bool: + def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool: """更新上下文中的消息 Args: - stream_id: 流ID message_id: 消息ID updates: 更新的属性 Returns: bool: 是否成功更新 """ - context = self.get_stream_context(stream_id) - if not context: - logger.warning(f"流上下文不存在: {stream_id}") - return False - try: # 更新消息信息 - context.update_message_info(message_id, **updates) + self.context.update_message_info(message_id, **updates) # 如果更新了兴趣度,重新计算能量 if "interest_value" in updates: - self._update_stream_energy(stream_id) + self._update_stream_energy() - logger.debug(f"更新上下文消息: {stream_id}/{message_id}") + logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}") return True except Exception as e: - logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True) + logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True) return False - def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]: + def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]: """获取上下文消息 Args: - stream_id: 流ID limit: 消息数量限制 include_unread: 是否包含未读消息 Returns: - List[Any]: 消息列表 + List[DatabaseMessages]: 消息列表 """ - context = self.get_stream_context(stream_id) - if not context: - return [] - try: messages = [] if include_unread: - messages.extend(context.get_unread_messages()) + messages.extend(self.context.get_unread_messages()) if limit: - messages.extend(context.get_history_messages(limit=limit)) + messages.extend(self.context.get_history_messages(limit=limit)) else: - messages.extend(context.get_history_messages()) + messages.extend(self.context.get_history_messages()) # 按时间排序 - messages.sort(key=lambda msg: getattr(msg, 'time', 0)) + messages.sort(key=lambda msg: getattr(msg, "time", 0)) # 应用限制 if limit and len(messages) > limit: @@ -260,103 +131,124 @@ class StreamContextManager: return messages except Exception as e: - logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True) - return [] - - def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]: - """获取未读消息 - - Args: - stream_id: 流ID - - Returns: - List[Any]: 未读消息列表 - """ - context = self.get_stream_context(stream_id) - if not context: + logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True) return [] + def get_unread_messages(self) -> List[DatabaseMessages]: + """获取未读消息""" try: - return context.get_unread_messages() + return self.context.get_unread_messages() except Exception as e: - logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True) + logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True) return [] - def mark_messages_as_read(self, stream_id: str, message_ids: List[str]) -> bool: - """标记消息为已读 - - Args: - stream_id: 流ID - message_ids: 消息ID列表 - - Returns: - bool: 是否成功标记 - """ - context = self.get_stream_context(stream_id) - if not context: - logger.warning(f"流上下文不存在: {stream_id}") - return False - + def mark_messages_as_read(self, message_ids: List[str]) -> bool: + """标记消息为已读""" try: - if not hasattr(context, 'mark_message_as_read'): - logger.error(f"上下文对象缺少 mark_message_as_read 方法: {stream_id}") + if not hasattr(self.context, "mark_message_as_read"): + logger.error(f"上下文对象缺少 mark_message_as_read 方法: {self.stream_id}") return False marked_count = 0 for message_id in message_ids: try: - context.mark_message_as_read(message_id) + self.context.mark_message_as_read(message_id) marked_count += 1 except Exception as e: logger.warning(f"标记消息已读失败 {message_id}: {e}") - logger.debug(f"标记消息为已读: {stream_id} ({marked_count}/{len(message_ids)}条)") + logger.debug(f"标记消息为已读: {self.stream_id} ({marked_count}/{len(message_ids)}条)") return marked_count > 0 except Exception as e: - logger.error(f"标记消息已读失败 {stream_id}: {e}", exc_info=True) - return False - - def clear_context(self, stream_id: str) -> bool: - """清空上下文 - - Args: - stream_id: 流ID - - Returns: - bool: 是否成功清空 - """ - context = self.get_stream_context(stream_id) - if not context: - logger.warning(f"流上下文不存在: {stream_id}") + logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True) return False + def clear_context(self) -> bool: + """清空上下文""" try: # 清空消息 - if hasattr(context, 'unread_messages'): - context.unread_messages.clear() - if hasattr(context, 'history_messages'): - context.history_messages.clear() + if hasattr(self.context, "unread_messages"): + self.context.unread_messages.clear() + if hasattr(self.context, "history_messages"): + self.context.history_messages.clear() # 重置状态 - reset_attrs = ['interruption_count', 'afc_threshold_adjustment', 'last_check_time'] + reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"] for attr in reset_attrs: - if hasattr(context, attr): - if attr in ['interruption_count', 'afc_threshold_adjustment']: - setattr(context, attr, 0) + if hasattr(self.context, attr): + if attr in ["interruption_count", "afc_threshold_adjustment"]: + setattr(self.context, attr, 0) else: - setattr(context, attr, time.time()) + setattr(self.context, attr, time.time()) # 重新计算能量 - self._update_stream_energy(stream_id) + self._update_stream_energy() - logger.info(f"清空上下文: {stream_id}") + logger.info(f"清空单流上下文: {self.stream_id}") return True except Exception as e: - logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True) + logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True) return False + def get_statistics(self) -> Dict[str, Any]: + """获取流统计信息""" + try: + current_time = time.time() + uptime = current_time - self.created_time + + unread_messages = getattr(self.context, "unread_messages", []) + history_messages = getattr(self.context, "history_messages", []) + + return { + "stream_id": self.stream_id, + "context_type": type(self.context).__name__, + "total_messages": len(history_messages) + len(unread_messages), + "unread_messages": len(unread_messages), + "history_messages": len(history_messages), + "is_active": getattr(self.context, "is_active", True), + "last_check_time": getattr(self.context, "last_check_time", current_time), + "interruption_count": getattr(self.context, "interruption_count", 0), + "afc_threshold_adjustment": getattr(self.context, "afc_threshold_adjustment", 0.0), + "created_time": self.created_time, + "last_access_time": self.last_access_time, + "access_count": self.access_count, + "uptime_seconds": uptime, + "idle_seconds": current_time - self.last_access_time, + } + except Exception as e: + logger.error(f"获取单流统计失败 {self.stream_id}: {e}", exc_info=True) + return {} + + def validate_integrity(self) -> bool: + """验证上下文完整性""" + try: + # 检查基本属性 + required_attrs = ["stream_id", "unread_messages", "history_messages"] + for attr in required_attrs: + if not hasattr(self.context, attr): + logger.warning(f"上下文缺少必要属性: {attr}") + return False + + # 检查消息ID唯一性 + all_messages = getattr(self.context, "unread_messages", []) + getattr(self.context, "history_messages", []) + message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")] + if len(message_ids) != len(set(message_ids)): + logger.warning(f"上下文中存在重复消息ID: {self.stream_id}") + return False + + return True + + except Exception as e: + logger.error(f"验证单流上下文完整性失败 {self.stream_id}: {e}") + return False + + def _update_access_stats(self): + """更新访问统计""" + self.last_access_time = time.time() + self.access_count += 1 + def _calculate_message_interest(self, message: DatabaseMessages) -> float: """计算消息兴趣度""" try: @@ -373,8 +265,7 @@ class StreamContextManager: interest_score = loop.run_until_complete( chatter_interest_scoring_system._calculate_single_message_score( - message=message, - bot_nickname=global_config.bot.nickname + message=message, bot_nickname=global_config.bot.nickname ) ) interest_value = interest_score.total_score @@ -391,12 +282,12 @@ class StreamContextManager: logger.error(f"计算消息兴趣度失败: {e}") return 0.5 - def _update_stream_energy(self, stream_id: str): + async def _update_stream_energy(self): """更新流能量""" try: # 获取所有消息 - all_messages = self.get_context_messages(stream_id, self.max_context_size) - unread_messages = self.get_unread_messages(stream_id) + all_messages = self.get_messages(self.max_context_size) + unread_messages = self.get_unread_messages() combined_messages = all_messages + unread_messages # 获取用户ID @@ -406,248 +297,12 @@ class StreamContextManager: user_id = last_message.user_info.user_id # 计算能量 - energy = energy_manager.calculate_focus_energy( - stream_id=stream_id, - messages=combined_messages, - user_id=user_id + energy = await energy_manager.calculate_focus_energy( + stream_id=self.stream_id, messages=combined_messages, user_id=user_id ) # 更新分发管理器 - distribution_manager.update_stream_energy(stream_id, energy) + distribution_manager.update_stream_energy(self.stream_id, energy) except Exception as e: - logger.error(f"更新流能量失败 {stream_id}: {e}") - - def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]: - """获取流统计信息 - - Args: - stream_id: 流ID - - Returns: - Optional[Dict[str, Any]]: 统计信息 - """ - context = self.get_stream_context(stream_id, update_access=False) - if not context: - return None - - try: - metadata = self.context_metadata.get(stream_id, {}) - current_time = time.time() - created_time = metadata.get("created_time", current_time) - last_access_time = metadata.get("last_access_time", current_time) - access_count = metadata.get("access_count", 0) - - unread_messages = getattr(context, "unread_messages", []) - history_messages = getattr(context, "history_messages", []) - - return { - "stream_id": stream_id, - "context_type": type(context).__name__, - "total_messages": len(history_messages) + len(unread_messages), - "unread_messages": len(unread_messages), - "history_messages": len(history_messages), - "is_active": getattr(context, "is_active", True), - "last_check_time": getattr(context, "last_check_time", current_time), - "interruption_count": getattr(context, "interruption_count", 0), - "afc_threshold_adjustment": getattr(context, "afc_threshold_adjustment", 0.0), - "created_time": created_time, - "last_access_time": last_access_time, - "access_count": access_count, - "uptime_seconds": current_time - created_time, - "idle_seconds": current_time - last_access_time, - } - except Exception as e: - logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True) - return None - - def get_manager_statistics(self) -> Dict[str, Any]: - """获取管理器统计信息 - - Returns: - Dict[str, Any]: 管理器统计信息 - """ - current_time = time.time() - uptime = current_time - self.stats.get("creation_time", current_time) - - return { - **self.stats, - "uptime_hours": uptime / 3600, - "stream_count": len(self.stream_contexts), - "metadata_count": len(self.context_metadata), - "auto_cleanup_enabled": self.auto_cleanup, - "cleanup_interval": self.cleanup_interval, - } - - def cleanup_inactive_contexts(self, max_inactive_hours: int = 24) -> int: - """清理不活跃的上下文 - - Args: - max_inactive_hours: 最大不活跃小时数 - - Returns: - int: 清理的上下文数量 - """ - current_time = time.time() - max_inactive_seconds = max_inactive_hours * 3600 - - inactive_streams = [] - for stream_id, context in self.stream_contexts.items(): - try: - # 获取最后活动时间 - metadata = self.context_metadata.get(stream_id, {}) - last_activity = metadata.get("last_access_time", metadata.get("created_time", 0)) - context_last_activity = getattr(context, "last_check_time", 0) - actual_last_activity = max(last_activity, context_last_activity) - - # 检查是否不活跃 - unread_count = len(getattr(context, "unread_messages", [])) - history_count = len(getattr(context, "history_messages", [])) - total_messages = unread_count + history_count - - if (current_time - actual_last_activity > max_inactive_seconds and - total_messages == 0): - inactive_streams.append(stream_id) - except Exception as e: - logger.warning(f"检查上下文活跃状态失败 {stream_id}: {e}") - continue - - # 清理不活跃上下文 - cleaned_count = 0 - for stream_id in inactive_streams: - if self.remove_stream_context(stream_id): - cleaned_count += 1 - - if cleaned_count > 0: - logger.info(f"清理了 {cleaned_count} 个不活跃上下文") - - return cleaned_count - - def validate_context_integrity(self, stream_id: str) -> bool: - """验证上下文完整性 - - Args: - stream_id: 流ID - - Returns: - bool: 是否完整 - """ - context = self.get_stream_context(stream_id) - if not context: - return False - - try: - # 检查基本属性 - required_attrs = ["stream_id", "unread_messages", "history_messages"] - for attr in required_attrs: - if not hasattr(context, attr): - logger.warning(f"上下文缺少必要属性: {attr}") - return False - - # 检查消息ID唯一性 - all_messages = getattr(context, "unread_messages", []) + getattr(context, "history_messages", []) - message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")] - if len(message_ids) != len(set(message_ids)): - logger.warning(f"上下文中存在重复消息ID: {stream_id}") - return False - - return True - - except Exception as e: - logger.error(f"验证上下文完整性失败 {stream_id}: {e}") - return False - - async def start(self) -> None: - """启动上下文管理器""" - if self.is_running: - logger.warning("上下文管理器已经在运行") - return - - await self.start_auto_cleanup() - logger.info("上下文管理器已启动") - - async def stop(self) -> None: - """停止上下文管理器""" - if not self.is_running: - return - - await self.stop_auto_cleanup() - logger.info("上下文管理器已停止") - - async def start_auto_cleanup(self, interval: Optional[float] = None) -> None: - """启动自动清理 - - Args: - interval: 清理间隔(秒) - """ - if not self.auto_cleanup: - logger.info("自动清理已禁用") - return - - if self.is_running: - logger.warning("自动清理已在运行") - return - - self.is_running = True - cleanup_interval = interval or self.cleanup_interval - logger.info(f"启动自动清理(间隔: {cleanup_interval}s)") - - import asyncio - self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval)) - - async def stop_auto_cleanup(self) -> None: - """停止自动清理""" - self.is_running = False - if self.cleanup_task and not self.cleanup_task.done(): - self.cleanup_task.cancel() - try: - await self.cleanup_task - except Exception: - pass - logger.info("自动清理已停止") - - async def _cleanup_loop(self, interval: float) -> None: - """清理循环 - - Args: - interval: 清理间隔 - """ - while self.is_running: - try: - await asyncio.sleep(interval) - self.cleanup_inactive_contexts() - self._cleanup_expired_contexts() - logger.debug("自动清理完成") - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"清理循环出错: {e}", exc_info=True) - await asyncio.sleep(interval) - - def _cleanup_expired_contexts(self) -> None: - """清理过期上下文""" - current_time = time.time() - expired_contexts = [] - - for stream_id, metadata in self.context_metadata.items(): - created_time = metadata.get("created_time", current_time) - if current_time - created_time > self.context_ttl: - expired_contexts.append(stream_id) - - for stream_id in expired_contexts: - self.remove_stream_context(stream_id) - - if expired_contexts: - logger.info(f"清理了 {len(expired_contexts)} 个过期上下文") - - def get_active_streams(self) -> List[str]: - """获取活跃流列表 - - Returns: - List[str]: 活跃流ID列表 - """ - return list(self.stream_contexts.keys()) - - -# 全局上下文管理器实例 -context_manager = StreamContextManager() \ No newline at end of file + logger.error(f"更新单流能量失败 {self.stream_id}: {e}") \ No newline at end of file diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 7c0d77828..6a0eac5e2 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -14,11 +14,10 @@ from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats from src.chat.chatter_manager import ChatterManager from src.chat.planner_actions.action_manager import ChatterActionManager -from src.plugin_system.base.component_types import ChatMode from .sleep_manager.sleep_manager import SleepManager from .sleep_manager.wakeup_manager import WakeUpManager from src.config.config import global_config -from .context_manager import context_manager +from src.plugin_system.apis.chat_api import get_chat_manager if TYPE_CHECKING: from src.common.data_models.message_manager_data_model import StreamContext @@ -45,8 +44,7 @@ class MessageManager: self.sleep_manager = SleepManager() self.wakeup_manager = WakeUpManager(self.sleep_manager) - # 初始化上下文管理器 - self.context_manager = context_manager + # 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context_manager async def start(self): """启动消息管理器""" @@ -57,7 +55,7 @@ class MessageManager: self.is_running = True self.manager_task = asyncio.create_task(self._manager_loop()) await self.wakeup_manager.start() - await self.context_manager.start() + # await self.context_manager.start() # 已删除,需要重构 logger.info("消息管理器已启动") async def stop(self): @@ -73,28 +71,31 @@ class MessageManager: self.manager_task.cancel() await self.wakeup_manager.stop() - await self.context_manager.stop() + # await self.context_manager.stop() # 已删除,需要重构 logger.info("消息管理器已停止") def add_message(self, stream_id: str, message: DatabaseMessages): """添加消息到指定聊天流""" - # 检查流上下文是否存在,不存在则创建 - context = self.context_manager.get_stream_context(stream_id) - if not context: - # 创建新的流上下文 - from src.common.data_models.message_manager_data_model import StreamContext - context = StreamContext(stream_id=stream_id) - # 将创建的上下文添加到管理器 - self.context_manager.add_stream_context(stream_id, context) + try: + # 通过 ChatManager 获取 ChatStream + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(stream_id) - # 使用 context_manager 添加消息 - success = self.context_manager.add_message_to_context(stream_id, message) + if not chat_stream: + logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在") + return - if success: - logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}") - else: - logger.warning(f"添加消息到聊天流 {stream_id} 失败") + # 使用 ChatStream 的 context_manager 添加消息 + success = chat_stream.context_manager.add_message(message) + + if success: + logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}") + else: + logger.warning(f"添加消息到聊天流 {stream_id} 失败") + + except Exception as e: + logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}") def update_message( self, @@ -105,17 +106,60 @@ class MessageManager: should_reply: bool = None, ): """更新消息信息""" - # 使用 context_manager 更新消息信息 - context = self.context_manager.get_stream_context(stream_id) - if context: - context.update_message_info(message_id, interest_value, actions, should_reply) + try: + # 通过 ChatManager 获取 ChatStream + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(stream_id) + + if not chat_stream: + logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在") + return + + # 构建更新字典 + updates = {} + if interest_value is not None: + updates["interest_value"] = interest_value + if actions is not None: + updates["actions"] = actions + if should_reply is not None: + updates["should_reply"] = should_reply + + # 使用 ChatStream 的 context_manager 更新消息 + if updates: + success = chat_stream.context_manager.update_message(message_id, updates) + if success: + logger.debug(f"更新消息 {message_id} 成功") + else: + logger.warning(f"更新消息 {message_id} 失败") + + except Exception as e: + logger.error(f"更新消息 {message_id} 时发生错误: {e}") def add_action(self, stream_id: str, message_id: str, action: str): """添加动作到消息""" - # 使用 context_manager 添加动作到消息 - context = self.context_manager.get_stream_context(stream_id) - if context: - context.add_action_to_message(message_id, action) + try: + # 通过 ChatManager 获取 ChatStream + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(stream_id) + + if not chat_stream: + logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在") + return + + # 使用 ChatStream 的 context_manager 添加动作 + # 注意:这里需要根据实际的 API 调整 + # 假设我们可以通过 update_message 来添加动作 + success = chat_stream.context_manager.update_message( + message_id, {"actions": [action]} + ) + + if success: + logger.debug(f"为消息 {message_id} 添加动作 {action} 成功") + else: + logger.warning(f"为消息 {message_id} 添加动作 {action} 失败") + + except Exception as e: + logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}") async def _manager_loop(self): """管理器主循环 - 独立聊天流分发周期版本""" @@ -145,38 +189,53 @@ class MessageManager: active_streams = 0 total_unread = 0 - # 使用 context_manager 获取活跃的流 - active_stream_ids = self.context_manager.get_active_streams() + # 通过 ChatManager 获取所有活跃的流 + try: + chat_manager = get_chat_manager() + active_stream_ids = list(chat_manager.streams.keys()) - for stream_id in active_stream_ids: - context = self.context_manager.get_stream_context(stream_id) - if not context: - continue + for stream_id in active_stream_ids: + chat_stream = chat_manager.get_stream(stream_id) + if not chat_stream: + continue - active_streams += 1 + # 检查流是否活跃 + context = chat_stream.stream_context + if not context.is_active: + continue - # 检查是否有未读消息 - unread_messages = self.context_manager.get_unread_messages(stream_id) - if unread_messages: - total_unread += len(unread_messages) + active_streams += 1 - # 如果没有处理任务,创建一个 - if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done(): - context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) + # 检查是否有未读消息 + unread_messages = chat_stream.context_manager.get_unread_messages() + if unread_messages: + total_unread += len(unread_messages) - # 更新统计 - self.stats.active_streams = active_streams - self.stats.total_unread_messages = total_unread + # 如果没有处理任务,创建一个 + if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done(): + context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) + + # 更新统计 + self.stats.active_streams = active_streams + self.stats.total_unread_messages = total_unread + + except Exception as e: + logger.error(f"检查所有聊天流时发生错误: {e}") async def _process_stream_messages(self, stream_id: str): """处理指定聊天流的消息""" - context = self.context_manager.get_stream_context(stream_id) - if not context: - return - try: + # 通过 ChatManager 获取 ChatStream + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(stream_id) + if not chat_stream: + logger.warning(f"处理消息失败: 聊天流 {stream_id} 不存在") + return + + context = chat_stream.stream_context + # 获取未读消息 - unread_messages = self.context_manager.get_unread_messages(stream_id) + unread_messages = chat_stream.context_manager.get_unread_messages() if not unread_messages: return @@ -250,8 +309,15 @@ class MessageManager: def deactivate_stream(self, stream_id: str): """停用聊天流""" - context = self.context_manager.get_stream_context(stream_id) - if context: + try: + # 通过 ChatManager 获取 ChatStream + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(stream_id) + if not chat_stream: + logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在") + return + + context = chat_stream.stream_context context.is_active = False # 取消处理任务 @@ -260,27 +326,50 @@ class MessageManager: logger.info(f"停用聊天流: {stream_id}") + except Exception as e: + logger.error(f"停用聊天流 {stream_id} 时发生错误: {e}") + def activate_stream(self, stream_id: str): """激活聊天流""" - context = self.context_manager.get_stream_context(stream_id) - if context: + try: + # 通过 ChatManager 获取 ChatStream + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(stream_id) + if not chat_stream: + logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在") + return + + context = chat_stream.stream_context context.is_active = True logger.info(f"激活聊天流: {stream_id}") + except Exception as e: + logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}") + def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]: """获取聊天流统计""" - context = self.context_manager.get_stream_context(stream_id) - if not context: - return None + try: + # 通过 ChatManager 获取 ChatStream + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(stream_id) + if not chat_stream: + return None - return StreamStats( - stream_id=stream_id, - is_active=context.is_active, - unread_count=len(self.context_manager.get_unread_messages(stream_id)), - history_count=len(context.history_messages), - last_check_time=context.last_check_time, - has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()), - ) + context = chat_stream.stream_context + unread_count = len(chat_stream.context_manager.get_unread_messages()) + + return StreamStats( + stream_id=stream_id, + is_active=context.is_active, + unread_count=unread_count, + history_count=len(context.history_messages), + last_check_time=context.last_check_time, + has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()), + ) + + except Exception as e: + logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}") + return None def get_manager_stats(self) -> Dict[str, Any]: """获取管理器统计""" @@ -295,9 +384,36 @@ class MessageManager: def cleanup_inactive_streams(self, max_inactive_hours: int = 24): """清理不活跃的聊天流""" - # 使用 context_manager 的自动清理功能 - self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600) - logger.info("已启动不活跃聊天流清理") + try: + # 通过 ChatManager 清理不活跃的流 + chat_manager = get_chat_manager() + current_time = time.time() + max_inactive_seconds = max_inactive_hours * 3600 + + inactive_streams = [] + for stream_id, chat_stream in chat_manager.streams.items(): + # 检查最后活跃时间 + if current_time - chat_stream.last_active_time > max_inactive_seconds: + inactive_streams.append(stream_id) + + # 清理不活跃的流 + for stream_id in inactive_streams: + try: + # 清理流的内容 + chat_stream.context_manager.clear_context() + # 从 ChatManager 中移除 + del chat_manager.streams[stream_id] + logger.info(f"清理不活跃聊天流: {stream_id}") + except Exception as e: + logger.error(f"清理聊天流 {stream_id} 失败: {e}") + + if inactive_streams: + logger.info(f"已清理 {len(inactive_streams)} 个不活跃聊天流") + else: + logger.debug("没有需要清理的不活跃聊天流") + + except Exception as e: + logger.error(f"清理不活跃聊天流时发生错误: {e}") async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str): """检查并处理消息打断""" @@ -376,115 +492,123 @@ class MessageManager: min_delay = float("inf") # 找到最近需要检查的流 - active_stream_ids = self.context_manager.get_active_streams() - for stream_id in active_stream_ids: - context = self.context_manager.get_stream_context(stream_id) - if not context or not context.is_active: - continue + try: + chat_manager = get_chat_manager() + for _stream_id, chat_stream in chat_manager.streams.items(): + context = chat_stream.stream_context + if not context or not context.is_active: + continue - time_until_check = context.next_check_time - current_time - if time_until_check > 0: - min_delay = min(min_delay, time_until_check) - else: - min_delay = 0.1 # 立即检查 - break + time_until_check = context.next_check_time - current_time + if time_until_check > 0: + min_delay = min(min_delay, time_until_check) + else: + min_delay = 0.1 # 立即检查 + break - # 如果没有活跃流,使用默认间隔 - if min_delay == float("inf"): + # 如果没有活跃流,使用默认间隔 + if min_delay == float("inf"): + return self.check_interval + + # 确保最小延迟 + return max(0.1, min(min_delay, self.check_interval)) + + except Exception as e: + logger.error(f"计算下次检查延迟时发生错误: {e}") return self.check_interval - # 确保最小延迟 - return max(0.1, min(min_delay, self.check_interval)) - async def _check_streams_with_individual_intervals(self): """检查所有达到检查时间的聊天流""" current_time = time.time() processed_streams = 0 - # 使用 context_manager 获取活跃的流 - active_stream_ids = self.context_manager.get_active_streams() + # 通过 ChatManager 获取活跃的流 + try: + chat_manager = get_chat_manager() + for stream_id, chat_stream in chat_manager.streams.items(): + context = chat_stream.stream_context + if not context or not context.is_active: + continue - for stream_id in active_stream_ids: - context = self.context_manager.get_stream_context(stream_id) - if not context or not context.is_active: - continue + # 检查是否达到检查时间 + if current_time >= context.next_check_time: + # 更新检查时间 + context.last_check_time = current_time - # 检查是否达到检查时间 - if current_time >= context.next_check_time: - # 更新检查时间 - context.last_check_time = current_time + # 计算下次检查时间和分发周期 + if global_config.chat.dynamic_distribution_enabled: + context.distribution_interval = self._calculate_stream_distribution_interval(context) + else: + context.distribution_interval = self.check_interval - # 计算下次检查时间和分发周期 - if global_config.chat.dynamic_distribution_enabled: - context.distribution_interval = self._calculate_stream_distribution_interval(context) - else: - context.distribution_interval = self.check_interval + # 设置下次检查时间 + context.next_check_time = current_time + context.distribution_interval - # 设置下次检查时间 - context.next_check_time = current_time + context.distribution_interval + # 检查未读消息 + unread_messages = chat_stream.context_manager.get_unread_messages() + if unread_messages: + processed_streams += 1 + self.stats.total_unread_messages = len(unread_messages) - # 检查未读消息 - unread_messages = self.context_manager.get_unread_messages(stream_id) - if unread_messages: - processed_streams += 1 - self.stats.total_unread_messages = len(unread_messages) + # 如果没有处理任务,创建一个 + if not context.processing_task or context.processing_task.done(): + focus_energy = chat_stream.focus_energy - # 如果没有处理任务,创建一个 - if not context.processing_task or context.processing_task.done(): - from src.plugin_system.apis.chat_api import get_chat_manager + # 根据优先级记录日志 + if focus_energy >= 0.7: + logger.info( + f"高优先级流 {stream_id} 开始处理 | " + f"focus_energy: {focus_energy:.3f} | " + f"分发周期: {context.distribution_interval:.2f}s | " + f"未读消息: {len(unread_messages)}" + ) + else: + logger.debug( + f"流 {stream_id} 开始处理 | " + f"focus_energy: {focus_energy:.3f} | " + f"分发周期: {context.distribution_interval:.2f}s" + ) - chat_stream = get_chat_manager().get_stream(context.stream_id) - focus_energy = chat_stream.focus_energy if chat_stream else 0.5 + context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) - # 根据优先级记录日志 - if focus_energy >= 0.7: - logger.info( - f"高优先级流 {stream_id} 开始处理 | " - f"focus_energy: {focus_energy:.3f} | " - f"分发周期: {context.distribution_interval:.2f}s | " - f"未读消息: {len(unread_messages)}" - ) - else: - logger.debug( - f"流 {stream_id} 开始处理 | " - f"focus_energy: {focus_energy:.3f} | " - f"分发周期: {context.distribution_interval:.2f}s" - ) - - context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) + except Exception as e: + logger.error(f"检查独立分发周期的聊天流时发生错误: {e}") # 更新活跃流计数 - active_count = len(self.context_manager.get_active_streams()) - self.stats.active_streams = active_count + try: + chat_manager = get_chat_manager() + active_count = len([s for s in chat_manager.streams.values() if s.stream_context.is_active]) + self.stats.active_streams = active_count - if processed_streams > 0: - logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}") + if processed_streams > 0: + logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}") + except Exception as e: + logger.error(f"更新活跃流计数时发生错误: {e}") async def _check_all_streams_with_priority(self): """按优先级检查所有聊天流,高focus_energy的流优先处理""" - if not self.context_manager.get_active_streams(): - return + try: + chat_manager = get_chat_manager() + if not chat_manager.streams: + return - # 获取活跃的聊天流并按focus_energy排序 - active_streams = [] - active_stream_ids = self.context_manager.get_active_streams() + # 获取活跃的聊天流并按focus_energy排序 + active_streams = [] + for stream_id, chat_stream in chat_manager.streams.items(): + context = chat_stream.stream_context + if not context or not context.is_active: + continue - for stream_id in active_stream_ids: - context = self.context_manager.get_stream_context(stream_id) - if not context or not context.is_active: - continue - - # 获取focus_energy,如果不存在则使用默认值 - from src.plugin_system.apis.chat_api import get_chat_manager - - chat_stream = get_chat_manager().get_stream(context.stream_id) - focus_energy = 0.5 - if chat_stream: + # 获取focus_energy focus_energy = chat_stream.focus_energy - # 计算流优先级分数 - priority_score = self._calculate_stream_priority(context, focus_energy) - active_streams.append((priority_score, stream_id, context)) + # 计算流优先级分数 + priority_score = self._calculate_stream_priority(context, focus_energy) + active_streams.append((priority_score, stream_id, context)) + + except Exception as e: + logger.error(f"获取活跃流列表时发生错误: {e}") + return # 按优先级降序排序 active_streams.sort(reverse=True, key=lambda x: x[0]) @@ -497,21 +621,29 @@ class MessageManager: active_stream_count += 1 # 检查是否有未读消息 - unread_messages = self.context_manager.get_unread_messages(stream_id) - if unread_messages: - total_unread += len(unread_messages) + try: + chat_stream = chat_manager.get_stream(stream_id) + if not chat_stream: + continue - # 如果没有处理任务,创建一个 - if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done(): - context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) + unread_messages = chat_stream.context_manager.get_unread_messages() + if unread_messages: + total_unread += len(unread_messages) - # 高优先级流的额外日志 - if priority_score > 0.7: - logger.info( - f"高优先级流 {stream_id} 开始处理 | " - f"优先级: {priority_score:.3f} | " - f"未读消息: {len(unread_messages)}" - ) + # 如果没有处理任务,创建一个 + if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done(): + context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) + + # 高优先级流的额外日志 + if priority_score > 0.7: + logger.info( + f"高优先级流 {stream_id} 开始处理 | " + f"优先级: {priority_score:.3f} | " + f"未读消息: {len(unread_messages)}" + ) + except Exception as e: + logger.error(f"处理流 {stream_id} 的未读消息时发生错误: {e}") + continue # 更新统计 self.stats.active_streams = active_stream_count @@ -536,22 +668,33 @@ class MessageManager: def _clear_all_unread_messages(self, stream_id: str): """清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读""" - unread_messages = self.context_manager.get_unread_messages(stream_id) - if not unread_messages: - return + try: + # 通过 ChatManager 获取 ChatStream + chat_manager = get_chat_manager() + chat_stream = chat_manager.get_stream(stream_id) + if not chat_stream: + logger.warning(f"清除消息失败: 聊天流 {stream_id} 不存在") + return - logger.warning(f"正在清除 {len(unread_messages)} 条未读消息") + # 获取未读消息 + unread_messages = chat_stream.context_manager.get_unread_messages() + if not unread_messages: + return - # 将所有未读消息标记为已读 - context = self.context_manager.get_stream_context(stream_id) - if context: - for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表 - try: - context.mark_message_as_read(msg.message_id) - self.stats.total_processed_messages += 1 - logger.debug(f"强制清除消息 {msg.message_id},标记为已读") - except Exception as e: - logger.error(f"清除消息 {msg.message_id} 时出错: {e}") + logger.warning(f"正在清除 {len(unread_messages)} 条未读消息") + + # 将所有未读消息标记为已读 + message_ids = [msg.message_id for msg in unread_messages] + success = chat_stream.context_manager.mark_messages_as_read(message_ids) + + if success: + self.stats.total_processed_messages += len(unread_messages) + logger.debug(f"强制清除 {len(unread_messages)} 条消息,标记为已读") + else: + logger.error("标记未读消息为已读失败") + + except Exception as e: + logger.error(f"清除未读消息时发生错误: {e}") # 创建全局消息管理器实例 diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 53d9ab0ed..007ab4dc1 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -49,10 +49,18 @@ class ChatStream: from src.common.data_models.message_manager_data_model import StreamContext from src.plugin_system.base.component_types import ChatType, ChatMode + # 创建StreamContext self.stream_context: StreamContext = StreamContext( stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL ) + # 创建单流上下文管理器 + from src.chat.message_manager.context_manager import SingleStreamContextManager + + self.context_manager: SingleStreamContextManager = SingleStreamContextManager( + stream_id=stream_id, context=self.stream_context + ) + # 基础参数 self.base_interest_energy = 0.5 # 默认基础兴趣度 self._focus_energy = 0.5 # 内部存储的focus_energy值 @@ -61,6 +69,37 @@ class ChatStream: # 自动加载历史消息 self._load_history_messages() + def __deepcopy__(self, memo): + """自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象""" + import copy + + # 创建新的实例 + new_stream = ChatStream( + stream_id=self.stream_id, + platform=self.platform, + user_info=copy.deepcopy(self.user_info, memo), + group_info=copy.deepcopy(self.group_info, memo), + ) + + # 复制基本属性 + new_stream.create_time = self.create_time + new_stream.last_active_time = self.last_active_time + new_stream.sleep_pressure = self.sleep_pressure + new_stream.saved = self.saved + new_stream.base_interest_energy = self.base_interest_energy + new_stream._focus_energy = self._focus_energy + new_stream.no_reply_consecutive = self.no_reply_consecutive + + # 复制 stream_context,但跳过 processing_task + new_stream.stream_context = copy.deepcopy(self.stream_context, memo) + if hasattr(new_stream.stream_context, 'processing_task'): + new_stream.stream_context.processing_task = None + + # 复制 context_manager + new_stream.context_manager = copy.deepcopy(self.context_manager, memo) + + return new_stream + def to_dict(self) -> dict: """转换为字典格式""" return { @@ -74,10 +113,10 @@ class ChatStream: "focus_energy": self.focus_energy, # 基础兴趣度 "base_interest_energy": self.base_interest_energy, - # 新增stream_context信息 + # stream_context基本信息 "stream_context_chat_type": self.stream_context.chat_type.value, "stream_context_chat_mode": self.stream_context.chat_mode.value, - # 新增interruption_count信息 + # 统计信息 "interruption_count": self.stream_context.interruption_count, } @@ -109,6 +148,14 @@ class ChatStream: if "interruption_count" in data: instance.stream_context.interruption_count = data["interruption_count"] + # 确保 context_manager 已初始化 + if not hasattr(instance, "context_manager"): + from src.chat.message_manager.context_manager import SingleStreamContextManager + + instance.context_manager = SingleStreamContextManager( + stream_id=instance.stream_id, context=instance.stream_context + ) + return instance def update_active_time(self): @@ -195,12 +242,14 @@ class ChatStream: self.stream_context.priority_info = getattr(message, "priority_info", None) # 调试日志:记录数据转移情况 - logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, " - f"chat_id: {db_message.chat_id}, " - f"is_mentioned: {db_message.is_mentioned}, " - f"is_emoji: {db_message.is_emoji}, " - f"is_picid: {db_message.is_picid}, " - f"interest_value: {db_message.interest_value}") + logger.debug( + f"消息数据转移完成 - message_id: {db_message.message_id}, " + f"chat_id: {db_message.chat_id}, " + f"is_mentioned: {db_message.is_mentioned}, " + f"is_emoji: {db_message.is_emoji}, " + f"is_picid: {db_message.is_picid}, " + f"interest_value: {db_message.interest_value}" + ) def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]: """安全获取消息的actions字段""" @@ -213,6 +262,7 @@ class ChatStream: if isinstance(actions, str): try: import json + actions = json.loads(actions) except json.JSONDecodeError: logger.warning(f"无法解析actions JSON字符串: {actions}") @@ -269,14 +319,17 @@ class ChatStream: @property def focus_energy(self) -> float: - """使用重构后的能量管理器计算focus_energy""" - try: - from src.chat.energy_system import energy_manager + """获取缓存的focus_energy值""" + if hasattr(self, "_focus_energy"): + return self._focus_energy + else: + return 0.5 - # 获取所有消息 - history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size) - unread_messages = self.stream_context.get_unread_messages() - all_messages = history_messages + unread_messages + async def calculate_focus_energy(self) -> float: + """异步计算focus_energy""" + try: + # 使用单流上下文管理器获取消息 + all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size) # 获取用户ID user_id = None @@ -284,10 +337,10 @@ class ChatStream: user_id = str(self.user_info.user_id) # 使用能量管理器计算 - energy = energy_manager.calculate_focus_energy( - stream_id=self.stream_id, - messages=all_messages, - user_id=user_id + from src.chat.energy_system import energy_manager + + energy = await energy_manager.calculate_focus_energy( + stream_id=self.stream_id, messages=all_messages, user_id=user_id ) # 更新内部存储 @@ -299,7 +352,7 @@ class ChatStream: except Exception as e: logger.error(f"获取focus_energy失败: {e}", exc_info=True) # 返回缓存的值或默认值 - if hasattr(self, '_focus_energy'): + if hasattr(self, "_focus_energy"): return self._focus_energy else: return 0.5 @@ -309,7 +362,7 @@ class ChatStream: """设置focus_energy值(主要用于初始化或特殊场景)""" self._focus_energy = max(0.0, min(1.0, value)) - def _get_user_relationship_score(self) -> float: + async def _get_user_relationship_score(self) -> float: """获取用户关系分""" # 使用插件内部的兴趣度评分系统 try: @@ -317,7 +370,7 @@ class ChatStream: if self.user_info and hasattr(self.user_info, "user_id"): user_id = str(self.user_info.user_id) - relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id) + relationship_score = await chatter_interest_scoring_system._calculate_relationship_score(user_id) logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}") return max(0.0, min(1.0, relationship_score)) @@ -346,7 +399,8 @@ class ChatStream: .order_by(desc(Messages.time)) .limit(global_config.chat.max_context_size) ) - results = session.execute(stmt).scalars().all() + result = session.execute(stmt) + results = result.scalars().all() return results # 在线程中执行数据库查询 @@ -404,7 +458,9 @@ class ChatStream: ) # 添加调试日志:检查从数据库加载的interest_value - logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}") + logger.debug( + f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}" + ) # 标记为已读并添加到历史消息 db_message.is_read = True @@ -548,7 +604,11 @@ class ChatManager: # 检查数据库中是否存在 async def _db_find_stream_async(s_id: str): async with get_db_session() as session: - return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalars().first() + return ( + (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))) + .scalars() + .first() + ) model_instance = await _db_find_stream_async(stream_id) @@ -603,6 +663,15 @@ class ChatManager: stream.set_context(self.last_messages[stream_id]) else: logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") + + # 确保 ChatStream 有自己的 context_manager + if not hasattr(stream, "context_manager"): + # 创建新的单流上下文管理器 + from src.chat.message_manager.context_manager import SingleStreamContextManager + stream.context_manager = SingleStreamContextManager( + stream_id=stream_id, context=stream.stream_context + ) + # 保存到内存和数据库 self.streams[stream_id] = stream await self._save_stream(stream) @@ -704,7 +773,8 @@ class ChatManager: async def _db_load_all_streams_async(): loaded_streams_data = [] async with get_db_session() as session: - for model_instance in (await session.execute(select(ChatStreams))).scalars().all(): + result = await session.execute(select(ChatStreams)) + for model_instance in result.scalars().all(): user_info_data = { "platform": model_instance.user_platform, "user_id": model_instance.user_id, @@ -752,6 +822,13 @@ class ChatManager: self.streams[stream.stream_id] = stream if stream.stream_id in self.last_messages: stream.set_context(self.last_messages[stream.stream_id]) + + # 确保 ChatStream 有自己的 context_manager + if not hasattr(stream, "context_manager"): + from src.chat.message_manager.context_manager import SingleStreamContextManager + stream.context_manager = SingleStreamContextManager( + stream_id=stream.stream_id, context=stream.stream_context + ) except Exception as e: logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index b37301f47..60583c1f8 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -41,7 +41,7 @@ class MessageStorage: processed_plain_text = message.processed_plain_text if processed_plain_text: - processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text) + processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL) else: filtered_processed_plain_text = "" @@ -129,9 +129,9 @@ class MessageStorage: key_words=key_words, key_words_lite=key_words_lite, ) - with get_db_session() as session: + async with get_db_session() as session: session.add(new_message) - session.commit() + await session.commit() except Exception: logger.exception("存储消息失败") @@ -174,13 +174,13 @@ class MessageStorage: # 使用上下文管理器确保session正确管理 from src.common.database.sqlalchemy_models import get_db_session - with get_db_session() as session: - matched_message = session.execute( + async with get_db_session() as session: + matched_message = (await session.execute( select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) - ).scalar() + )).scalar() if matched_message: - session.execute( + await session.execute( update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) ) logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") @@ -195,7 +195,7 @@ class MessageStorage: ) @staticmethod - def replace_image_descriptions(text: str) -> str: + async def replace_image_descriptions(text: str) -> str: """将[图片:描述]替换为[picid:image_id]""" # 先检查文本中是否有图片标记 pattern = r"\[图片:([^\]]+)\]" @@ -205,15 +205,15 @@ class MessageStorage: logger.debug("文本中没有图片标记,直接返回原文本") return text - def replace_match(match): + async def replace_match(match): description = match.group(1).strip() try: from src.common.database.sqlalchemy_models import get_db_session - with get_db_session() as session: - image_record = session.execute( + async with get_db_session() as session: + image_record = (await session.execute( select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) - ).scalar() + )).scalar() return f"[picid:{image_record.image_id}]" if image_record else match.group(0) except Exception: return match.group(0) @@ -271,7 +271,8 @@ class MessageStorage: ) ).limit(50) # 限制每次修复的数量,避免性能问题 - messages_to_fix = session.execute(query).scalars().all() + result = session.execute(query) + messages_to_fix = result.scalars().all() fixed_count = 0 for msg in messages_to_fix: diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 7335b5546..1ccc916db 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -824,7 +824,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: description = "[图片内容未知]" # 默认描述 try: with get_db_session() as session: - image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none() + result = session.execute(select(Images).where(Images.image_id == pic_id)) + image = result.scalar_one_or_none() if image and image.description: # type: ignore description = image.description except Exception: diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 93ec14957..c468520d9 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -308,7 +308,8 @@ class ImageManager: async with get_db_session() as session: # 优先检查Images表中是否已有完整的描述 - existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar() + existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash)) + result.scalar() if existing_image: # 更新计数 if hasattr(existing_image, "count") and existing_image.count is not None: @@ -527,7 +528,8 @@ class ImageManager: image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() async with get_db_session() as session: - existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar() + existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash)) + result.scalar() if existing_image: # 检查是否缺少必要字段,如果缺少则创建新记录 if ( diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index f6acb1a7d..6ecd599af 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -22,13 +22,14 @@ from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger from src.common.database.sqlalchemy_models import get_db_session, Videos +from sqlalchemy import select logger = get_logger("utils_video") # Rust模块可用性检测 RUST_VIDEO_AVAILABLE = False try: - import rust_video + import rust_video # pyright: ignore[reportMissingImports] RUST_VIDEO_AVAILABLE = True logger.info("✅ Rust 视频处理模块加载成功") @@ -202,19 +203,21 @@ class VideoAnalyzer: hash_obj.update(video_data) return hash_obj.hexdigest() - def _check_video_exists(self, video_hash: str) -> Optional[Videos]: + async def _check_video_exists(self, video_hash: str) -> Optional[Videos]: """检查视频是否已经分析过""" try: - with get_db_session() as session: + async with get_db_session() as session: # 明确刷新会话以确保看到其他事务的最新提交 - session.expire_all() - return session.query(Videos).filter(Videos.video_hash == video_hash).first() + await session.expire_all() + stmt = select(Videos).where(Videos.video_hash == video_hash) + result = await session.execute(stmt) + return result.scalar_one_or_none() except Exception as e: logger.warning(f"检查视频是否存在时出错: {e}") return None - def _store_video_result( - self, video_hash: str, description: str, metadata: Optional[Dict] = None + async def _store_video_result( + self, video_hash: str, description: str, metadata: Optional[Dict] = None ) -> Optional[Videos]: """存储视频分析结果到数据库""" # 检查描述是否为错误信息,如果是则不保存 @@ -223,9 +226,11 @@ class VideoAnalyzer: return None try: - with get_db_session() as session: + async with get_db_session() as session: # 只根据video_hash查找 - existing_video = session.query(Videos).filter(Videos.video_hash == video_hash).first() + stmt = select(Videos).where(Videos.video_hash == video_hash) + result = await session.execute(stmt) + existing_video = result.scalar_one_or_none() if existing_video: # 如果已存在,更新描述和计数 @@ -238,8 +243,8 @@ class VideoAnalyzer: existing_video.fps = metadata.get("fps") existing_video.resolution = metadata.get("resolution") existing_video.file_size = metadata.get("file_size") - session.commit() - session.refresh(existing_video) + await session.commit() + await session.refresh(existing_video) logger.info(f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}") return existing_video else: @@ -254,8 +259,8 @@ class VideoAnalyzer: video_record.file_size = metadata.get("file_size") session.add(video_record) - session.commit() - session.refresh(video_record) + await session.commit() + await session.refresh(video_record) logger.info(f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}...") return video_record except Exception as e: @@ -704,7 +709,7 @@ class VideoAnalyzer: logger.info("✅ 等待结束,检查是否有处理结果") # 检查是否有结果了 - existing_video = self._check_video_exists(video_hash) + existing_video = await self._check_video_exists(video_hash) if existing_video: logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})") return {"summary": existing_video.description} @@ -718,7 +723,7 @@ class VideoAnalyzer: logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)") # 再次检查数据库(可能在等待期间已经有结果了) - existing_video = self._check_video_exists(video_hash) + existing_video = await self._check_video_exists(video_hash) if existing_video: logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})") video_event.set() # 通知其他等待者 @@ -749,7 +754,7 @@ class VideoAnalyzer: # 保存分析结果到数据库(仅保存成功的结果) if success and not result.startswith("❌"): metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()} - self._store_video_result(video_hash=video_hash, description=result, metadata=metadata) + await self._store_video_result(video_hash=video_hash, description=result, metadata=metadata) logger.info("✅ 分析结果已保存到数据库") else: logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试") diff --git a/src/common/database/database.py b/src/common/database/database.py index 1815a98ff..92c851edb 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -22,9 +22,9 @@ class DatabaseProxy: self._session = None @staticmethod - def initialize(*args, **kwargs): + async def initialize(*args, **kwargs): """初始化数据库连接""" - return initialize_database_compat() + return await initialize_database_compat() class SQLAlchemyTransaction: @@ -88,7 +88,7 @@ async def initialize_sql_database(database_config): logger.info(f" 数据库文件: {db_path}") # 使用SQLAlchemy初始化 - success = initialize_database_compat() + success = await initialize_database_compat() if success: _sql_engine = await get_engine() logger.info("SQLAlchemy数据库初始化成功") diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 2469fa642..96ef59135 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -706,7 +706,8 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: raise RuntimeError("Database session not initialized") session = SessionLocal() yield session - except Exception: + except Exception as e: + logger.error(f"数据库会话错误: {e}") if session: await session.rollback() raise diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 992ad3320..f295f8e8a 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -101,7 +101,8 @@ def find_messages( # 获取时间最早的 limit 条记录,已经是正序 query = query.order_by(Messages.time.asc()).limit(limit) try: - results = session.execute(query).scalars().all() + results = result = session.execute(query) + result.scalars().all() except Exception as e: logger.error(f"执行earliest查询失败: {e}") results = [] @@ -109,7 +110,8 @@ def find_messages( # 获取时间最晚的 limit 条记录 query = query.order_by(Messages.time.desc()).limit(limit) try: - latest_results = session.execute(query).scalars().all() + latest_results = result = session.execute(query) + result.scalars().all() # 将结果按时间正序排列 results = sorted(latest_results, key=lambda msg: msg.time) except Exception as e: @@ -133,7 +135,8 @@ def find_messages( if sort_terms: query = query.order_by(*sort_terms) try: - results = session.execute(query).scalars().all() + results = result = session.execute(query) + result.scalars().all() except Exception as e: logger.error(f"执行无限制查询失败: {e}") results = [] @@ -207,5 +210,5 @@ def count_messages(message_filter: dict[str, Any]) -> int: # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 -# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。 +# 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。 # 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。 diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index 34949e968..c322e2ffb 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -161,7 +161,7 @@ class LLMUsageRecorder: session = None try: # 使用 SQLAlchemy 会话创建记录 - with get_db_session() as session: + async with get_db_session() as session: usage_record = LLMUsage( model_name=model_info.model_identifier, model_assign_name=model_info.name, @@ -172,14 +172,14 @@ class LLMUsageRecorder: prompt_tokens=model_usage.prompt_tokens or 0, completion_tokens=model_usage.completion_tokens or 0, total_tokens=model_usage.total_tokens or 0, - cost=total_cost or 0.0, + cost=1.0, time_cost=round(time_cost or 0.0, 3), status="success", timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段 ) session.add(usage_record) - session.commit() + await session.commit() logger.debug( f"Token使用情况 - 模型: {model_usage.model_name}, " diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index cc84b8967..bba004356 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -163,7 +163,8 @@ class PersonInfoManager: try: # 在需要时获取会话 async with get_db_session() as session: - record = (await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))).scalar() + record = result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)) + result.scalar() return record.person_id if record else "" except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}") @@ -339,7 +340,8 @@ class PersonInfoManager: start_time = time.time() async with get_db_session() as session: try: - record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() + result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)) + record = result.scalar() query_time = time.time() if record: setattr(record, f_name, val_to_set) @@ -401,7 +403,8 @@ class PersonInfoManager: async def _db_has_field_async(p_id: str, f_name: str): async with get_db_session() as session: - record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() + result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)) + record = result.scalar() return bool(record) try: @@ -512,10 +515,9 @@ class PersonInfoManager: async def _db_check_name_exists_async(name_to_check): async with get_db_session() as session: - return ( - (await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))).scalar() - is not None - ) + result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)) + record = result.scalar() + return record is not None if await _db_check_name_exists_async(generated_nickname): is_duplicate = True @@ -556,7 +558,8 @@ class PersonInfoManager: async def _db_delete_async(p_id: str): try: async with get_db_session() as session: - record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() + result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)) + record = result.scalar() if record: await session.delete(record) await session.commit() @@ -585,7 +588,9 @@ class PersonInfoManager: async def _get_record_sync(): async with get_db_session() as session: - return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))).scalar() + result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)) + record = result.scalar() + return record try: record = asyncio.run(_get_record_sync()) @@ -624,7 +629,9 @@ class PersonInfoManager: async def _db_get_record_async(p_id: str): async with get_db_session() as session: - return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() + result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)) + record = result.scalar() + return record record = await _db_get_record_async(person_id) @@ -700,7 +707,8 @@ class PersonInfoManager: """原子性的获取或创建操作""" async with get_db_session() as session: # 首先尝试获取现有记录 - record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() + result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)) + record = result.scalar() if record: return record, False # 记录存在,未创建 @@ -715,9 +723,10 @@ class PersonInfoManager: # 如果创建失败(可能是因为竞态条件),再次尝试获取 if "UNIQUE constraint failed" in str(e): logger.debug(f"检测到并发创建用户 {p_id},获取现有记录") - record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() - if record: - return record, False # 其他协程已创建,返回现有记录 + result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)) + record = result.scalar() + if record: + return record, False # 其他协程已创建,返回现有记录 # 如果仍然失败,重新抛出异常 raise e diff --git a/src/plugin_system/apis/emoji_api.py b/src/plugin_system/apis/emoji_api.py index 479f3aec1..4fbadb98f 100644 --- a/src/plugin_system/apis/emoji_api.py +++ b/src/plugin_system/apis/emoji_api.py @@ -122,7 +122,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情" # 记录使用次数 - emoji_manager.record_usage(selected_emoji.hash) + await emoji_manager.record_usage(selected_emoji.hash) results.append((emoji_base64, selected_emoji.description, matched_emotion)) if not results and count > 0: @@ -180,7 +180,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: return None # 记录使用次数 - emoji_manager.record_usage(selected_emoji.hash) + await emoji_manager.record_usage(selected_emoji.hash) logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}") return emoji_base64, selected_emoji.description, emotion diff --git a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py index 08f5f7098..9d9786e5d 100644 --- a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py +++ b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py @@ -65,7 +65,7 @@ class AffinityChatter(BaseChatter): """ try: # 触发表达学习 - learner = expression_learner_manager.get_expression_learner(self.stream_id) + learner = await expression_learner_manager.get_expression_learner(self.stream_id) asyncio.create_task(learner.trigger_learning_for_chat()) unread_messages = context.get_unread_messages() diff --git a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py index 0538090bc..391ca58fb 100644 --- a/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py +++ b/src/plugins/built_in/affinity_flow_chatter/interest_scoring.py @@ -69,7 +69,7 @@ class ChatterInterestScoringSystem: keywords = self._extract_keywords_from_database(message) interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords) - relationship_score = self._calculate_relationship_score(message.user_info.user_id) + relationship_score = await self._calculate_relationship_score(message.user_info.user_id) mentioned_score = self._calculate_mentioned_score(message, bot_nickname) total_score = ( @@ -189,7 +189,7 @@ class ChatterInterestScoringSystem: unique_keywords = list(set(keywords)) return unique_keywords[:10] # 返回前10个唯一关键词 - def _calculate_relationship_score(self, user_id: str) -> float: + async def _calculate_relationship_score(self, user_id: str) -> float: """计算关系分 - 从数据库获取关系分""" # 优先使用内存中的关系分 if user_id in self.user_relationships: @@ -212,7 +212,7 @@ class ChatterInterestScoringSystem: global_tracker = ChatterRelationshipTracker() if global_tracker: - relationship_score = global_tracker.get_user_relationship_score(user_id) + relationship_score = await global_tracker.get_user_relationship_score(user_id) # 同时更新内存缓存 self.user_relationships[user_id] = relationship_score return relationship_score diff --git a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py index c0050025e..66b3ca31f 100644 --- a/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py +++ b/src/plugins/built_in/affinity_flow_chatter/relationship_tracker.py @@ -287,7 +287,7 @@ class ChatterRelationshipTracker: # ===== 数据库支持方法 ===== - def get_user_relationship_score(self, user_id: str) -> float: + async def get_user_relationship_score(self, user_id: str) -> float: """获取用户关系分""" # 先检查缓存 if user_id in self.user_relationship_cache: @@ -298,7 +298,7 @@ class ChatterRelationshipTracker: return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score) # 缓存过期或不存在,从数据库获取 - relationship_data = self._get_user_relationship_from_db(user_id) + relationship_data = await self._get_user_relationship_from_db(user_id) if relationship_data: # 更新缓存 self.user_relationship_cache[user_id] = { @@ -313,37 +313,38 @@ class ChatterRelationshipTracker: # 数据库中也没有,返回默认值 return global_config.affinity_flow.base_relationship_score - def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]: + async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]: """从数据库获取用户关系数据""" try: - with get_db_session() as session: + async with get_db_session() as session: # 查询用户关系表 stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) - result = session.execute(stmt).scalar_one_or_none() + result = await session.execute(stmt) + relationship = result.scalar_one_or_none() - if result: + if relationship: return { - "relationship_text": result.relationship_text or "", - "relationship_score": float(result.relationship_score) - if result.relationship_score is not None + "relationship_text": relationship.relationship_text or "", + "relationship_score": float(relationship.relationship_score) + if relationship.relationship_score is not None else 0.3, - "last_updated": result.last_updated, + "last_updated": relationship.last_updated, } except Exception as e: logger.error(f"从数据库获取用户关系失败: {e}") return None - def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float): + async def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float): """更新数据库中的用户关系""" try: current_time = time.time() - with get_db_session() as session: + async with get_db_session() as session: # 检查是否已存在关系记录 - existing = session.execute( - select(UserRelationships).where(UserRelationships.user_id == user_id) - ).scalar_one_or_none() + stmt = select(UserRelationships).where(UserRelationships.user_id == user_id) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() if existing: # 更新现有记录 @@ -362,7 +363,7 @@ class ChatterRelationshipTracker: ) session.add(new_relationship) - session.commit() + await session.commit() logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}") except Exception as e: @@ -399,7 +400,7 @@ class ChatterRelationshipTracker: logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息") # 获取当前关系数据 - current_relationship = self._get_user_relationship_from_db(user_id) + current_relationship = await self._get_user_relationship_from_db(user_id) current_score = ( current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) if current_relationship @@ -417,14 +418,14 @@ class ChatterRelationshipTracker: logger.error(f"回复后关系追踪失败: {e}") logger.debug("错误详情:", exc_info=True) - def _get_last_tracked_time(self, user_id: str) -> float: + async def _get_last_tracked_time(self, user_id: str) -> float: """获取上次追踪时间""" # 先检查缓存 if user_id in self.user_relationship_cache: return self.user_relationship_cache[user_id].get("last_tracked", 0) # 从数据库获取 - relationship_data = self._get_user_relationship_from_db(user_id) + relationship_data = await self._get_user_relationship_from_db(user_id) if relationship_data: return relationship_data.get("last_updated", 0) @@ -433,7 +434,7 @@ class ChatterRelationshipTracker: async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]: """获取上次bot回复该用户的消息""" try: - with get_db_session() as session: + async with get_db_session() as session: # 查询bot回复给该用户的最新消息 stmt = ( select(Messages) @@ -443,10 +444,11 @@ class ChatterRelationshipTracker: .limit(1) ) - result = session.execute(stmt).scalar_one_or_none() - if result: + result = await session.execute(stmt) + message = result.scalar_one_or_none() + if message: # 将SQLAlchemy模型转换为DatabaseMessages对象 - return self._sqlalchemy_to_database_messages(result) + return self._sqlalchemy_to_database_messages(message) except Exception as e: logger.error(f"获取上次回复消息失败: {e}") @@ -456,7 +458,7 @@ class ChatterRelationshipTracker: async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]: """获取用户在bot回复后的反应消息""" try: - with get_db_session() as session: + async with get_db_session() as session: # 查询用户在回复时间之后的5分钟内的消息 end_time = reply_time + 5 * 60 # 5分钟 @@ -468,9 +470,10 @@ class ChatterRelationshipTracker: .order_by(Messages.time) ) - results = session.execute(stmt).scalars().all() - if results: - return [self._sqlalchemy_to_database_messages(result) for result in results] + result = await session.execute(stmt) + messages = result.scalars().all() + if messages: + return [self._sqlalchemy_to_database_messages(message) for message in messages] except Exception as e: logger.error(f"获取用户反应消息失败: {e}") @@ -593,7 +596,7 @@ class ChatterRelationshipTracker: quality = response_data.get("interaction_quality", "medium") # 更新数据库 - self._update_user_relationship_in_db(user_id, new_text, new_score) + await self._update_user_relationship_in_db(user_id, new_text, new_score) # 更新缓存 self.user_relationship_cache[user_id] = { @@ -696,7 +699,7 @@ class ChatterRelationshipTracker: ) # 更新数据库和缓存 - self._update_user_relationship_in_db(user_id, new_text, new_score) + await self._update_user_relationship_in_db(user_id, new_text, new_score) self.user_relationship_cache[user_id] = { "relationship_text": new_text, "relationship_score": new_score, diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index ed32da48d..770ced8e6 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -13,6 +13,7 @@ from typing import Callable from src.common.logger import get_logger from src.schedule.schedule_manager import schedule_manager from src.common.database.sqlalchemy_database_api import get_db_session +from sqlalchemy import select from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus from .qzone_service import QZoneService @@ -138,15 +139,13 @@ class SchedulerService: :return: 如果已处理过,返回 True,否则返回 False。 """ try: - with get_db_session() as session: - record = ( - session.query(MaiZoneScheduleStatus) - .filter( - MaiZoneScheduleStatus.datetime_hour == hour_str, - MaiZoneScheduleStatus.is_processed == True, # noqa: E712 - ) - .first() + async with get_db_session() as session: + stmt = select(MaiZoneScheduleStatus).where( + MaiZoneScheduleStatus.datetime_hour == hour_str, + MaiZoneScheduleStatus.is_processed == True, # noqa: E712 ) + result = await session.execute(stmt) + record = result.scalar_one_or_none() return record is not None except Exception as e: logger.error(f"检查日程处理状态时发生数据库错误: {e}") @@ -162,11 +161,11 @@ class SchedulerService: :param content: 最终发送的说说内容或错误信息。 """ try: - with get_db_session() as session: + async with get_db_session() as session: # 查找是否已存在该记录 - record = ( - session.query(MaiZoneScheduleStatus).filter(MaiZoneScheduleStatus.datetime_hour == hour_str).first() - ) + stmt = select(MaiZoneScheduleStatus).where(MaiZoneScheduleStatus.datetime_hour == hour_str) + result = await session.execute(stmt) + record = result.scalar_one_or_none() if record: # 如果存在,则更新状态 @@ -185,7 +184,7 @@ class SchedulerService: send_success=success, ) session.add(new_record) - session.commit() + await session.commit() logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}") except Exception as e: logger.error(f"更新日程处理状态时发生数据库错误: {e}") diff --git a/src/plugins/built_in/napcat_adapter_plugin/plugin.py b/src/plugins/built_in/napcat_adapter_plugin/plugin.py index fa0eeed23..569c0857a 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/plugin.py +++ b/src/plugins/built_in/napcat_adapter_plugin/plugin.py @@ -64,15 +64,9 @@ async def message_recv(server_connection: Server.ServerConnection): # 处理完整消息(可能是重组后的,也可能是原本就完整的) post_type = decoded_raw_message.get("post_type") - - # 兼容没有 post_type 的普通消息 - if not post_type and "message_type" in decoded_raw_message: - decoded_raw_message["post_type"] = "message" - post_type = "message" - if post_type in ["meta_event", "message", "notice"]: await message_queue.put(decoded_raw_message) - else: + elif post_type is None: await put_response(decoded_raw_message) except json.JSONDecodeError as e: @@ -428,8 +422,9 @@ class NapcatAdapterPlugin(BasePlugin): def get_plugin_components(self): self.register_events() - components = [(LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler), - (StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler)] + components = [] + components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler)) + components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler)) for handler in get_classes_in_module(event_handlers): if issubclass(handler, BaseEventHandler): components.append((handler.get_handler_info(), handler)) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/database.py b/src/plugins/built_in/napcat_adapter_plugin/src/database.py index 1620ec304..74842eed5 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/database.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/database.py @@ -1,156 +1,162 @@ -"""Napcat Adapter 插件数据库层 (基于主程序异步SQLAlchemy API) - -本模块替换原先的 sqlmodel + 同步Session 实现: -1. 复用主项目的异步数据库连接与迁移体系 -2. 提供与旧接口名兼容的方法(update_ban_record/create_ban_record/delete_ban_record) -3. 新增首选异步方法: update_ban_records / create_or_update / delete_record / get_ban_records - -数据语义: - user_id == 0 表示群全体禁言 - -注意: 所有方法均为异步, 需要在 async 上下文中调用。 -""" -from __future__ import annotations - +import os +from typing import Optional, List from dataclasses import dataclass -from typing import Optional, List, Sequence +from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index -from sqlalchemy.ext.asyncio import AsyncSession - -from src.common.database.sqlalchemy_models import Base, get_db_session from src.common.logger import get_logger logger = get_logger("napcat_adapter") +""" +表记录的方式: +| group_id | user_id | lift_time | +|----------|---------|-----------| -class NapcatBanRecord(Base): - __tablename__ = "napcat_ban_records" - - id = Column(Integer, primary_key=True, autoincrement=True) - group_id = Column(BigInteger, nullable=False, index=True) - user_id = Column(BigInteger, nullable=False, index=True) # 0 == 全体禁言 - lift_time = Column(BigInteger, nullable=True) # -1 / None 表示未知/永久 - - __table_args__ = ( - UniqueConstraint("group_id", "user_id", name="uq_napcat_group_user"), - Index("idx_napcat_ban_group", "group_id"), - Index("idx_napcat_ban_user", "user_id"), - ) +其中使用 user_id == 0 表示群全体禁言 +""" @dataclass class BanUser: + """ + 程序处理使用的实例 + """ + user_id: int group_id: int - lift_time: Optional[int] = -1 - - def identity(self) -> tuple[int, int]: - return self.group_id, self.user_id + lift_time: Optional[int] = Field(default=-1) -class NapcatDatabase: - async def _fetch_all(self, session: AsyncSession) -> Sequence[NapcatBanRecord]: - result = await session.execute(select(NapcatBanRecord)) - return result.scalars().all() +class DB_BanUser(SQLModel, table=True): + """ + 表示数据库中的用户禁言记录。 + 使用双重主键 + """ - async def get_ban_records(self) -> List[BanUser]: - async with get_db_session() as session: - rows = await self._fetch_all(session) - return [BanUser(group_id=r.group_id, user_id=r.user_id, lift_time=r.lift_time) for r in rows] + user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID + group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID + lift_time: Optional[int] # 禁言解除的时间(时间戳) - async def update_ban_records(self, ban_list: List[BanUser]) -> None: - target_map = {b.identity(): b for b in ban_list} - async with get_db_session() as session: - rows = await self._fetch_all(session) - existing_map = {(r.group_id, r.user_id): r for r in rows} - changed = 0 - for ident, ban in target_map.items(): - if ident in existing_map: - row = existing_map[ident] - if row.lift_time != ban.lift_time: - row.lift_time = ban.lift_time - changed += 1 +def is_identical(obj1: BanUser, obj2: BanUser) -> bool: + """ + 检查两个 BanUser 对象是否相同。 + """ + return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id + + +class DatabaseManager: + """ + 数据库管理类,负责与数据库交互。 + """ + + def __init__(self): + os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在 + DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db") + self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL + self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎 + self._ensure_database() # 确保数据库和表已创建 + + def _ensure_database(self) -> None: + """ + 确保数据库和表已创建。 + """ + logger.info("确保数据库文件和表已创建...") + SQLModel.metadata.create_all(self.engine) + logger.info("数据库和表已创建或已存在") + + def update_ban_record(self, ban_list: List[BanUser]) -> None: + # sourcery skip: class-extract-method + """ + 更新禁言列表到数据库。 + 支持在不存在时创建新记录,对于多余的项目自动删除。 + """ + with Session(self.engine) as session: + all_records = session.exec(select(DB_BanUser)).all() + for ban_user in ban_list: + statement = select(DB_BanUser).where( + DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id + ) + if existing_record := session.exec(statement).first(): + if existing_record.lift_time == ban_user.lift_time: + logger.debug(f"禁言记录未变更: {existing_record}") + continue + # 更新现有记录的 lift_time + existing_record.lift_time = ban_user.lift_time + session.add(existing_record) + logger.debug(f"更新禁言记录: {existing_record}") else: - session.add( - NapcatBanRecord(group_id=ban.group_id, user_id=ban.user_id, lift_time=ban.lift_time) + # 创建新记录 + db_record = DB_BanUser( + user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time ) - changed += 1 - - removed = 0 - for ident, row in existing_map.items(): - if ident not in target_map: - await session.delete(row) - removed += 1 - - logger.debug( - f"Napcat ban list sync => total_incoming={len(ban_list)} created_or_updated={changed} removed={removed}" - ) - - async def create_or_update(self, ban_record: BanUser) -> None: - async with get_db_session() as session: - stmt = select(NapcatBanRecord).where( - NapcatBanRecord.group_id == ban_record.group_id, - NapcatBanRecord.user_id == ban_record.user_id, - ) - result = await session.execute(stmt) - row = result.scalars().first() - if row: - if row.lift_time != ban_record.lift_time: - row.lift_time = ban_record.lift_time - logger.debug( - f"更新禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}" + session.add(db_record) + logger.debug(f"创建新禁言记录: {ban_user}") + # 删除不在 ban_list 中的记录 + for db_record in all_records: + record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time) + if not any(is_identical(record, ban_user) for ban_user in ban_list): + statement = select(DB_BanUser).where( + DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id ) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + + logger.debug(f"删除禁言记录: {ban_record}") + else: + logger.info(f"未找到禁言记录: {ban_record}") + + logger.info("禁言记录已更新") + + def get_ban_records(self) -> List[BanUser]: + """ + 读取所有禁言记录。 + """ + with Session(self.engine) as session: + statement = select(DB_BanUser) + records = session.exec(statement).all() + return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records] + + def create_ban_record(self, ban_record: BanUser) -> None: + """ + 为特定群组中的用户创建禁言记录。 + 一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。 + 其同时还是简化版的更新方式。 + """ + with Session(self.engine) as session: + # 检查记录是否已存在 + statement = select(DB_BanUser).where( + DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id + ) + existing_record = session.exec(statement).first() + if existing_record: + # 如果记录已存在,更新 lift_time + existing_record.lift_time = ban_record.lift_time + session.add(existing_record) + logger.debug(f"更新禁言记录: {ban_record}") else: - session.add( - NapcatBanRecord( - group_id=ban_record.group_id, user_id=ban_record.user_id, lift_time=ban_record.lift_time - ) - ) - logger.debug( - f"创建禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}" + # 如果记录不存在,创建新记录 + db_record = DB_BanUser( + user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time ) + session.add(db_record) + logger.debug(f"创建新禁言记录: {ban_record}") - async def delete_record(self, ban_record: BanUser) -> None: - async with get_db_session() as session: - stmt = select(NapcatBanRecord).where( - NapcatBanRecord.group_id == ban_record.group_id, - NapcatBanRecord.user_id == ban_record.user_id, - ) - result = await session.execute(stmt) - row = result.scalars().first() - if row: - await session.delete(row) - logger.debug( - f"删除禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={row.lift_time}" - ) + def delete_ban_record(self, ban_record: BanUser): + """ + 删除特定用户在特定群组中的禁言记录。 + 一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。 + """ + user_id = ban_record.user_id + group_id = ban_record.group_id + with Session(self.engine) as session: + statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id) + if ban_record := session.exec(statement).first(): + session.delete(ban_record) + + logger.debug(f"删除禁言记录: {ban_record}") else: - logger.info( - f"未找到禁言记录 group={ban_record.group_id} user={ban_record.user_id}" - ) - - # 兼容旧命名 - async def update_ban_record(self, ban_list: List[BanUser]) -> None: # old name - await self.update_ban_records(ban_list) - - async def create_ban_record(self, ban_record: BanUser) -> None: # old name - await self.create_or_update(ban_record) - - async def delete_ban_record(self, ban_record: BanUser) -> None: # old name - await self.delete_record(ban_record) + logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}") -napcat_db = NapcatDatabase() - - -def is_identical(a: BanUser, b: BanUser) -> bool: - return a.group_id == b.group_id and a.user_id == b.user_id - - -__all__ = [ - "BanUser", - "NapcatBanRecord", - "napcat_db", - "is_identical", -] +db_manager = DatabaseManager() diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py b/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py index 9757e7cf5..0f25bd62e 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py @@ -112,8 +112,7 @@ class MessageChunker: else: return [{"_original_message": message}] - @staticmethod - def is_chunk_message(message: Union[str, Dict[str, Any]]) -> bool: + def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool: """判断是否是切片消息""" try: if isinstance(message, str): diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py index 7ae743c41..7f310fbfa 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py @@ -14,7 +14,6 @@ class MetaEventHandler: """ def __init__(self): - self.last_heart_beat = time.time() self.interval = 5.0 # 默认值,稍后通过set_plugin_config设置 self._interval_checking = False self.plugin_config = None @@ -40,6 +39,7 @@ class MetaEventHandler: if message["status"].get("online") and message["status"].get("good"): if not self._interval_checking: asyncio.create_task(self.check_heartbeat()) + self.last_heart_beat = time.time() self.interval = message.get("interval") / 1000 else: self_id = message.get("self_id") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index ec4fbe75e..4e4aa3e10 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -76,7 +76,7 @@ class SendHandler: processed_message = await self.handle_seg_recursive(message_segment, user_info) except Exception as e: logger.error(f"处理消息时发生错误: {e}") - return None + return if not processed_message: logger.critical("现在暂时不支持解析此回复!") @@ -94,7 +94,7 @@ class SendHandler: id_name = "user_id" else: logger.error("无法识别的消息类型") - return None + return logger.info("尝试发送到napcat") logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'") response = await self.send_message_to_napcat( @@ -108,10 +108,8 @@ class SendHandler: logger.info("消息发送成功") qq_message_id = response.get("data", {}).get("message_id") await self.message_sent_back(raw_message_base, qq_message_id) - return None else: logger.warning(f"消息发送失败,napcat返回:{str(response)}") - return None async def send_command(self, raw_message_base: MessageBase) -> None: """ @@ -149,7 +147,7 @@ class SendHandler: command, args_dict = self.handle_send_like_command(args) case _: logger.error(f"未知命令: {command_name}") - return None + return except Exception as e: logger.error(f"处理命令时发生错误: {e}") return None @@ -161,10 +159,8 @@ class SendHandler: response = await self.send_message_to_napcat(command, args_dict) if response.get("status") == "ok": logger.info(f"命令 {command_name} 执行成功") - return None else: logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}") - return None async def handle_adapter_command(self, raw_message_base: MessageBase) -> None: """ @@ -272,8 +268,7 @@ class SendHandler: new_payload = self.build_payload(payload, self.handle_file_message(file_path), False) return new_payload - @staticmethod - def build_payload(payload: list, addon: dict | list, is_reply: bool = False) -> list: + def build_payload(self, payload: list, addon: dict | list, is_reply: bool = False) -> list: # sourcery skip: for-append-to-extend, merge-list-append, simplify-generator """构建发送的消息体""" if is_reply: @@ -339,13 +334,11 @@ class SendHandler: logger.info(f"最终返回的回复段: {reply_seg}") return reply_seg - @staticmethod - def handle_text_message(message: str) -> dict: + def handle_text_message(self, message: str) -> dict: """处理文本消息""" return {"type": "text", "data": {"text": message}} - @staticmethod - def handle_image_message(encoded_image: str) -> dict: + def handle_image_message(self, encoded_image: str) -> dict: """处理图片消息""" return { "type": "image", @@ -355,8 +348,7 @@ class SendHandler: }, } # base64 编码的图片 - @staticmethod - def handle_emoji_message(encoded_emoji: str) -> dict: + def handle_emoji_message(self, encoded_emoji: str) -> dict: """处理表情消息""" encoded_image = encoded_emoji image_format = get_image_format(encoded_emoji) @@ -387,45 +379,39 @@ class SendHandler: "data": {"file": f"base64://{encoded_voice}"}, } - @staticmethod - def handle_voiceurl_message(voice_url: str) -> dict: + def handle_voiceurl_message(self, voice_url: str) -> dict: """处理语音链接消息""" return { "type": "record", "data": {"file": voice_url}, } - @staticmethod - def handle_music_message(song_id: str) -> dict: + def handle_music_message(self, song_id: str) -> dict: """处理音乐消息""" return { "type": "music", "data": {"type": "163", "id": song_id}, } - @staticmethod - def handle_videourl_message(video_url: str) -> dict: + def handle_videourl_message(self, video_url: str) -> dict: """处理视频链接消息""" return { "type": "video", "data": {"file": video_url}, } - @staticmethod - def handle_file_message(file_path: str) -> dict: + def handle_file_message(self, file_path: str) -> dict: """处理文件消息""" return { "type": "file", "data": {"file": f"file://{file_path}"}, } - @staticmethod - def delete_msg_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: """处理删除消息命令""" return "delete_msg", {"message_id": args["message_id"]} - @staticmethod - def handle_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理封禁命令 Args: @@ -453,8 +439,7 @@ class SendHandler: }, ) - @staticmethod - def handle_whole_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理全体禁言命令 Args: @@ -477,8 +462,7 @@ class SendHandler: }, ) - @staticmethod - def handle_kick_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理群成员踢出命令 Args: @@ -503,8 +487,7 @@ class SendHandler: }, ) - @staticmethod - def handle_poke_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理戳一戳命令 Args: @@ -531,8 +514,7 @@ class SendHandler: }, ) - @staticmethod - def handle_set_emoji_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: """处理设置表情回应命令 Args: @@ -554,8 +536,7 @@ class SendHandler: {"message_id": message_id, "emoji_id": emoji_id, "set": set_like}, ) - @staticmethod - def handle_send_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: """ 处理发送点赞命令的逻辑。 @@ -576,8 +557,7 @@ class SendHandler: {"user_id": user_id, "times": times}, ) - @staticmethod - def handle_ai_voice_send_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """ 处理AI语音发送命令的逻辑。 并返回 NapCat 兼容的 (action, params) 元组。 @@ -624,8 +604,7 @@ class SendHandler: return {"status": "error", "message": str(e)} return response - @staticmethod - async def message_sent_back(message_base: MessageBase, qq_message_id: str) -> None: + async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None: # 修改 additional_config,添加 echo 字段 if message_base.message_info.additional_config is None: message_base.message_info.additional_config = {} @@ -643,9 +622,8 @@ class SendHandler: logger.debug("已回送消息ID") return - @staticmethod async def send_adapter_command_response( - original_message: MessageBase, response_data: dict, request_id: str + self, original_message: MessageBase, response_data: dict, request_id: str ) -> None: """ 发送适配器命令响应回MaiBot @@ -674,8 +652,7 @@ class SendHandler: except Exception as e: logger.error(f"发送适配器命令响应时出错: {e}") - @staticmethod - def handle_at_message_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理艾特并发送消息命令 Args: diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py index 4c47a2570..e36fc93fd 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py @@ -6,7 +6,7 @@ import urllib3 import ssl import io -from .database import BanUser, napcat_db +from .database import BanUser, db_manager from src.common.logger import get_logger logger = get_logger("napcat_adapter") @@ -270,11 +270,10 @@ async def read_ban_list( ] """ try: - ban_list = await napcat_db.get_ban_records() + ban_list = db_manager.get_ban_records() lifted_list: List[BanUser] = [] logger.info("已经读取禁言列表") - # 复制列表以避免迭代中修改原列表问题 - for ban_record in list(ban_list): + for ban_record in ban_list: if ban_record.user_id == 0: fetched_group_info = await get_group_info(websocket, ban_record.group_id) if fetched_group_info is None: @@ -302,12 +301,12 @@ async def read_ban_list( ban_list.remove(ban_record) else: ban_record.lift_time = lift_ban_time - await napcat_db.update_ban_record(ban_list) + db_manager.update_ban_record(ban_list) return ban_list, lifted_list except Exception as e: logger.error(f"读取禁言列表失败: {e}") return [], [] -async def save_ban_record(list: List[BanUser]): - return await napcat_db.update_ban_record(list) +def save_ban_record(list: List[BanUser]): + return db_manager.update_ban_record(list) diff --git a/test_deepcopy_fix.py b/test_deepcopy_fix.py new file mode 100644 index 000000000..c790619b8 --- /dev/null +++ b/test_deepcopy_fix.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +测试 ChatStream 的 deepcopy 功能 +验证 asyncio.Task 序列化问题是否已解决 +""" + +import asyncio +import sys +import os +import copy + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from src.chat.message_receive.chat_stream import ChatStream +from maim_message import UserInfo, GroupInfo + + +async def test_chat_stream_deepcopy(): + """测试 ChatStream 的 deepcopy 功能""" + print("[TEST] 开始测试 ChatStream deepcopy 功能...") + + try: + # 创建测试用的用户和群组信息 + user_info = UserInfo( + platform="test_platform", + user_id="test_user_123", + user_nickname="测试用户", + user_cardname="测试卡片名" + ) + + group_info = GroupInfo( + platform="test_platform", + group_id="test_group_456", + group_name="测试群组" + ) + + # 创建 ChatStream 实例 + print("📝 创建 ChatStream 实例...") + stream_id = "test_stream_789" + platform = "test_platform" + + chat_stream = ChatStream( + stream_id=stream_id, + platform=platform, + user_info=user_info, + group_info=group_info + ) + + print(f"[SUCCESS] ChatStream 创建成功: {chat_stream.stream_id}") + + # 等待一下,让异步任务有机会创建 + await asyncio.sleep(0.1) + + # 尝试进行 deepcopy + print("[INFO] 尝试进行 deepcopy...") + copied_stream = copy.deepcopy(chat_stream) + + print("[SUCCESS] deepcopy 成功!") + + # 验证复制后的对象属性 + print("\n[CHECK] 验证复制后的对象属性:") + print(f" - stream_id: {copied_stream.stream_id}") + print(f" - platform: {copied_stream.platform}") + print(f" - user_info: {copied_stream.user_info.user_nickname}") + print(f" - group_info: {copied_stream.group_info.group_name}") + + # 检查 processing_task 是否被正确处理 + if hasattr(copied_stream.stream_context, 'processing_task'): + print(f" - processing_task: {copied_stream.stream_context.processing_task}") + if copied_stream.stream_context.processing_task is None: + print(" [SUCCESS] processing_task 已被正确设置为 None") + else: + print(" [WARNING] processing_task 不为 None") + else: + print(" [SUCCESS] stream_context 没有 processing_task 属性") + + # 验证原始对象和复制对象是不同的实例 + if id(chat_stream) != id(copied_stream): + print("[SUCCESS] 原始对象和复制对象是不同的实例") + else: + print("[ERROR] 原始对象和复制对象是同一个实例") + + # 验证基本属性是否正确复制 + if (chat_stream.stream_id == copied_stream.stream_id and + chat_stream.platform == copied_stream.platform): + print("[SUCCESS] 基本属性正确复制") + else: + print("[ERROR] 基本属性复制失败") + + print("\n[COMPLETE] 测试完成!deepcopy 功能修复成功!") + return True + + except Exception as e: + print(f"[ERROR] 测试失败: {e}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + # 运行测试 + result = asyncio.run(test_chat_stream_deepcopy()) + + if result: + print("\n[SUCCESS] 所有测试通过!") + sys.exit(0) + else: + print("\n[ERROR] 测试失败!") + sys.exit(1) \ No newline at end of file diff --git a/test_simple_deepcopy.py b/test_simple_deepcopy.py new file mode 100644 index 000000000..63e680d45 --- /dev/null +++ b/test_simple_deepcopy.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +""" +简单的 ChatStream deepcopy 测试 +""" + +import asyncio +import sys +import os +import copy + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from src.chat.message_receive.chat_stream import ChatStream +from maim_message import UserInfo, GroupInfo + + +async def test_deepcopy(): + """测试 deepcopy 功能""" + print("开始测试 ChatStream deepcopy 功能...") + + try: + # 创建测试用的用户和群组信息 + user_info = UserInfo( + platform="test_platform", + user_id="test_user_123", + user_nickname="测试用户", + user_cardname="测试卡片名" + ) + + group_info = GroupInfo( + platform="test_platform", + group_id="test_group_456", + group_name="测试群组" + ) + + # 创建 ChatStream 实例 + print("创建 ChatStream 实例...") + stream_id = "test_stream_789" + platform = "test_platform" + + chat_stream = ChatStream( + stream_id=stream_id, + platform=platform, + user_info=user_info, + group_info=group_info + ) + + print(f"ChatStream 创建成功: {chat_stream.stream_id}") + + # 等待一下,让异步任务有机会创建 + await asyncio.sleep(0.1) + + # 尝试进行 deepcopy + print("尝试进行 deepcopy...") + copied_stream = copy.deepcopy(chat_stream) + + print("deepcopy 成功!") + + # 验证复制后的对象属性 + print("\n验证复制后的对象属性:") + print(f" - stream_id: {copied_stream.stream_id}") + print(f" - platform: {copied_stream.platform}") + print(f" - user_info: {copied_stream.user_info.user_nickname}") + print(f" - group_info: {copied_stream.group_info.group_name}") + + # 检查 processing_task 是否被正确处理 + if hasattr(copied_stream.stream_context, 'processing_task'): + print(f" - processing_task: {copied_stream.stream_context.processing_task}") + if copied_stream.stream_context.processing_task is None: + print(" SUCCESS: processing_task 已被正确设置为 None") + else: + print(" WARNING: processing_task 不为 None") + else: + print(" SUCCESS: stream_context 没有 processing_task 属性") + + # 验证原始对象和复制对象是不同的实例 + if id(chat_stream) != id(copied_stream): + print("SUCCESS: 原始对象和复制对象是不同的实例") + else: + print("ERROR: 原始对象和复制对象是同一个实例") + + # 验证基本属性是否正确复制 + if (chat_stream.stream_id == copied_stream.stream_id and + chat_stream.platform == copied_stream.platform): + print("SUCCESS: 基本属性正确复制") + else: + print("ERROR: 基本属性复制失败") + + print("\n测试完成!deepcopy 功能修复成功!") + return True + + except Exception as e: + print(f"ERROR: 测试失败: {e}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + # 运行测试 + result = asyncio.run(test_deepcopy()) + + if result: + print("\n所有测试通过!") + sys.exit(0) + else: + print("\n测试失败!") + sys.exit(1) \ No newline at end of file