diff --git a/bot.py b/bot.py index 566263113..da2391a0d 100644 --- a/bot.py +++ b/bot.py @@ -33,6 +33,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__)) os.chdir(script_dir) logger.info("工作目录已设置") + class ConfigManager: """配置管理器""" @@ -96,6 +97,7 @@ class ConfigManager: logger.error(f"加载环境变量失败: {e}") return False + class EULAManager: """EULA管理类""" @@ -134,7 +136,9 @@ class EULAManager: return if attempts % 5 == 0: - confirm_logger.critical(f"请修改 .env 文件中的 EULA_CONFIRMED=true (尝试 {attempts}/{MAX_EULA_CHECK_ATTEMPTS})") + confirm_logger.critical( + f"请修改 .env 文件中的 EULA_CONFIRMED=true (尝试 {attempts}/{MAX_EULA_CHECK_ATTEMPTS})" + ) except KeyboardInterrupt: confirm_logger.info("用户取消,程序退出") @@ -148,16 +152,14 @@ class EULAManager: confirm_logger.error("EULA确认超时,程序退出") sys.exit(1) + class TaskManager: """任务管理器""" @staticmethod async def cancel_pending_tasks(loop, timeout=SHUTDOWN_TIMEOUT): """取消所有待处理的任务""" - remaining_tasks = [ - t for t in asyncio.all_tasks(loop) - if t is not asyncio.current_task(loop) and not t.done() - ] + remaining_tasks = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task(loop) and not t.done()] if not remaining_tasks: logger.info("没有待取消的任务") @@ -171,10 +173,7 @@ class TaskManager: # 等待任务完成 try: - results = await asyncio.wait_for( - asyncio.gather(*remaining_tasks, return_exceptions=True), - timeout=timeout - ) + results = await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=timeout) # 检查任务结果 for i, result in enumerate(results): @@ -195,6 +194,7 @@ class TaskManager: """停止所有异步任务""" try: from src.manager.async_task_manager import async_task_manager + await async_task_manager.stop_and_wait_all_tasks() return True except ImportError: @@ -204,6 +204,7 @@ class TaskManager: logger.error(f"停止异步任务失败: {e}") return False + class ShutdownManager: """关闭管理器""" @@ -236,6 +237,7 @@ class ShutdownManager: logger.error(f"麦麦关闭失败: {e}", exc_info=True) return False + @asynccontextmanager async def create_event_loop_context(): """创建事件循环的上下文管理器""" @@ -260,6 +262,7 @@ async def create_event_loop_context(): except Exception as e: logger.error(f"关闭事件循环失败: {e}") + class DatabaseManager: """数据库连接管理器""" @@ -278,7 +281,9 @@ class DatabaseManager: # 使用线程执行器运行潜在的阻塞操作 await asyncio.to_thread(initialize_sql_database, global_config.database) elapsed_time = time.time() - start_time - logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库,耗时: {elapsed_time:.2f}秒") + logger.info( + f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库,耗时: {elapsed_time:.2f}秒" + ) return self except Exception as e: @@ -291,6 +296,7 @@ class DatabaseManager: logger.error(f"数据库操作发生异常: {exc_val}") return False + class ConfigurationValidator: """配置验证器""" @@ -328,6 +334,7 @@ class ConfigurationValidator: logger.error(f"配置验证失败: {e}") return False + class EasterEgg: """彩蛋功能""" @@ -347,6 +354,7 @@ class EasterEgg: rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char logger.info(rainbow_text) + class MaiBotMain: """麦麦机器人主程序类""" @@ -375,6 +383,7 @@ class MaiBotMain: try: start_time = time.time() from src.common.database.sqlalchemy_models import initialize_database as init_db + await init_db() elapsed_time = time.time() - start_time logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}秒") @@ -385,6 +394,7 @@ class MaiBotMain: def create_main_system(self): """创建MainSystem实例""" from src.main import MainSystem + self.main_system = MainSystem() return self.main_system @@ -411,11 +421,13 @@ class MaiBotMain: # 初始化知识库 from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge + initialize_lpmm_knowledge() # 显示彩蛋 EasterEgg.show() + async def wait_for_user_input(): """等待用户输入(异步方式)""" try: @@ -432,6 +444,7 @@ async def wait_for_user_input(): logger.error(f"等待用户输入时发生错误: {e}") return False + async def main_async(): """主异步函数""" exit_code = 0 @@ -455,10 +468,7 @@ async def main_async(): user_input_done = asyncio.create_task(wait_for_user_input()) # 使用wait等待任意一个任务完成 - done, pending = await asyncio.wait( - [main_task, user_input_done], - return_when=asyncio.FIRST_COMPLETED - ) + done, pending = await asyncio.wait([main_task, user_input_done], return_when=asyncio.FIRST_COMPLETED) # 如果用户输入任务完成(用户按了Ctrl+C),取消主任务 if user_input_done in done and main_task not in done: @@ -482,6 +492,7 @@ async def main_async(): return exit_code + if __name__ == "__main__": exit_code = 0 try: diff --git a/src/api/message_router.py b/src/api/message_router.py index 47ae7771f..513d3d2df 100644 --- a/src/api/message_router.py +++ b/src/api/message_router.py @@ -12,10 +12,13 @@ logger = get_logger("HTTP消息API") router = APIRouter() + @router.get("/messages/recent") async def get_message_stats( days: int = Query(1, ge=1, description="指定查询过去多少天的数据"), - message_type: Literal["all", "sent", "received"] = Query("all", description="筛选消息类型: 'sent' (BOT发送的), 'received' (BOT接收的), or 'all' (全部)") + message_type: Literal["all", "sent", "received"] = Query( + "all", description="筛选消息类型: 'sent' (BOT发送的), 'received' (BOT接收的), or 'all' (全部)" + ), ): """ 获取BOT在指定天数内的消息统计数据。 @@ -45,7 +48,7 @@ async def get_message_stats( "message_type": message_type, "sent_count": sent_count, "received_count": received_count, - "total_count": len(messages) + "total_count": len(messages), } except Exception as e: @@ -76,10 +79,7 @@ async def get_message_stats_by_chat( user_id = msg.get("user_id") if chat_id not in stats: - stats[chat_id] = { - "total_stats": {"total": 0}, - "user_stats": {} - } + stats[chat_id] = {"total_stats": {"total": 0}, "user_stats": {}} stats[chat_id]["total_stats"]["total"] += 1 @@ -116,10 +116,7 @@ async def get_message_stats_by_chat( for user_id, count in data["user_stats"].items(): person_id = person_api.get_person_id("qq", user_id) nickname = await person_api.get_person_value(person_id, "nickname", "未知用户") - formatted_data["user_stats"][user_id] = { - "nickname": nickname, - "count": count - } + formatted_data["user_stats"][user_id] = {"nickname": nickname, "count": count} formatted_stats[chat_id] = formatted_data return formatted_stats @@ -129,6 +126,7 @@ async def get_message_stats_by_chat( except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @router.get("/messages/bot_stats_by_chat") async def get_bot_message_stats_by_chat( days: int = Query(1, ge=1, description="指定查询过去多少天的数据"), @@ -165,10 +163,7 @@ async def get_bot_message_stats_by_chat( elif stream.user_info and stream.user_info.user_nickname: chat_name = stream.user_info.user_nickname - formatted_stats[chat_id] = { - "chat_name": chat_name, - "count": count - } + formatted_stats[chat_id] = {"chat_name": chat_name, "count": count} return formatted_stats return stats diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index fc84edc26..1120b62cf 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -313,7 +313,9 @@ class EnergyManager: # 确保 score 是 float 类型 if not isinstance(score, int | float): - logger.warning(f"计算器 {calculator.__class__.__name__} 返回了非数值类型: {type(score)},跳过此组件") + logger.warning( + f"计算器 {calculator.__class__.__name__} 返回了非数值类型: {type(score)},跳过此组件" + ) continue component_scores[calculator.__class__.__name__] = float(score) diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 6d72c171c..0c25b9fc6 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -527,7 +527,7 @@ class ExpressionLearnerManager: os.makedirs(directory, exist_ok=True) logger.debug(f"确保目录存在: {directory}") except Exception as e: - logger.error(f"创建目录失败 {directory}: {e}") + logger.error(f"创建目录失败 {directory}: {e}") @staticmethod async def _auto_migrate_json_to_db(): diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index 7926d4a8e..b2d9a93cd 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -429,7 +429,9 @@ class BotInterestManager: except Exception as e: logger.error(f"❌ 计算相似度分数失败: {e}") - async def calculate_interest_match(self, message_text: str, keywords: list[str] | None = None) -> InterestMatchResult: + async def calculate_interest_match( + self, message_text: str, keywords: list[str] | None = None + ) -> InterestMatchResult: """计算消息与机器人兴趣的匹配度""" if not self.current_interests or not self._initialized: raise RuntimeError("❌ 兴趣标签系统未初始化") diff --git a/src/chat/interest_system/interest_manager.py b/src/chat/interest_system/interest_manager.py index c77ffd25b..faf77b888 100644 --- a/src/chat/interest_system/interest_manager.py +++ b/src/chat/interest_system/interest_manager.py @@ -79,7 +79,9 @@ class InterestManager: # 如果已有组件在运行,先清理并替换 if self._current_calculator: - logger.info(f"替换现有兴趣值计算组件: {self._current_calculator.component_name} -> {calculator.component_name}") + logger.info( + f"替换现有兴趣值计算组件: {self._current_calculator.component_name} -> {calculator.component_name}" + ) await self._current_calculator.cleanup() else: logger.info(f"注册新的兴趣值计算组件: {calculator.component_name}") @@ -114,7 +116,7 @@ class InterestManager: success=False, message_id=getattr(message, "message_id", ""), interest_value=0.3, - error_message="没有可用的兴趣值计算组件" + error_message="没有可用的兴趣值计算组件", ) # 使用 create_task 异步执行计算 @@ -133,7 +135,7 @@ class InterestManager: interest_value=0.5, # 固定默认兴趣值 should_reply=False, should_act=False, - error_message=f"计算超时({timeout}s),使用默认值" + error_message=f"计算超时({timeout}s),使用默认值", ) except Exception as e: # 发生异常,返回默认结果 @@ -142,7 +144,7 @@ class InterestManager: success=False, message_id=getattr(message, "message_id", ""), interest_value=0.3, - error_message=f"计算异常: {e!s}" + error_message=f"计算异常: {e!s}", ) async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult: @@ -171,7 +173,7 @@ class InterestManager: message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=f"计算异常: {e!s}", - calculation_time=time.time() - start_time + calculation_time=time.time() - start_time, ) async def _calculation_worker(self): @@ -179,10 +181,7 @@ class InterestManager: while not self._shutdown_event.is_set(): try: # 等待计算任务或关闭信号 - await asyncio.wait_for( - self._calculation_queue.get(), - timeout=1.0 - ) + await asyncio.wait_for(self._calculation_queue.get(), timeout=1.0) # 处理计算任务 # 这里可以实现批量处理逻辑 @@ -210,7 +209,7 @@ class InterestManager: "failed_calculations": self._failed_calculations, "success_rate": success_rate, "last_calculation_time": self._last_calculation_time, - "current_calculator": self._current_calculator.component_name if self._current_calculator else None + "current_calculator": self._current_calculator.component_name if self._current_calculator else None, } } diff --git a/src/chat/knowledge/open_ie.py b/src/chat/knowledge/open_ie.py index d59d6b409..4843174b8 100644 --- a/src/chat/knowledge/open_ie.py +++ b/src/chat/knowledge/open_ie.py @@ -125,19 +125,19 @@ class OpenIE: def extract_entity_dict(self): """提取实体列表""" ner_output_dict = { - doc_item["idx"]: doc_item["extracted_entities"] - for doc_item in self.docs - if len(doc_item["extracted_entities"]) > 0 - } + doc_item["idx"]: doc_item["extracted_entities"] + for doc_item in self.docs + if len(doc_item["extracted_entities"]) > 0 + } return ner_output_dict def extract_triple_dict(self): """提取三元组列表""" triple_output_dict = { - doc_item["idx"]: doc_item["extracted_triples"] - for doc_item in self.docs - if len(doc_item["extracted_triples"]) > 0 - } + doc_item["idx"]: doc_item["extracted_triples"] + for doc_item in self.docs + if len(doc_item["extracted_triples"]) > 0 + } return triple_output_dict def extract_raw_paragraph_dict(self): diff --git a/src/chat/knowledge/utils/dyn_topk.py b/src/chat/knowledge/utils/dyn_topk.py index e14146781..55f45c1b2 100644 --- a/src/chat/knowledge/utils/dyn_topk.py +++ b/src/chat/knowledge/utils/dyn_topk.py @@ -19,10 +19,10 @@ def dyn_select_top_k( for score_item in sorted_score: normalized_score.append( ( - score_item[0], - score_item[1], - (score_item[1] - min_score) / (max_score - min_score), - ) + score_item[0], + score_item[1], + (score_item[1] - min_score) / (max_score - min_score), + ) ) # 寻找跳变点:score变化最大的位置 diff --git a/src/chat/memory_system/hippocampus_sampler.py b/src/chat/memory_system/hippocampus_sampler.py index fba3e439f..24bea1fdd 100644 --- a/src/chat/memory_system/hippocampus_sampler.py +++ b/src/chat/memory_system/hippocampus_sampler.py @@ -29,20 +29,21 @@ logger = get_logger(__name__) @dataclass class HippocampusSampleConfig: """海马体采样配置""" + # 双峰分布参数 recent_mean_hours: float = 12.0 # 近期分布均值(小时) - recent_std_hours: float = 8.0 # 近期分布标准差(小时) - recent_weight: float = 0.7 # 近期分布权重 + recent_std_hours: float = 8.0 # 近期分布标准差(小时) + recent_weight: float = 0.7 # 近期分布权重 distant_mean_hours: float = 48.0 # 远期分布均值(小时) - distant_std_hours: float = 24.0 # 远期分布标准差(小时) - distant_weight: float = 0.3 # 远期分布权重 + distant_std_hours: float = 24.0 # 远期分布标准差(小时) + distant_weight: float = 0.3 # 远期分布权重 # 采样参数 - total_samples: int = 50 # 总采样数 - sample_interval: int = 1800 # 采样间隔(秒) - max_sample_length: int = 30 # 每次采样的最大消息数量 - batch_size: int = 5 # 批处理大小 + total_samples: int = 50 # 总采样数 + sample_interval: int = 1800 # 采样间隔(秒) + max_sample_length: int = 30 # 每次采样的最大消息数量 + batch_size: int = 5 # 批处理大小 @classmethod def from_global_config(cls) -> "HippocampusSampleConfig": @@ -84,12 +85,10 @@ class HippocampusSampler: try: # 初始化LLM模型 from src.config.config import model_config + task_config = getattr(model_config.model_task_config, "utils", None) if task_config: - self.memory_builder_model = LLMRequest( - model_set=task_config, - request_type="memory.hippocampus_build" - ) + self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build") asyncio.create_task(self.start_background_sampling()) logger.info("✅ 海马体采样器初始化成功") else: @@ -107,14 +106,10 @@ class HippocampusSampler: # 生成两个正态分布的小时偏移 recent_offsets = np.random.normal( - loc=self.config.recent_mean_hours, - scale=self.config.recent_std_hours, - size=recent_samples + loc=self.config.recent_mean_hours, scale=self.config.recent_std_hours, size=recent_samples ) distant_offsets = np.random.normal( - loc=self.config.distant_mean_hours, - scale=self.config.distant_std_hours, - size=distant_samples + loc=self.config.distant_mean_hours, scale=self.config.distant_std_hours, size=distant_samples ) # 合并两个分布的偏移 @@ -122,10 +117,7 @@ class HippocampusSampler: # 转换为时间戳(使用绝对值确保时间点在过去) base_time = datetime.now() - timestamps = [ - base_time - timedelta(hours=abs(offset)) - for offset in all_offsets - ] + timestamps = [base_time - timedelta(hours=abs(offset)) for offset in all_offsets] # 按时间排序(从最早到最近) return sorted(timestamps) @@ -171,7 +163,8 @@ class HippocampusSampler: if messages and len(messages) >= 2: # 至少需要2条消息 # 过滤掉已经记忆过的消息 filtered_messages = [ - msg for msg in messages + msg + for msg in messages if msg.get("memorized_times", 0) < 2 # 最多记忆2次 ] @@ -229,7 +222,7 @@ class HippocampusSampler: conversation_text=input_text, context=context, timestamp=time.time(), - bypass_interval=True # 海马体采样器绕过构建间隔限制 + bypass_interval=True, # 海马体采样器绕过构建间隔限制 ) if memories: @@ -367,7 +360,7 @@ class HippocampusSampler: max_concurrent = min(5, len(time_samples)) # 提高并发数到5 for i in range(0, len(time_samples), max_concurrent): - batch = time_samples[i:i + max_concurrent] + batch = time_samples[i : i + max_concurrent] tasks = [] # 创建并发收集任务 @@ -392,7 +385,9 @@ class HippocampusSampler: return collected_messages - async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]: + async def _fuse_and_deduplicate_messages( + self, collected_messages: list[list[dict[str, Any]]] + ) -> list[list[dict[str, Any]]]: """融合和去重消息样本""" if not collected_messages: return [] @@ -416,7 +411,7 @@ class HippocampusSampler: chat_id = message.get("chat_id", "") # 简单哈希:内容前50字符 + 时间戳(精确到分钟) + 聊天ID - hash_key = f"{content[:50]}_{int(timestamp//60)}_{chat_id}" + hash_key = f"{content[:50]}_{int(timestamp // 60)}_{chat_id}" if hash_key not in seen_hashes and len(content.strip()) > 10: seen_hashes.add(hash_key) @@ -448,7 +443,9 @@ class HippocampusSampler: # 返回原始消息组作为备选 return collected_messages[:5] # 限制返回数量 - def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]: + def _merge_adjacent_messages( + self, messages: list[dict[str, Any]], time_gap: int = 1800 + ) -> list[list[dict[str, Any]]]: """合并时间间隔内的消息""" if not messages: return [] @@ -479,7 +476,9 @@ class HippocampusSampler: return result_groups - async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]: + async def _build_batch_memory( + self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime] + ) -> dict[str, Any]: """批量构建记忆""" if not fused_messages: return {"memory_count": 0, "memories": []} @@ -513,10 +512,7 @@ class HippocampusSampler: # 一次性构建记忆 memories = await self.memory_system.build_memory_from_conversation( - conversation_text=batch_input_text, - context=batch_context, - timestamp=time.time(), - bypass_interval=True + conversation_text=batch_input_text, context=batch_context, timestamp=time.time(), bypass_interval=True ) if memories: @@ -545,11 +541,7 @@ class HippocampusSampler: if len(self.last_sample_results) > 10: self.last_sample_results.pop(0) - return { - "memory_count": total_memory_count, - "memories": total_memories, - "result": result - } + return {"memory_count": total_memory_count, "memories": total_memories, "result": result} except Exception as e: logger.error(f"批量构建记忆失败: {e}") @@ -601,11 +593,7 @@ class HippocampusSampler: except Exception as e: logger.debug(f"单独构建失败: {e}") - return { - "memory_count": total_count, - "memories": total_memories, - "fallback_mode": True - } + return {"memory_count": total_count, "memories": total_memories, "fallback_mode": True} async def process_sample_timestamp(self, target_timestamp: float) -> str | None: """处理单个时间戳采样(保留作为备选方法)""" @@ -696,7 +684,9 @@ class HippocampusSampler: "performance_metrics": { "avg_messages_per_sample": f"{recent_avg_messages:.1f}", "avg_memories_per_sample": f"{recent_avg_memory_count:.1f}", - "fusion_efficiency": f"{(recent_avg_messages/max(recent_avg_memory_count, 1)):.1f}x" if recent_avg_messages > 0 else "N/A" + "fusion_efficiency": f"{(recent_avg_messages / max(recent_avg_memory_count, 1)):.1f}x" + if recent_avg_messages > 0 + else "N/A", }, "config": { "sample_interval": self.config.sample_interval, diff --git a/src/chat/memory_system/memory_formatter.py b/src/chat/memory_system/memory_formatter.py index c5b1db134..5d69b32c3 100644 --- a/src/chat/memory_system/memory_formatter.py +++ b/src/chat/memory_system/memory_formatter.py @@ -15,6 +15,7 @@ 返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。 """ + from __future__ import annotations import time diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index b9f02c86d..8f185cb82 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -24,9 +24,12 @@ from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner # 记忆采样模式枚举 class MemorySamplingMode(Enum): """记忆采样模式""" + HIPPOCAMPUS = "hippocampus" # 海马体模式:定时任务采样 - IMMEDIATE = "immediate" # 即时模式:回复后立即采样 - ALL = "all" # 所有模式:同时使用海马体和即时采样 + IMMEDIATE = "immediate" # 即时模式:回复后立即采样 + ALL = "all" # 所有模式:同时使用海马体和即时采样 + + from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest @@ -165,7 +168,6 @@ class MemorySystem: async def initialize(self): """异步初始化记忆系统""" try: - # 初始化LLM模型 fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None @@ -264,6 +266,7 @@ class MemorySystem: if global_config.memory.enable_hippocampus_sampling: try: from .hippocampus_sampler import initialize_hippocampus_sampler + self.hippocampus_sampler = await initialize_hippocampus_sampler(self) logger.info("✅ 海马体采样器初始化成功") except Exception as e: @@ -321,7 +324,11 @@ class MemorySystem: return [] async def build_memory_from_conversation( - self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None, bypass_interval: bool = False + self, + conversation_text: str, + context: dict[str, Any], + timestamp: float | None = None, + bypass_interval: bool = False, ) -> list[MemoryChunk]: """从对话中构建记忆 @@ -560,7 +567,6 @@ class MemorySystem: sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision") current_mode = MemorySamplingMode(sampling_mode) - context["__sampling_mode"] = current_mode.value logger.debug(f"使用记忆采样模式: {current_mode.value}") diff --git a/src/chat/message_manager/adaptive_stream_manager.py b/src/chat/message_manager/adaptive_stream_manager.py index 9e01403c4..fa0a97de5 100644 --- a/src/chat/message_manager/adaptive_stream_manager.py +++ b/src/chat/message_manager/adaptive_stream_manager.py @@ -17,6 +17,7 @@ logger = get_logger("adaptive_stream_manager") class StreamPriority(Enum): """流优先级""" + LOW = 1 NORMAL = 2 HIGH = 3 @@ -26,6 +27,7 @@ class StreamPriority(Enum): @dataclass class SystemMetrics: """系统指标""" + cpu_usage: float = 0.0 memory_usage: float = 0.0 active_coroutines: int = 0 @@ -36,6 +38,7 @@ class SystemMetrics: @dataclass class StreamMetrics: """流指标""" + stream_id: str priority: StreamPriority message_rate: float = 0.0 # 消息速率(消息/分钟) @@ -56,7 +59,7 @@ class AdaptiveStreamManager: metrics_window: float = 60.0, # 指标窗口时间 adjustment_interval: float = 30.0, # 调整间隔 cpu_threshold_high: float = 0.8, # CPU高负载阈值 - cpu_threshold_low: float = 0.3, # CPU低负载阈值 + cpu_threshold_low: float = 0.3, # CPU低负载阈值 memory_threshold_high: float = 0.85, # 内存高负载阈值 ): self.base_concurrent_limit = base_concurrent_limit @@ -139,10 +142,7 @@ class AdaptiveStreamManager: logger.info("自适应流管理器已停止") async def acquire_stream_slot( - self, - stream_id: str, - priority: StreamPriority = StreamPriority.NORMAL, - force: bool = False + self, stream_id: str, priority: StreamPriority = StreamPriority.NORMAL, force: bool = False ) -> bool: """ 获取流处理槽位 @@ -165,10 +165,7 @@ class AdaptiveStreamManager: # 更新流指标 if stream_id not in self.stream_metrics: - self.stream_metrics[stream_id] = StreamMetrics( - stream_id=stream_id, - priority=priority - ) + self.stream_metrics[stream_id] = StreamMetrics(stream_id=stream_id, priority=priority) self.stream_metrics[stream_id].last_activity = current_time # 检查是否已经活跃 @@ -271,8 +268,10 @@ class AdaptiveStreamManager: # 如果最近有活跃且响应时间较长,可能需要强制分发 current_time = time.time() - if (current_time - metrics.last_activity < 300 and # 5分钟内有活动 - metrics.response_time > 5.0): # 响应时间超过5秒 + if ( + current_time - metrics.last_activity < 300 # 5分钟内有活动 + and metrics.response_time > 5.0 + ): # 响应时间超过5秒 return True return False @@ -324,26 +323,20 @@ class AdaptiveStreamManager: memory_usage=memory_usage, active_coroutines=active_coroutines, event_loop_lag=event_loop_lag, - timestamp=time.time() + timestamp=time.time(), ) self.system_metrics.append(metrics) # 保持指标窗口大小 cutoff_time = time.time() - self.metrics_window - self.system_metrics = [ - m for m in self.system_metrics - if m.timestamp > cutoff_time - ] + self.system_metrics = [m for m in self.system_metrics if m.timestamp > cutoff_time] # 更新统计信息 self.stats["avg_concurrent_streams"] = ( self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1 ) - self.stats["peak_concurrent_streams"] = max( - self.stats["peak_concurrent_streams"], - len(self.active_streams) - ) + self.stats["peak_concurrent_streams"] = max(self.stats["peak_concurrent_streams"], len(self.active_streams)) except Exception as e: logger.error(f"收集系统指标失败: {e}") @@ -445,14 +438,16 @@ class AdaptiveStreamManager: def get_stats(self) -> dict: """获取统计信息""" stats = self.stats.copy() - stats.update({ - "current_limit": self.current_limit, - "active_streams": len(self.active_streams), - "pending_streams": len(self.pending_streams), - "is_running": self.is_running, - "system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0, - "system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0, - }) + stats.update( + { + "current_limit": self.current_limit, + "active_streams": len(self.active_streams), + "pending_streams": len(self.pending_streams), + "is_running": self.is_running, + "system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0, + "system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0, + } + ) # 计算接受率 if stats["total_requests"] > 0: diff --git a/src/chat/message_manager/batch_database_writer.py b/src/chat/message_manager/batch_database_writer.py index cb8e87c3d..4bbe93e9c 100644 --- a/src/chat/message_manager/batch_database_writer.py +++ b/src/chat/message_manager/batch_database_writer.py @@ -20,6 +20,7 @@ logger = get_logger("batch_database_writer") @dataclass class StreamUpdatePayload: """流更新数据结构""" + stream_id: str update_data: dict[str, Any] priority: int = 0 # 优先级,数字越大优先级越高 @@ -95,12 +96,7 @@ class BatchDatabaseWriter: logger.info("批量数据库写入器已停止") - async def schedule_stream_update( - self, - stream_id: str, - update_data: dict[str, Any], - priority: int = 0 - ) -> bool: + async def schedule_stream_update(self, stream_id: str, update_data: dict[str, Any], priority: int = 0) -> bool: """ 调度流更新 @@ -119,11 +115,7 @@ class BatchDatabaseWriter: return True # 创建更新载荷 - payload = StreamUpdatePayload( - stream_id=stream_id, - update_data=update_data, - priority=priority - ) + payload = StreamUpdatePayload(stream_id=stream_id, update_data=update_data, priority=priority) # 非阻塞方式加入队列 try: @@ -178,10 +170,7 @@ class BatchDatabaseWriter: if remaining_time == 0: break - payload = await asyncio.wait_for( - self.write_queue.get(), - timeout=remaining_time - ) + payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time) batch.append(payload) except asyncio.TimeoutError: @@ -203,7 +192,10 @@ class BatchDatabaseWriter: # 合并同一流ID的更新(保留最新的) merged_updates = {} for payload in batch: - if payload.stream_id not in merged_updates or payload.timestamp > merged_updates[payload.stream_id].timestamp: + if ( + payload.stream_id not in merged_updates + or payload.timestamp > merged_updates[payload.stream_id].timestamp + ): merged_updates[payload.stream_id] = payload # 批量写入 @@ -211,9 +203,7 @@ class BatchDatabaseWriter: # 更新统计 self.stats["batch_writes"] += 1 - self.stats["avg_batch_size"] = ( - self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1 - ) # 滑动平均 + self.stats["avg_batch_size"] = self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1 # 滑动平均 self.stats["last_flush_time"] = start_time logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s") @@ -238,31 +228,22 @@ class BatchDatabaseWriter: # 根据数据库类型选择不同的插入/更新策略 if global_config.database.database_type == "sqlite": from sqlalchemy.dialects.sqlite import insert as sqlite_insert - stmt = sqlite_insert(ChatStreams).values( - stream_id=stream_id, **update_data - ) - stmt = stmt.on_conflict_do_update( - index_elements=["stream_id"], - set_=update_data - ) + + stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data) + stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data) elif global_config.database.database_type == "mysql": from sqlalchemy.dialects.mysql import insert as mysql_insert - stmt = mysql_insert(ChatStreams).values( - stream_id=stream_id, **update_data - ) + + stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data) stmt = stmt.on_duplicate_key_update( **{key: value for key, value in update_data.items() if key != "stream_id"} ) else: # 默认使用SQLite语法 from sqlalchemy.dialects.sqlite import insert as sqlite_insert - stmt = sqlite_insert(ChatStreams).values( - stream_id=stream_id, **update_data - ) - stmt = stmt.on_conflict_do_update( - index_elements=["stream_id"], - set_=update_data - ) + + stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data) + stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data) await session.execute(stmt) @@ -273,30 +254,21 @@ class BatchDatabaseWriter: async with get_db_session() as session: if global_config.database.database_type == "sqlite": from sqlalchemy.dialects.sqlite import insert as sqlite_insert - stmt = sqlite_insert(ChatStreams).values( - stream_id=stream_id, **update_data - ) - stmt = stmt.on_conflict_do_update( - index_elements=["stream_id"], - set_=update_data - ) + + stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data) + stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data) elif global_config.database.database_type == "mysql": from sqlalchemy.dialects.mysql import insert as mysql_insert - stmt = mysql_insert(ChatStreams).values( - stream_id=stream_id, **update_data - ) + + stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data) stmt = stmt.on_duplicate_key_update( **{key: value for key, value in update_data.items() if key != "stream_id"} ) else: from sqlalchemy.dialects.sqlite import insert as sqlite_insert - stmt = sqlite_insert(ChatStreams).values( - stream_id=stream_id, **update_data - ) - stmt = stmt.on_conflict_do_update( - index_elements=["stream_id"], - set_=update_data - ) + + stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data) + stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data) await session.execute(stmt) await session.commit() diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index d5bfa9548..ea37ad0be 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -273,8 +273,10 @@ class SingleStreamContextManager: message.should_reply = result.should_reply message.should_act = result.should_act - logger.debug(f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, " - f"should_reply: {result.should_reply}, should_act: {result.should_act}") + logger.debug( + f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, " + f"should_reply: {result.should_reply}, should_act: {result.should_act}" + ) return result.interest_value else: logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}") diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index b6eab795e..91545e5d5 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -79,7 +79,7 @@ class StreamLoopManager: logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...") await asyncio.gather( *[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks], - return_exceptions=True + return_exceptions=True, ) # 取消所有活跃的 chatter 处理任务 @@ -115,6 +115,7 @@ class StreamLoopManager: # 使用自适应流管理器获取槽位 try: from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager + adaptive_manager = get_adaptive_stream_manager() if adaptive_manager.is_running: @@ -123,9 +124,7 @@ class StreamLoopManager: # 获取处理槽位 slot_acquired = await adaptive_manager.acquire_stream_slot( - stream_id=stream_id, - priority=priority, - force=force + stream_id=stream_id, priority=priority, force=force ) if slot_acquired: @@ -140,10 +139,7 @@ class StreamLoopManager: # 创建流循环任务 try: - loop_task = asyncio.create_task( - self._stream_loop_worker(stream_id), - name=f"stream_loop_{stream_id}" - ) + loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}") self.stream_loops[stream_id] = loop_task # 更新统计信息 self.stats["active_streams"] += 1 @@ -156,6 +152,7 @@ class StreamLoopManager: logger.error(f"启动流循环任务失败 {stream_id}: {e}") # 释放槽位 from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager + adaptive_manager = get_adaptive_stream_manager() adaptive_manager.release_stream_slot(stream_id) @@ -179,8 +176,8 @@ class StreamLoopManager: except Exception: from src.chat.message_manager.adaptive_stream_manager import StreamPriority - return StreamPriority.NORMAL + return StreamPriority.NORMAL async def stop_stream_loop(self, stream_id: str) -> bool: """停止指定流的循环任务 @@ -244,11 +241,12 @@ class StreamLoopManager: # 3. 更新自适应管理器指标 try: from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager + adaptive_manager = get_adaptive_stream_manager() adaptive_manager.update_stream_metrics( stream_id, message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算 - last_activity=time.time() + last_activity=time.time(), ) except Exception as e: logger.debug(f"更新流指标失败: {e}") @@ -300,6 +298,7 @@ class StreamLoopManager: # 释放自适应管理器的槽位 try: from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager + adaptive_manager = get_adaptive_stream_manager() adaptive_manager.release_stream_slot(stream_id) logger.debug(f"释放自适应流处理槽位: {stream_id}") @@ -553,12 +552,12 @@ class StreamLoopManager: existing_task.cancel() # 创建异步任务来等待取消完成,并添加异常处理 cancel_task = asyncio.create_task( - self._wait_for_task_cancel(stream_id, existing_task), - name=f"cancel_existing_loop_{stream_id}" + self._wait_for_task_cancel(stream_id, existing_task), name=f"cancel_existing_loop_{stream_id}" ) # 为取消任务添加异常处理,避免孤儿任务 cancel_task.add_done_callback( - lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception() + lambda task: logger.debug(f"取消任务完成: {stream_id}") + if not task.exception() else logger.error(f"取消任务异常: {stream_id} - {task.exception()}") ) # 从字典中移除 @@ -582,10 +581,7 @@ class StreamLoopManager: logger.info(f"流 {stream_id} 当前未读消息数: {unread_count}") # 创建新的流循环任务 - new_task = asyncio.create_task( - self._stream_loop(stream_id), - name=f"force_stream_loop_{stream_id}" - ) + new_task = asyncio.create_task(self._stream_loop(stream_id), name=f"force_stream_loop_{stream_id}") self.stream_loops[stream_id] = new_task self.stats["total_loops"] += 1 diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 4e8de1134..b6ad48d4e 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -59,6 +59,7 @@ class MessageManager: # 启动批量数据库写入器 try: from src.chat.message_manager.batch_database_writer import init_batch_writer + await init_batch_writer() except Exception as e: logger.error(f"启动批量数据库写入器失败: {e}") @@ -66,6 +67,7 @@ class MessageManager: # 启动流缓存管理器 try: from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager + await init_stream_cache_manager() except Exception as e: logger.error(f"启动流缓存管理器失败: {e}") @@ -73,6 +75,7 @@ class MessageManager: # 启动自适应流管理器 try: from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager + await init_adaptive_stream_manager() logger.info("🎯 自适应流管理器已启动") except Exception as e: @@ -97,6 +100,7 @@ class MessageManager: # 停止批量数据库写入器 try: from src.chat.message_manager.batch_database_writer import shutdown_batch_writer + await shutdown_batch_writer() logger.info("📦 批量数据库写入器已停止") except Exception as e: @@ -105,6 +109,7 @@ class MessageManager: # 停止流缓存管理器 try: from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager + await shutdown_stream_cache_manager() logger.info("🗄️ 流缓存管理器已停止") except Exception as e: @@ -113,6 +118,7 @@ class MessageManager: # 停止自适应流管理器 try: from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager + await shutdown_adaptive_stream_manager() logger.info("🎯 自适应流管理器已停止") except Exception as e: diff --git a/src/chat/message_manager/stream_cache_manager.py b/src/chat/message_manager/stream_cache_manager.py index 3e8cdebac..ea85c3855 100644 --- a/src/chat/message_manager/stream_cache_manager.py +++ b/src/chat/message_manager/stream_cache_manager.py @@ -19,6 +19,7 @@ logger = get_logger("stream_cache_manager") @dataclass class StreamCacheStats: """缓存统计信息""" + hot_cache_size: int = 0 warm_storage_size: int = 0 cold_storage_size: int = 0 @@ -38,9 +39,9 @@ class TieredStreamCache: max_warm_size: int = 500, max_cold_size: int = 2000, cleanup_interval: float = 300.0, # 5分钟清理一次 - hot_timeout: float = 1800.0, # 30分钟未访问降级到warm - warm_timeout: float = 7200.0, # 2小时未访问降级到cold - cold_timeout: float = 86400.0, # 24小时未访问删除 + hot_timeout: float = 1800.0, # 30分钟未访问降级到warm + warm_timeout: float = 7200.0, # 2小时未访问降级到cold + cold_timeout: float = 86400.0, # 24小时未访问删除 ): self.max_hot_size = max_hot_size self.max_warm_size = max_warm_size @@ -52,8 +53,8 @@ class TieredStreamCache: # 三层缓存存储 self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU) - self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间) - self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间) + self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间) + self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间) # 统计信息 self.stats = StreamCacheStats() @@ -134,11 +135,7 @@ class TieredStreamCache: # 4. 缓存未命中,创建新流 self.stats.cache_misses += 1 stream = create_optimized_chat_stream( - stream_id=stream_id, - platform=platform, - user_info=user_info, - group_info=group_info, - data=data + stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data ) logger.debug(f"缓存未命中,创建新流: {stream_id}") @@ -294,9 +291,9 @@ class TieredStreamCache: # 估算内存使用(粗略估计) self.stats.total_memory_usage = ( - len(self.hot_cache) * 1024 + # 每个热流约1KB - len(self.warm_storage) * 512 + # 每个温流约512B - len(self.cold_storage) * 256 # 每个冷流约256B + len(self.hot_cache) * 1024 # 每个热流约1KB + + len(self.warm_storage) * 512 # 每个温流约512B + + len(self.cold_storage) * 256 # 每个冷流约256B ) if sum(cleanup_stats.values()) > 0: diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 059160471..c1ca149b3 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -557,7 +557,11 @@ class ChatBot: # 将兴趣度结果同步回原始消息,便于后续流程使用 message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0)) - setattr(message, "should_reply", getattr(db_message, "should_reply", getattr(message, "should_reply", False))) + setattr( + message, + "should_reply", + getattr(db_message, "should_reply", getattr(message, "should_reply", False)), + ) setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False))) # 存储消息到数据库,只进行一次写入 diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index a7eee5ed5..223b50f9d 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -298,8 +298,10 @@ class ChatStream: db_message.should_reply = result.should_reply db_message.should_act = result.should_act - logger.debug(f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, " - f"should_reply: {result.should_reply}, should_act: {result.should_act}") + logger.debug( + f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, " + f"should_reply: {result.should_reply}, should_act: {result.should_act}" + ) else: logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}") # 使用默认值 @@ -521,18 +523,17 @@ class ChatManager: # 优先使用缓存管理器(优化版本) try: from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager + cache_manager = get_stream_cache_manager() if cache_manager.is_running: optimized_stream = await cache_manager.get_or_create_stream( - stream_id=stream_id, - platform=platform, - user_info=user_info, - group_info=group_info + stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info ) # 设置消息上下文 from .message import MessageRecv + if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): optimized_stream.set_context(self.last_messages[stream_id]) @@ -715,7 +716,7 @@ class ChatManager: success = await batch_writer.schedule_stream_update( stream_id=stream_data_dict["stream_id"], update_data=ChatManager._prepare_stream_data(stream_data_dict), - priority=1 # 流更新的优先级 + priority=1, # 流更新的优先级 ) if success: stream.saved = True @@ -738,7 +739,7 @@ class ChatManager: result = await batch_update( model_class=ChatStreams, conditions={"stream_id": stream_data_dict["stream_id"]}, - data=ChatManager._prepare_stream_data(stream_data_dict) + data=ChatManager._prepare_stream_data(stream_data_dict), ) if result and result > 0: stream.saved = True @@ -874,43 +875,43 @@ chat_manager = None def _convert_to_original_stream(self, optimized_stream) -> "ChatStream": - """将OptimizedChatStream转换为原始ChatStream以保持兼容性""" - try: - # 创建原始ChatStream实例 - original_stream = ChatStream( - stream_id=optimized_stream.stream_id, - platform=optimized_stream.platform, - user_info=optimized_stream._get_effective_user_info(), - group_info=optimized_stream._get_effective_group_info() - ) + """将OptimizedChatStream转换为原始ChatStream以保持兼容性""" + try: + # 创建原始ChatStream实例 + original_stream = ChatStream( + stream_id=optimized_stream.stream_id, + platform=optimized_stream.platform, + user_info=optimized_stream._get_effective_user_info(), + group_info=optimized_stream._get_effective_group_info(), + ) - # 复制状态 - original_stream.create_time = optimized_stream.create_time - original_stream.last_active_time = optimized_stream.last_active_time - original_stream.sleep_pressure = optimized_stream.sleep_pressure - original_stream.base_interest_energy = optimized_stream.base_interest_energy - original_stream._focus_energy = optimized_stream._focus_energy - original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive - original_stream.saved = optimized_stream.saved + # 复制状态 + original_stream.create_time = optimized_stream.create_time + original_stream.last_active_time = optimized_stream.last_active_time + original_stream.sleep_pressure = optimized_stream.sleep_pressure + original_stream.base_interest_energy = optimized_stream.base_interest_energy + original_stream._focus_energy = optimized_stream._focus_energy + original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive + original_stream.saved = optimized_stream.saved - # 复制上下文信息(如果存在) - if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context: - original_stream.stream_context = optimized_stream._stream_context + # 复制上下文信息(如果存在) + if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context: + original_stream.stream_context = optimized_stream._stream_context - if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager: - original_stream.context_manager = optimized_stream._context_manager + if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager: + original_stream.context_manager = optimized_stream._context_manager - return original_stream + return original_stream - except Exception as e: - logger.error(f"转换OptimizedChatStream失败: {e}") - # 如果转换失败,创建一个新的原始流 - return ChatStream( - stream_id=optimized_stream.stream_id, - platform=optimized_stream.platform, - user_info=optimized_stream._get_effective_user_info(), - group_info=optimized_stream._get_effective_group_info() - ) + except Exception as e: + logger.error(f"转换OptimizedChatStream失败: {e}") + # 如果转换失败,创建一个新的原始流 + return ChatStream( + stream_id=optimized_stream.stream_id, + platform=optimized_stream.platform, + user_info=optimized_stream._get_effective_user_info(), + group_info=optimized_stream._get_effective_group_info(), + ) def get_chat_manager(): diff --git a/src/chat/message_receive/optimized_chat_stream.py b/src/chat/message_receive/optimized_chat_stream.py index c9b32d6f8..d9280c61a 100644 --- a/src/chat/message_receive/optimized_chat_stream.py +++ b/src/chat/message_receive/optimized_chat_stream.py @@ -80,10 +80,7 @@ class OptimizedChatStream: ): # 共享的只读数据 self._shared_context = SharedContext( - stream_id=stream_id, - platform=platform, - user_info=user_info, - group_info=group_info + stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info ) # 本地修改数据 @@ -269,14 +266,13 @@ class OptimizedChatStream: self._stream_context = StreamContext( stream_id=self.stream_id, chat_type=ChatType.GROUP if self.group_info else ChatType.PRIVATE, - chat_mode=ChatMode.NORMAL + chat_mode=ChatMode.NORMAL, ) # 创建单流上下文管理器 from src.chat.message_manager.context_manager import SingleStreamContextManager - self._context_manager = SingleStreamContextManager( - stream_id=self.stream_id, context=self._stream_context - ) + + self._context_manager = SingleStreamContextManager(stream_id=self.stream_id, context=self._stream_context) @property def stream_context(self): @@ -331,9 +327,11 @@ class OptimizedChatStream: # 恢复stream_context信息 if "stream_context_chat_type" in data: from src.plugin_system.base.component_types import ChatMode, ChatType + instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"]) if "stream_context_chat_mode" in data: from src.plugin_system.base.component_types import ChatMode, ChatType + instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"]) # 恢复interruption_count信息 @@ -352,6 +350,7 @@ class OptimizedChatStream: if isinstance(actions, str): try: import json + actions = json.loads(actions) except json.JSONDecodeError: logger.warning(f"无法解析actions JSON字符串: {actions}") @@ -458,7 +457,7 @@ class OptimizedChatStream: stream_id=self.stream_id, platform=self.platform, user_info=self._get_effective_user_info(), - group_info=self._get_effective_group_info() + group_info=self._get_effective_group_info(), ) # 复制本地修改(但不触发写时复制) @@ -482,9 +481,5 @@ def create_optimized_chat_stream( ) -> OptimizedChatStream: """创建优化版聊天流实例""" return OptimizedChatStream( - stream_id=stream_id, - platform=platform, - user_info=user_info, - group_info=group_info, - data=data + stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data ) diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index ec75eaf74..caf720da7 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -196,18 +196,20 @@ class ChatterActionManager: thinking_id=thinking_id or "", action_done=True, action_build_into_prompt=False, - action_prompt_display=reason + action_prompt_display=reason, ) else: - asyncio.create_task(database_api.store_action_info( - chat_stream=chat_stream, - action_build_into_prompt=False, - action_prompt_display=reason, - action_done=True, - thinking_id=thinking_id, - action_data={"reason": reason}, - action_name="no_reply", - )) + asyncio.create_task( + database_api.store_action_info( + chat_stream=chat_stream, + action_build_into_prompt=False, + action_prompt_display=reason, + action_done=True, + thinking_id=thinking_id, + action_data={"reason": reason}, + action_name="no_reply", + ) + ) # 自动清空所有未读消息 asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply")) @@ -228,7 +230,9 @@ class ChatterActionManager: # 记录执行的动作到目标消息 if success: - asyncio.create_task(self._record_action_to_message(chat_stream, action_name, target_message, action_data)) + asyncio.create_task( + self._record_action_to_message(chat_stream, action_name, target_message, action_data) + ) # 自动清空所有未读消息 if clear_unread_messages: asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name)) @@ -496,7 +500,7 @@ class ChatterActionManager: thinking_id=thinking_id or "", action_done=True, action_build_into_prompt=False, - action_prompt_display=action_prompt_display + action_prompt_display=action_prompt_display, ) else: await database_api.store_action_info( @@ -618,9 +622,15 @@ class ChatterActionManager: self._pending_actions = [] # 清空队列 logger.debug("已禁用批量存储模式") - def add_action_to_batch(self, action_name: str, action_data: dict, thinking_id: str = "", - action_done: bool = True, action_build_into_prompt: bool = False, - action_prompt_display: str = ""): + def add_action_to_batch( + self, + action_name: str, + action_data: dict, + thinking_id: str = "", + action_done: bool = True, + action_build_into_prompt: bool = False, + action_prompt_display: str = "", + ): """添加动作到批量存储列表""" if not self._batch_storage_enabled: return False @@ -632,7 +642,7 @@ class ChatterActionManager: "action_done": action_done, "action_build_into_prompt": action_build_into_prompt, "action_prompt_display": action_prompt_display, - "timestamp": time.time() + "timestamp": time.time(), } self._pending_actions.append(action_record) logger.debug(f"已添加动作到批量存储列表: {action_name} (当前待处理: {len(self._pending_actions)} 个)") @@ -658,7 +668,7 @@ class ChatterActionManager: action_done=action_data.get("action_done", True), action_build_into_prompt=action_data.get("action_build_into_prompt", False), action_prompt_display=action_data.get("action_prompt_display", ""), - thinking_id=action_data.get("thinking_id", "") + thinking_id=action_data.get("thinking_id", ""), ) if result: stored_count += 1 diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 72e72fb27..5f42c0042 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -589,7 +589,7 @@ class DefaultReplyer: # 获取记忆系统实例 memory_system = get_memory_system() - # 使用统一记忆系统检索相关记忆 + # 使用统一记忆系统检索相关记忆 enhanced_memories = await memory_system.retrieve_relevant_memories( query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10 ) @@ -1208,12 +1208,32 @@ class DefaultReplyer: # 并行执行六个构建任务 tasks = { - "expression_habits": asyncio.create_task(self._time_and_run_task(self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits")), - "relation_info": asyncio.create_task(self._time_and_run_task(self.build_relation_info(sender, target), "relation_info")), - "memory_block": asyncio.create_task(self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block")), - "tool_info": asyncio.create_task(self._time_and_run_task(self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info")), - "prompt_info": asyncio.create_task(self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info")), - "cross_context": asyncio.create_task(self._time_and_run_task(Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info), "cross_context")), + "expression_habits": asyncio.create_task( + self._time_and_run_task( + self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits" + ) + ), + "relation_info": asyncio.create_task( + self._time_and_run_task(self.build_relation_info(sender, target), "relation_info") + ), + "memory_block": asyncio.create_task( + self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block") + ), + "tool_info": asyncio.create_task( + self._time_and_run_task( + self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), + "tool_info", + ) + ), + "prompt_info": asyncio.create_task( + self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info") + ), + "cross_context": asyncio.create_task( + self._time_and_run_task( + Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info), + "cross_context", + ) + ), } # 设置超时 @@ -1512,13 +1532,8 @@ class DefaultReplyer: chat_target_name = ( self.chat_target_info.get("person_name") or self.chat_target_info.get("user_nickname") or "对方" ) - await global_prompt_manager.format_prompt( - "chat_target_private1", sender_name=chat_target_name - ) - await global_prompt_manager.format_prompt( - "chat_target_private2", sender_name=chat_target_name - ) - + await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name) + await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name) # 使用新的统一Prompt系统 - Expressor模式,创建PromptParameters prompt_parameters = PromptParameters( diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 3e989c8ab..5d99d9ca8 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -121,13 +121,14 @@ class VideoAnalyzer: async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str: from src.llm_models.payload_content.message import MessageBuilder, RoleType from src.llm_models.utils_model import RequestType + prompt = self.batch_analysis_prompt.format( personality_core=self.personality_core, personality_side=self.personality_side ) if question: prompt += f"\n用户关注: {question}" desc = [ - (f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧") + (f"第{i + 1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i + 1}帧") for i, (_b, ts) in enumerate(frames) ] prompt += "\n帧列表: " + ", ".join(desc) @@ -151,16 +152,16 @@ class VideoAnalyzer: async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str: results: list[str] = [] for i, (b64, ts) in enumerate(frames): - prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "") + prompt = f"分析第{i + 1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "") if question: prompt += f"\n关注: {question}" try: text, _ = await self.video_llm.generate_response_for_image( prompt=prompt, image_base64=b64, image_format="jpeg" ) - results.append(f"第{i+1}帧: {text}") + results.append(f"第{i + 1}帧: {text}") except Exception as e: # pragma: no cover - results.append(f"第{i+1}帧: 失败 {e}") + results.append(f"第{i + 1}帧: 失败 {e}") if i < len(frames) - 1: await asyncio.sleep(self.frame_analysis_delay) summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results) @@ -182,7 +183,9 @@ class VideoAnalyzer: mode = self.analysis_mode if mode == "auto": mode = "batch" if len(frames) <= 20 else "sequential" - text = await (self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question)) + text = await ( + self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question) + ) return True, text async def analyze_video_from_bytes( diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 9f988cdb7..3c1cd02ef 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -220,7 +220,9 @@ class DatabaseMessages(BaseDataModel): "chat_info_user_cardname": self.chat_info.user_info.user_cardname, } - def update_message_info(self, interest_value: float | None = None, actions: list | None = None, should_reply: bool | None = None): + def update_message_info( + self, interest_value: float | None = None, actions: list | None = None, should_reply: bool | None = None + ): """ 更新消息信息 diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index 164ce4d2d..ad08fbb6c 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -53,8 +53,6 @@ class StreamContext(BaseDataModel): priority_mode: str | None = None priority_info: dict | None = None - - def add_action_to_message(self, message_id: str, action: str): """ 向指定消息添加执行的动作 @@ -75,9 +73,6 @@ class StreamContext(BaseDataModel): message.add_action(action) break - - - def mark_message_as_read(self, message_id: str): """标记消息为已读""" for msg in self.unread_messages: diff --git a/src/common/database/connection_pool_manager.py b/src/common/database/connection_pool_manager.py index 6ce3f517e..4ca789b6f 100644 --- a/src/common/database/connection_pool_manager.py +++ b/src/common/database/connection_pool_manager.py @@ -78,7 +78,7 @@ class ConnectionPoolManager: "total_expired": 0, "active_connections": 0, "pool_hits": 0, - "pool_misses": 0 + "pool_misses": 0, } # 后台清理任务 @@ -156,7 +156,9 @@ class ConnectionPoolManager: if connection_info: connection_info.mark_released() - async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> ConnectionInfo | None: + async def _get_reusable_connection( + self, session_factory: async_sessionmaker[AsyncSession] + ) -> ConnectionInfo | None: """获取可复用的连接""" async with self._lock: # 清理过期连接 @@ -164,9 +166,7 @@ class ConnectionPoolManager: # 查找可复用的连接 for connection_info in list(self._connections): - if (not connection_info.in_use and - not connection_info.is_expired(self.max_lifetime, self.max_idle)): - + if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle): # 验证连接是否仍然有效 try: # 执行一个简单的查询来验证连接 @@ -191,8 +191,7 @@ class ConnectionPoolManager: expired_connections = [] for connection_info in list(self._connections): - if (connection_info.is_expired(self.max_lifetime, self.max_idle) and - not connection_info.in_use): + if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use: expired_connections.append(connection_info) for connection_info in expired_connections: @@ -238,7 +237,8 @@ class ConnectionPoolManager: "max_pool_size": self.max_pool_size, "pool_efficiency": ( self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"]) - ) * 100 + ) + * 100, } diff --git a/src/common/database/db_batch_scheduler.py b/src/common/database/db_batch_scheduler.py index b0974cc42..a09f7fb84 100644 --- a/src/common/database/db_batch_scheduler.py +++ b/src/common/database/db_batch_scheduler.py @@ -24,6 +24,7 @@ T = TypeVar("T") @dataclass class BatchOperation: """批量操作基础类""" + operation_type: str # 'select', 'insert', 'update', 'delete' model_class: Any conditions: dict[str, Any] @@ -40,6 +41,7 @@ class BatchOperation: @dataclass class BatchResult: """批量操作结果""" + success: bool data: Any = None error: str | None = None @@ -48,10 +50,12 @@ class BatchResult: class DatabaseBatchScheduler: """数据库批量调度器""" - def __init__(self, - batch_size: int = 50, - max_wait_time: float = 0.1, # 100ms - max_queue_size: int = 1000): + def __init__( + self, + batch_size: int = 50, + max_wait_time: float = 0.1, # 100ms + max_queue_size: int = 1000, + ): self.batch_size = batch_size self.max_wait_time = max_wait_time self.max_queue_size = max_queue_size @@ -65,12 +69,7 @@ class DatabaseBatchScheduler: self._lock = asyncio.Lock() # 统计信息 - self.stats = { - "total_operations": 0, - "batched_operations": 0, - "cache_hits": 0, - "execution_time": 0.0 - } + self.stats = {"total_operations": 0, "batched_operations": 0, "cache_hits": 0, "execution_time": 0.0} # 简单的结果缓存(用于频繁的查询) self._result_cache: dict[str, tuple[Any, float]] = {} @@ -105,11 +104,7 @@ class DatabaseBatchScheduler: def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str: """生成缓存键""" # 简单的缓存键生成,实际可以根据需要优化 - key_parts = [ - operation_type, - model_class.__name__, - str(sorted(conditions.items())) - ] + key_parts = [operation_type, model_class.__name__, str(sorted(conditions.items()))] return "|".join(key_parts) def _get_from_cache(self, cache_key: str) -> Any | None: @@ -132,11 +127,7 @@ class DatabaseBatchScheduler: """添加操作到队列""" # 检查是否可以立即返回缓存结果 if operation.operation_type == "select": - cache_key = self._generate_cache_key( - operation.operation_type, - operation.model_class, - operation.conditions - ) + cache_key = self._generate_cache_key(operation.operation_type, operation.model_class, operation.conditions) cached_result = self._get_from_cache(cache_key) if cached_result is not None: if operation.callback: @@ -180,10 +171,7 @@ class DatabaseBatchScheduler: return # 复制队列内容,避免长时间占用锁 - queues_copy = { - key: deque(operations) - for key, operations in self.operation_queues.items() - } + queues_copy = {key: deque(operations) for key, operations in self.operation_queues.items()} # 清空原队列 for queue in self.operation_queues.values(): queue.clear() @@ -240,9 +228,7 @@ class DatabaseBatchScheduler: # 缓存查询结果 if operation.operation_type == "select": cache_key = self._generate_cache_key( - operation.operation_type, - operation.model_class, - operation.conditions + operation.operation_type, operation.model_class, operation.conditions ) self._set_cache(cache_key, result) @@ -287,12 +273,9 @@ class DatabaseBatchScheduler: else: # 需要根据条件过滤结果 op_result = [ - item for item in data - if all( - getattr(item, k) == v - for k, v in op.conditions.items() - if hasattr(item, k) - ) + item + for item in data + if all(getattr(item, k) == v for k, v in op.conditions.items() if hasattr(item, k)) ] results.append(op_result) @@ -429,7 +412,7 @@ class DatabaseBatchScheduler: **self.stats, "cache_size": len(self._result_cache), "queue_sizes": {k: len(v) for k, v in self.operation_queues.items()}, - "is_running": self._is_running + "is_running": self._is_running, } @@ -452,43 +435,25 @@ async def get_batch_session(): # 便捷函数 async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any: """批量查询""" - operation = BatchOperation( - operation_type="select", - model_class=model_class, - conditions=conditions - ) + operation = BatchOperation(operation_type="select", model_class=model_class, conditions=conditions) return await db_batch_scheduler.add_operation(operation) async def batch_insert(model_class: Any, data: dict[str, Any]) -> int: """批量插入""" - operation = BatchOperation( - operation_type="insert", - model_class=model_class, - conditions={}, - data=data - ) + operation = BatchOperation(operation_type="insert", model_class=model_class, conditions={}, data=data) return await db_batch_scheduler.add_operation(operation) async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int: """批量更新""" - operation = BatchOperation( - operation_type="update", - model_class=model_class, - conditions=conditions, - data=data - ) + operation = BatchOperation(operation_type="update", model_class=model_class, conditions=conditions, data=data) return await db_batch_scheduler.add_operation(operation) async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int: """批量删除""" - operation = BatchOperation( - operation_type="delete", - model_class=model_class, - conditions=conditions - ) + operation = BatchOperation(operation_type="delete", model_class=model_class, conditions=conditions) return await db_batch_scheduler.add_operation(operation) diff --git a/src/common/logger.py b/src/common/logger.py index a0186a8e9..550478515 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -304,8 +304,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress "library_log_levels": {"aiohttp": "WARNING"}, } - - # 误加的即刻线程启动已移除;真正的线程在 start_log_cleanup_task 中按午夜调度 + # 误加的即刻线程启动已移除;真正的线程在 start_log_cleanup_task 中按午夜调度 try: if config_path.exists(): diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 75997c7a2..3af0b1fb1 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -37,7 +37,9 @@ class DatabaseConfig(ValidatedConfigBase): connection_timeout: int = Field(default=10, ge=1, description="连接超时时间") # 批量动作记录存储配置 - batch_action_storage_enabled: bool = Field(default=True, description="是否启用批量保存动作记录(开启后将多个动作一次性写入数据库,提升性能)") + batch_action_storage_enabled: bool = Field( + default=True, description="是否启用批量保存动作记录(开启后将多个动作一次性写入数据库,提升性能)" + ) class BotConfig(ValidatedConfigBase): @@ -355,7 +357,7 @@ class MemoryConfig(ValidatedConfigBase): # 双峰分布配置 [近期均值, 近期标准差, 近期权重, 远期均值, 远期标准差, 远期权重] hippocampus_distribution_config: list[float] = Field( default=[12.0, 8.0, 0.7, 48.0, 24.0, 0.3], - description="海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]" + description="海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]", ) # 自适应采样配置 @@ -690,7 +692,6 @@ class AffinityFlowConfig(ValidatedConfigBase): base_relationship_score: float = Field(default=0.5, description="基础人物关系分") - class ProactiveThinkingConfig(ValidatedConfigBase): """主动思考(主动发起对话)功能配置""" diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index cd017b6a1..246b0618b 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -189,11 +189,11 @@ class ClientRegistry: bool: 事件循环是否变化 """ current_loop_id = self._get_current_loop_id() - + # 如果没有缓存的循环ID,说明是首次创建 if provider_name not in self._event_loop_cache: return False - + # 比较当前循环ID与缓存的循环ID cached_loop_id = self._event_loop_cache[provider_name] return current_loop_id != cached_loop_id @@ -208,7 +208,7 @@ class ClientRegistry: BaseClient: 注册的API客户端实例 """ provider_name = api_provider.name - + # 如果强制创建新实例,直接创建不使用缓存 if force_new: if client_class := self.client_registry.get(api_provider.client_type): @@ -224,7 +224,7 @@ class ClientRegistry: # 事件循环已变化,需要重新创建实例 logger.debug(f"检测到事件循环变化,为 {provider_name} 重新创建客户端实例") self._loop_change_count += 1 - + # 移除旧实例 if provider_name in self.client_instance_cache: del self.client_instance_cache[provider_name] @@ -237,7 +237,7 @@ class ClientRegistry: self._event_loop_cache[provider_name] = self._get_current_loop_id() else: raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") - + return self.client_instance_cache[provider_name] def get_cache_stats(self) -> dict: diff --git a/src/llm_models/model_client/mcp_sse_client.py b/src/llm_models/model_client/mcp_sse_client.py index 91e58cde2..afed7c301 100644 --- a/src/llm_models/model_client/mcp_sse_client.py +++ b/src/llm_models/model_client/mcp_sse_client.py @@ -50,14 +50,16 @@ def _convert_messages_to_mcp(messages: list[Message]) -> list[dict[str, Any]]: for item in message.content: if isinstance(item, tuple): # 图片内容 - content_parts.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": f"image/{item[0].lower()}", - "data": item[1], - }, - }) + content_parts.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": f"image/{item[0].lower()}", + "data": item[1], + }, + } + ) elif isinstance(item, str): # 文本内容 content_parts.append({"type": "text", "text": item}) @@ -138,9 +140,7 @@ async def _parse_sse_stream( async with session.post(url, json=payload, headers=headers) as response: if response.status != 200: error_text = await response.text() - raise RespNotOkException( - response.status, f"MCP SSE请求失败: {error_text}" - ) + raise RespNotOkException(response.status, f"MCP SSE请求失败: {error_text}") # 解析SSE流 async for line in response.content: @@ -258,10 +258,7 @@ async def _parse_sse_stream( response.reasoning_content = reasoning_buffer.getvalue() if tool_calls_buffer: - response.tool_calls = [ - ToolCall(call_id, func_name, args) - for call_id, func_name, args in tool_calls_buffer - ] + response.tool_calls = [ToolCall(call_id, func_name, args) for call_id, func_name, args in tool_calls_buffer] # 关闭缓冲区 content_buffer.close() @@ -351,9 +348,7 @@ class MCPSSEClient(BaseClient): url = f"{self.api_provider.base_url}/v1/messages" try: - response, usage_record = await _parse_sse_stream( - session, url, payload, headers, interrupt_flag - ) + response, usage_record = await _parse_sse_stream(session, url, payload, headers, interrupt_flag) except Exception as e: logger.error(f"MCP SSE请求失败: {e}") raise diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 3ddbe00bd..06d6b50b2 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -378,7 +378,7 @@ class OpenaiClient(BaseClient): # 类级别的全局缓存:所有 OpenaiClient 实例共享 _global_client_cache: dict[int, AsyncOpenAI] = {} """全局 AsyncOpenAI 客户端缓存:config_hash -> AsyncOpenAI 实例""" - + def __init__(self, api_provider: APIProvider): super().__init__(api_provider) self._config_hash = self._calculate_config_hash() @@ -396,33 +396,31 @@ class OpenaiClient(BaseClient): def _create_client(self) -> AsyncOpenAI: """ 获取或创建 OpenAI 客户端实例(全局缓存) - + 多个 OpenaiClient 实例如果配置相同(base_url + api_key + timeout), 将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。 """ # 检查全局缓存 if self._config_hash in self._global_client_cache: return self._global_client_cache[self._config_hash] - + # 创建新的 AsyncOpenAI 实例 logger.debug( - f"创建新的 AsyncOpenAI 客户端实例 " - f"(base_url={self.api_provider.base_url}, " - f"config_hash={self._config_hash})" + f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash})" ) - + client = AsyncOpenAI( base_url=self.api_provider.base_url, api_key=self.api_provider.get_api_key(), max_retries=0, timeout=self.api_provider.timeout, ) - + # 存入全局缓存 self._global_client_cache[self._config_hash] = client - + return client - + @classmethod def get_cache_stats(cls) -> dict: """获取全局缓存统计信息""" diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 4719b60b3..dc4374c06 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -280,7 +280,9 @@ class _PromptProcessor: 这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。 """ - async def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str: + async def prepare_prompt( + self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str + ) -> str: """ 为请求准备最终的提示词。 diff --git a/src/main.py b/src/main.py index 63faddfbc..46c6ca66d 100644 --- a/src/main.py +++ b/src/main.py @@ -88,6 +88,7 @@ class MainSystem: def _setup_signal_handlers(self) -> None: """设置信号处理器""" + def signal_handler(signum, frame): if self._shutting_down: logger.warning("系统已经在关闭过程中,忽略重复信号") @@ -132,6 +133,7 @@ class MainSystem: try: from src.plugin_system.apis.component_manage_api import get_components_info_by_type from src.plugin_system.base.component_types import ComponentType + interest_calculators = get_components_info_by_type(ComponentType.INTEREST_CALCULATOR) logger.info(f"通过组件注册表发现 {len(interest_calculators)} 个兴趣计算器组件") except Exception as e: @@ -143,6 +145,7 @@ class MainSystem: # 初始化兴趣度管理器 from src.chat.interest_system.interest_manager import get_interest_manager + interest_manager = get_interest_manager() await interest_manager.initialize() @@ -159,7 +162,10 @@ class MainSystem: try: from src.plugin_system.core.component_registry import component_registry - component_class = component_registry.get_component_class(calc_name, ComponentType.INTEREST_CALCULATOR) + + component_class = component_registry.get_component_class( + calc_name, ComponentType.INTEREST_CALCULATOR + ) if not component_class: logger.warning(f"无法找到 {calc_name} 的组件类") @@ -208,6 +214,7 @@ class MainSystem: # 停止数据库服务 try: from src.common.database.database import stop_database + cleanup_tasks.append(("数据库服务", stop_database())) except Exception as e: logger.error(f"准备停止数据库服务时出错: {e}") @@ -215,6 +222,7 @@ class MainSystem: # 停止消息管理器 try: from src.chat.message_manager import message_manager + cleanup_tasks.append(("消息管理器", message_manager.stop())) except Exception as e: logger.error(f"准备停止消息管理器时出错: {e}") @@ -222,6 +230,7 @@ class MainSystem: # 停止消息重组器 try: from src.utils.message_chunker import reassembler + cleanup_tasks.append(("消息重组器", reassembler.stop_cleanup_task())) except Exception as e: logger.error(f"准备停止消息重组器时出错: {e}") @@ -236,15 +245,18 @@ class MainSystem: # 触发停止事件 try: from src.plugin_system.core.event_manager import event_manager - cleanup_tasks.append(("插件系统停止事件", - event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))) + + cleanup_tasks.append( + ("插件系统停止事件", event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM")) + ) except Exception as e: logger.error(f"准备触发停止事件时出错: {e}") # 停止表情管理器 try: - cleanup_tasks.append(("表情管理器", - asyncio.get_event_loop().run_in_executor(None, get_emoji_manager().shutdown))) + cleanup_tasks.append( + ("表情管理器", asyncio.get_event_loop().run_in_executor(None, get_emoji_manager().shutdown)) + ) except Exception as e: logger.error(f"准备停止表情管理器时出错: {e}") @@ -275,7 +287,7 @@ class MainSystem: try: results = await asyncio.wait_for( asyncio.gather(*tasks, return_exceptions=True), - timeout=30.0 # 30秒超时 + timeout=30.0, # 30秒超时 ) # 记录结果 @@ -389,6 +401,7 @@ MoFox_Bot(第三方修改版) # 注册API路由 try: from src.api.message_router import router as message_router + self.server.register_router(message_router, prefix="/api") logger.info("API路由注册成功") except Exception as e: @@ -405,6 +418,7 @@ MoFox_Bot(第三方修改版) mcp_config = global_config.get("mcp_servers", []) if mcp_config: from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider + await mcp_tool_provider.initialize(mcp_config) logger.info("MCP工具提供器初始化成功") except Exception as e: @@ -445,6 +459,7 @@ MoFox_Bot(第三方修改版) # 初始化LPMM知识库 try: from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge + initialize_lpmm_knowledge() logger.info("LPMM知识库初始化成功") except Exception as e: @@ -456,6 +471,7 @@ MoFox_Bot(第三方修改版) # 启动消息重组器 try: from src.utils.message_chunker import reassembler + await reassembler.start_cleanup_task() logger.info("消息重组器已启动") except Exception as e: @@ -464,6 +480,7 @@ MoFox_Bot(第三方修改版) # 启动消息管理器 try: from src.chat.message_manager import message_manager + await message_manager.start() logger.info("消息管理器已启动") except Exception as e: @@ -504,6 +521,7 @@ MoFox_Bot(第三方修改版) def _safe_init(self, component_name: str, init_func) -> callable: """安全初始化组件,捕获异常""" + async def wrapper(): try: result = init_func() @@ -514,6 +532,7 @@ MoFox_Bot(第三方修改版) except Exception as e: logger.error(f"{component_name}初始化失败: {e}") return False + return wrapper async def schedule_tasks(self) -> None: diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 7974dba4d..370a65d41 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -59,6 +59,7 @@ class ChatMood: """异步初始化方法""" if not self._initialized: from src.chat.message_receive.chat_stream import get_chat_manager + chat_manager = get_chat_manager() self.chat_stream = await chat_manager.get_stream(self.chat_id) diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 554af3260..4e258a8fd 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -69,6 +69,7 @@ class RelationshipBuilder: if not self._log_prefix_initialized: try: from src.chat.message_receive.chat_stream import get_chat_manager + chat_name = await get_chat_manager().get_stream_name(self.chat_id) self.log_prefix = f"[{chat_name}]" except Exception: diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 655734bc0..8783d5e7f 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -85,6 +85,7 @@ class RelationshipFetcher: """异步初始化log_prefix""" if not self._log_prefix_initialized: from src.chat.message_receive.chat_stream import get_chat_manager + name = await get_chat_manager().get_stream_name(self.chat_id) self.log_prefix = f"[{name}] 实时信息" self._log_prefix_initialized = True diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index a76fc6e74..dc6a1e6d9 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -59,6 +59,7 @@ async def get_replyer( logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}") # 动态导入避免循环依赖 from src.chat.replyer.replyer_manager import replyer_manager + return await replyer_manager.get_replyer( chat_stream=chat_stream, chat_id=chat_id, diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py index d64d366ba..285df7884 100644 --- a/src/plugin_system/apis/tool_api.py +++ b/src/plugin_system/apis/tool_api.py @@ -39,6 +39,7 @@ def get_llm_available_tool_definitions(): # 添加MCP工具 try: from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider + mcp_tools = mcp_tool_provider.get_mcp_tool_definitions() tool_definitions.extend(mcp_tools) if mcp_tools: diff --git a/src/plugin_system/base/base_event.py b/src/plugin_system/base/base_event.py index a5638ebd3..47d410c60 100644 --- a/src/plugin_system/base/base_event.py +++ b/src/plugin_system/base/base_event.py @@ -86,7 +86,9 @@ class HandlerResultsCollection: class BaseEvent: - def __init__(self, name: str, allowed_subscribers: list[str] | None = None, allowed_triggers: list[str] | None = None): + def __init__( + self, name: str, allowed_subscribers: list[str] | None = None, allowed_triggers: list[str] | None = None + ): self.name = name self.enabled = True self.allowed_subscribers = allowed_subscribers # 记录事件处理器名 diff --git a/src/plugin_system/base/base_interest_calculator.py b/src/plugin_system/base/base_interest_calculator.py index 211c87b4f..d2b307df5 100644 --- a/src/plugin_system/base/base_interest_calculator.py +++ b/src/plugin_system/base/base_interest_calculator.py @@ -28,7 +28,7 @@ class InterestCalculationResult: should_reply: bool = False, should_act: bool = False, error_message: str | None = None, - calculation_time: float = 0.0 + calculation_time: float = 0.0, ): self.success = success self.message_id = message_id @@ -51,17 +51,19 @@ class InterestCalculationResult: "should_act": self.should_act, "error_message": self.error_message, "calculation_time": self.calculation_time, - "timestamp": self.timestamp + "timestamp": self.timestamp, } def __repr__(self) -> str: - return (f"InterestCalculationResult(" - f"success={self.success}, " - f"message_id={self.message_id}, " - f"interest_value={self.interest_value:.3f}, " - f"should_take_action={self.should_take_action}, " - f"should_reply={self.should_reply}, " - f"should_act={self.should_act})") + return ( + f"InterestCalculationResult(" + f"success={self.success}, " + f"message_id={self.message_id}, " + f"interest_value={self.interest_value:.3f}, " + f"should_take_action={self.should_take_action}, " + f"should_reply={self.should_reply}, " + f"should_act={self.should_act})" + ) class BaseInterestCalculator(ABC): @@ -144,7 +146,7 @@ class BaseInterestCalculator(ABC): "failed_calculations": self._failed_calculations, "success_rate": 1.0 - (self._failed_calculations / max(1, self._total_calculations)), "average_calculation_time": self._average_calculation_time, - "last_calculation_time": self._last_calculation_time + "last_calculation_time": self._last_calculation_time, } def _update_statistics(self, result: InterestCalculationResult): @@ -159,8 +161,7 @@ class BaseInterestCalculator(ABC): else: alpha = 0.1 # 指数移动平均的平滑因子 self._average_calculation_time = ( - alpha * result.calculation_time + - (1 - alpha) * self._average_calculation_time + alpha * result.calculation_time + (1 - alpha) * self._average_calculation_time ) self._last_calculation_time = result.timestamp @@ -172,7 +173,7 @@ class BaseInterestCalculator(ABC): success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, - error_message="组件未启用" + error_message="组件未启用", ) start_time = time.time() @@ -187,7 +188,7 @@ class BaseInterestCalculator(ABC): message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=f"计算执行失败: {e!s}", - calculation_time=time.time() - start_time + calculation_time=time.time() - start_time, ) self._update_statistics(result) return result @@ -214,7 +215,9 @@ class BaseInterestCalculator(ABC): ) def __repr__(self) -> str: - return (f"{self.__class__.__name__}(" - f"name={self.component_name}, " - f"version={self.component_version}, " - f"enabled={self._enabled})") + return ( + f"{self.__class__.__name__}(" + f"name={self.component_name}, " + f"version={self.component_version}, " + f"enabled={self._enabled})" + ) diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 37c6e5ed5..df48a3164 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -60,7 +60,9 @@ class BasePlugin(PluginBase): if hasattr(component_class, "get_interest_calculator_info"): return component_class.get_interest_calculator_info() else: - logger.warning(f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法") + logger.warning( + f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法" + ) return None elif component_type == ComponentType.PLUS_COMMAND: @@ -96,6 +98,7 @@ class BasePlugin(PluginBase): 对应类型的ComponentInfo对象 """ return cls._get_component_info_from_class(component_class, component_type) + @abstractmethod def get_plugin_components( self, diff --git a/src/plugin_system/base/plugin_metadata.py b/src/plugin_system/base/plugin_metadata.py index 638cbb15e..8871fcf14 100644 --- a/src/plugin_system/base/plugin_metadata.py +++ b/src/plugin_system/base/plugin_metadata.py @@ -7,6 +7,7 @@ class PluginMetadata: """ 插件元数据,用于存储插件的开发者信息和用户帮助信息。 """ + name: str # 插件名称 (供用户查看) description: str # 插件功能描述 usage: str # 插件使用方法 diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 515a81b62..91b3001da 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -319,7 +319,9 @@ class ComponentRegistry: return True def _register_interest_calculator_component( - self, interest_calculator_info: "InterestCalculatorInfo", interest_calculator_class: type["BaseInterestCalculator"] + self, + interest_calculator_info: "InterestCalculatorInfo", + interest_calculator_class: type["BaseInterestCalculator"], ) -> bool: """注册InterestCalculator组件到特定注册表""" calculator_name = interest_calculator_info.name @@ -327,7 +329,9 @@ class ComponentRegistry: if not calculator_name: logger.error(f"InterestCalculator组件 {interest_calculator_class.__name__} 必须指定名称") return False - if not isinstance(interest_calculator_info, InterestCalculatorInfo) or not issubclass(interest_calculator_class, BaseInterestCalculator): + if not isinstance(interest_calculator_info, InterestCalculatorInfo) or not issubclass( + interest_calculator_class, BaseInterestCalculator + ): logger.error(f"注册失败: {calculator_name} 不是有效的InterestCalculator") return False diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index e7461074c..7dd09a894 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -67,6 +67,7 @@ class ToolExecutor: """异步初始化log_prefix和chat_stream""" if not self._log_prefix_initialized: from src.chat.message_receive.chat_stream import get_chat_manager + self.chat_stream = await get_chat_manager().get_stream(self.chat_id) stream_name = await get_chat_manager().get_stream_name(self.chat_id) self.log_prefix = f"[{stream_name or self.chat_id}]" @@ -283,6 +284,7 @@ class ToolExecutor: # 检查是否是MCP工具 try: from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider + if function_name in mcp_tool_provider.mcp_tools: logger.info(f"{self.log_prefix}执行MCP工具: {function_name}") result = await mcp_tool_provider.call_mcp_tool(function_name, function_args) diff --git a/src/plugins/built_in/affinity_flow_chatter/__init__.py b/src/plugins/built_in/affinity_flow_chatter/__init__.py index 4fe7e792c..699eef1c0 100644 --- a/src/plugins/built_in/affinity_flow_chatter/__init__.py +++ b/src/plugins/built_in/affinity_flow_chatter/__init__.py @@ -8,8 +8,5 @@ __plugin_meta__ = PluginMetadata( author="MoFox", keywords=["chatter", "affinity", "conversation"], categories=["Chat", "AI"], - extra={ - "is_built_in": True - } + extra={"is_built_in": True}, ) - diff --git a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py index a02d07b69..760ed07a6 100644 --- a/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py +++ b/src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py @@ -149,7 +149,6 @@ class AffinityChatter(BaseChatter): """ return self.planner.get_mood_stats() - def reset_stats(self): """重置统计信息""" self.stats = { diff --git a/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py b/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py index ab5260bde..ba8beba7f 100644 --- a/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py +++ b/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py @@ -111,9 +111,11 @@ class AffinityInterestCalculator(BaseInterestCalculator): + mentioned_score * self.score_weights["mentioned"] ) - logger.debug(f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + " - f"{relationship_score:.3f}*{self.score_weights['relationship']} + " - f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {total_score:.3f}") + logger.debug( + f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + " + f"{relationship_score:.3f}*{self.score_weights['relationship']} + " + f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {total_score:.3f}" + ) # 5. 考虑连续不回复的概率提升 adjusted_score = self._apply_no_reply_boost(total_score) @@ -135,8 +137,10 @@ class AffinityInterestCalculator(BaseInterestCalculator): calculation_time = time.time() - start_time - logger.debug(f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} " - f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})") + logger.debug( + f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} " + f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})" + ) return InterestCalculationResult( success=True, @@ -145,16 +149,13 @@ class AffinityInterestCalculator(BaseInterestCalculator): should_take_action=should_take_action, should_reply=should_reply, should_act=should_take_action, - calculation_time=calculation_time + calculation_time=calculation_time, ) except Exception as e: logger.error(f"Affinity兴趣值计算失败: {e}", exc_info=True) return InterestCalculationResult( - success=False, - message_id=getattr(message, "message_id", ""), - interest_value=0.0, - error_message=str(e) + success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e) ) async def _calculate_interest_match_score(self, content: str, keywords: list[str] | None = None) -> float: diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py index e68876aaf..ecae0f8e5 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py @@ -405,7 +405,6 @@ class ChatterPlanExecutor: # 移除执行时间列表以避免返回过大数据 stats.pop("execution_times", None) - return stats def reset_stats(self): @@ -434,12 +433,12 @@ class ChatterPlanExecutor: for i, time_val in enumerate(recent_times) ] - async def _flush_action_manager_batch_storage(self, plan: Plan): """使用 action_manager 的批量存储功能存储所有待处理的动作""" try: # 通过 chat_id 获取真实的 chat_stream 对象 from src.plugin_system.apis.chat_api import get_chat_manager + chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(plan.chat_id) @@ -455,4 +454,3 @@ class ChatterPlanExecutor: logger.error(f"批量存储动作记录时发生错误: {e}") # 确保在出错时也禁用批量存储模式 self.action_manager.disable_batch_storage() - diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index 16cab3f72..72e486d26 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -64,7 +64,6 @@ class ChatterPlanFilter: llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt) - if llm_content: if global_config.debug.show_prompt: logger.info(f"LLM规划器原始响应:{llm_content}") diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index cacfdf5bf..5d0bed659 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -132,7 +132,6 @@ class ChatterActionPlanner: if message_should_act: aggregate_should_act = True - except Exception as e: logger.warning(f"处理消息 {message.message_id} 失败: {e}") message.interest_value = 0.0 diff --git a/src/plugins/built_in/affinity_flow_chatter/plugin.py b/src/plugins/built_in/affinity_flow_chatter/plugin.py index 63b682061..26b83a696 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plugin.py +++ b/src/plugins/built_in/affinity_flow_chatter/plugin.py @@ -39,6 +39,7 @@ class AffinityChatterPlugin(BasePlugin): try: # 延迟导入 AffinityChatter from .affinity_chatter import AffinityChatter + components.append((AffinityChatter.get_chatter_info(), AffinityChatter)) except Exception as e: logger.error(f"加载 AffinityChatter 时出错: {e}") @@ -46,9 +47,9 @@ class AffinityChatterPlugin(BasePlugin): try: # 延迟导入 AffinityInterestCalculator from .affinity_interest_calculator import AffinityInterestCalculator + components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator)) except Exception as e: logger.error(f"加载 AffinityInterestCalculator 时出错: {e}") return components - diff --git a/src/plugins/built_in/core_actions/__init__.py b/src/plugins/built_in/core_actions/__init__.py index 00f14f526..1f8c271b6 100644 --- a/src/plugins/built_in/core_actions/__init__.py +++ b/src/plugins/built_in/core_actions/__init__.py @@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata( extra={ "is_built_in": True, "plugin_type": "action_provider", - } + }, ) diff --git a/src/plugins/built_in/maizone_refactored/__init__.py b/src/plugins/built_in/maizone_refactored/__init__.py index 6292b9207..547394832 100644 --- a/src/plugins/built_in/maizone_refactored/__init__.py +++ b/src/plugins/built_in/maizone_refactored/__init__.py @@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata( extra={ "is_built_in": False, "plugin_type": "social", - } + }, ) diff --git a/src/plugins/built_in/napcat_adapter_plugin/__init__.py b/src/plugins/built_in/napcat_adapter_plugin/__init__.py index a279a5409..15e1543e6 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/__init__.py +++ b/src/plugins/built_in/napcat_adapter_plugin/__init__.py @@ -12,5 +12,5 @@ __plugin_meta__ = PluginMetadata( categories=["protocol"], extra={ "is_built_in": False, - } -) \ No newline at end of file + }, +) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py b/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py index 282e68b4e..2e4a8390f 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py @@ -13,8 +13,8 @@ def create_router(plugin_config: dict): """创建路由器实例""" global router platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq") - host = os.getenv("HOST","127.0.0.1") - port = os.getenv("PORT","8000") + host = os.getenv("HOST", "127.0.0.1") + port = os.getenv("PORT", "8000") logger.debug(f"初始化MaiBot连接,使用地址:{host}:{port}") route_config = RouteConfig( route_config={ diff --git a/src/plugins/built_in/permission_management/__init__.py b/src/plugins/built_in/permission_management/__init__.py index c54c8d279..1e58619ea 100644 --- a/src/plugins/built_in/permission_management/__init__.py +++ b/src/plugins/built_in/permission_management/__init__.py @@ -12,5 +12,5 @@ __plugin_meta__ = PluginMetadata( extra={ "is_built_in": True, "plugin_type": "permission", - } + }, ) diff --git a/src/plugins/built_in/plugin_management/__init__.py b/src/plugins/built_in/plugin_management/__init__.py index 6b56066d0..79d417eeb 100644 --- a/src/plugins/built_in/plugin_management/__init__.py +++ b/src/plugins/built_in/plugin_management/__init__.py @@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata( extra={ "is_built_in": True, "plugin_type": "plugin_management", - } + }, ) diff --git a/src/plugins/built_in/proactive_thinker/__init__.py b/src/plugins/built_in/proactive_thinker/__init__.py index 81359e471..176db6dbe 100644 --- a/src/plugins/built_in/proactive_thinker/__init__.py +++ b/src/plugins/built_in/proactive_thinker/__init__.py @@ -8,10 +8,7 @@ __plugin_meta__ = PluginMetadata( author="MoFox-Studio", license="GPL-v3.0-or-later", repository_url="https://github.com/MoFox-Studio", - keywords=["主动思考","自己发消息"], + keywords=["主动思考", "自己发消息"], categories=["Chat", "Integration"], - extra={ - "is_built_in": True, - "plugin_type": "functional" - } + extra={"is_built_in": True, "plugin_type": "functional"}, ) diff --git a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py index 0cdc74fd6..2f062c0b0 100644 --- a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py +++ b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py @@ -63,7 +63,9 @@ class ColdStartTask(AsyncTask): logger.info(f"【冷启动】发现全新用户 {chat_id},准备发起第一次问候。") elif stream.last_active_time < self.bot_start_time: should_wake_up = True - logger.info(f"【冷启动】发现沉睡的聊天流 {chat_id} (最后活跃于 {datetime.fromtimestamp(stream.last_active_time)}),准备唤醒。") + logger.info( + f"【冷启动】发现沉睡的聊天流 {chat_id} (最后活跃于 {datetime.fromtimestamp(stream.last_active_time)}),准备唤醒。" + ) if should_wake_up: person_id = person_api.get_person_id(platform, user_id) @@ -166,7 +168,9 @@ class ProactiveThinkingTask(AsyncTask): continue # 检查冷却时间 - recent_messages = await message_api.get_recent_messages(chat_id=stream.stream_id, limit=1,limit_mode="latest") + recent_messages = await message_api.get_recent_messages( + chat_id=stream.stream_id, limit=1, limit_mode="latest" + ) last_message_time = recent_messages[0]["time"] if recent_messages else stream.create_time time_since_last_active = time.time() - last_message_time if time_since_last_active > next_interval: @@ -209,7 +213,7 @@ class ProactiveThinkingTask(AsyncTask): logger.info("日常唤醒任务被正常取消。") break except Exception as e: - traceback.print_exc() # 打印完整的堆栈跟踪 + traceback.print_exc() # 打印完整的堆栈跟踪 logger.error(f"【日常唤醒】任务出现错误,将在60秒后重试: {e}", exc_info=True) await asyncio.sleep(60) diff --git a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py index 6741c9794..3f5257790 100644 --- a/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py +++ b/src/plugins/built_in/proactive_thinker/proactive_thinker_executor.py @@ -143,14 +143,16 @@ class ProactiveThinkerExecutor: else "今天没有日程安排。" ) - recent_messages = await message_api.get_recent_messages(stream.stream_id,limit=50,limit_mode="latest",hours=12) + recent_messages = await message_api.get_recent_messages( + stream.stream_id, limit=50, limit_mode="latest", hours=12 + ) recent_chat_history = ( await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无" ) action_history_list = await get_actions_by_timestamp_with_chat( chat_id=stream.stream_id, - timestamp_start=time.time() - 3600 * 24, #过去24小时 + timestamp_start=time.time() - 3600 * 24, # 过去24小时 timestamp_end=time.time(), limit=7, ) @@ -195,11 +197,9 @@ class ProactiveThinkerExecutor: person_id = person_api.get_person_id(user_info.platform, int(user_info.user_id)) person_info_manager = get_person_info_manager() person_info = await person_info_manager.get_values(person_id, ["user_id", "platform", "person_name"]) - cross_context_block = await Prompt.build_cross_context( - stream.stream_id, "s4u", person_info - ) + cross_context_block = await Prompt.build_cross_context(stream.stream_id, "s4u", person_info) - # 获取关系信息 + # 获取关系信息 short_impression = await person_info_manager.get_value(person_id, "short_impression") or "无" impression = await person_info_manager.get_value(person_id, "impression") or "无" attitude = await person_info_manager.get_value(person_id, "attitude") or 50 diff --git a/src/plugins/built_in/social_toolkit_plugin/__init__.py b/src/plugins/built_in/social_toolkit_plugin/__init__.py index 92b89ca6f..9f48d7182 100644 --- a/src/plugins/built_in/social_toolkit_plugin/__init__.py +++ b/src/plugins/built_in/social_toolkit_plugin/__init__.py @@ -10,8 +10,5 @@ __plugin_meta__ = PluginMetadata( repository_url="https://github.com/MoFox-Studio", keywords=["emoji", "reaction", "like", "表情", "回应", "点赞"], categories=["Chat", "Integration"], - extra={ - "is_built_in": "true", - "plugin_type": "functional" - } + extra={"is_built_in": "true", "plugin_type": "functional"}, ) diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index 041acb11b..1319e29e0 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -548,7 +548,7 @@ class SetEmojiLikePlugin(BasePlugin): config_section_descriptions = {"plugin": "插件基本信息", "components": "插件组件"} # 配置Schema定义 - config_schema: ClassVar[dict ]= { + config_schema: ClassVar[dict] = { "plugin": { "name": ConfigField(type=str, default="set_emoji_like", description="插件名称"), "version": ConfigField(type=str, default="1.0.0", description="插件版本"), diff --git a/src/plugins/built_in/tts_plugin/__init__.py b/src/plugins/built_in/tts_plugin/__init__.py index e2595960d..c7db2791f 100644 --- a/src/plugins/built_in/tts_plugin/__init__.py +++ b/src/plugins/built_in/tts_plugin/__init__.py @@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata( extra={ "is_built_in": True, "plugin_type": "audio_processor", - } + }, ) diff --git a/src/plugins/built_in/web_search_tool/__init__.py b/src/plugins/built_in/web_search_tool/__init__.py index 588af2378..1ebf0bec1 100644 --- a/src/plugins/built_in/web_search_tool/__init__.py +++ b/src/plugins/built_in/web_search_tool/__init__.py @@ -12,5 +12,5 @@ __plugin_meta__ = PluginMetadata( categories=["Tools"], extra={ "is_built_in": True, - } + }, ) diff --git a/src/plugins/built_in/web_search_tool/engines/searxng_engine.py b/src/plugins/built_in/web_search_tool/engines/searxng_engine.py index 75f9373bb..03fdc8885 100644 --- a/src/plugins/built_in/web_search_tool/engines/searxng_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/searxng_engine.py @@ -43,7 +43,9 @@ class SearXNGSearchEngine(BaseSearchEngine): api_keys = config_api.get_global_config("web_search.searxng_api_keys", None) if isinstance(api_keys, list): - self.api_keys: list[str | None] = [k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys] + self.api_keys: list[str | None] = [ + k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys + ] else: self.api_keys = [] @@ -51,9 +53,7 @@ class SearXNGSearchEngine(BaseSearchEngine): if self.api_keys and len(self.api_keys) < len(self.instances): self.api_keys.extend([None] * (len(self.instances) - len(self.api_keys))) - logger.debug( - f"SearXNG 引擎配置: instances={self.instances}, api_keys={'yes' if any(self.api_keys) else 'no'}" - ) + logger.debug(f"SearXNG 引擎配置: instances={self.instances}, api_keys={'yes' if any(self.api_keys) else 'no'}") def is_available(self) -> bool: return bool(self.instances)