From 72e7492953290bbcfdc2272eeba54e2e54508e36 Mon Sep 17 00:00:00 2001 From: ikun-11451 <334495606@qq.com> Date: Sat, 29 Nov 2025 21:26:42 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BE=9D=E6=97=A7=E4=BF=AEpyright=E5=96=B5~?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/chatter_manager.py | 2 +- src/chat/planner_actions/action_manager.py | 3 +- src/chat/planner_actions/action_modifier.py | 2 + src/chat/replyer/default_generator.py | 16 +++++ src/chat/security/manager.py | 19 ++--- src/chat/utils/chat_message_builder.py | 12 +++- src/chat/utils/prompt.py | 7 ++ src/chat/utils/report_generator.py | 1 + src/chat/utils/statistic.py | 4 +- src/chat/utils/utils.py | 8 ++- src/chat/utils/utils_image.py | 5 +- src/chat/utils/utils_video.py | 2 + src/chat/utils/utils_video_legacy.py | 2 + src/chat/utils/utils_voice.py | 2 + .../data_models/message_manager_data_model.py | 13 ++-- src/common/database/api/crud.py | 29 ++++---- src/common/database/api/query.py | 4 +- src/common/database/api/specialized.py | 8 +-- src/common/database/compatibility/adapter.py | 6 +- src/common/database/core/engine.py | 2 + src/common/database/core/models.py | 1 + src/common/database/core/session.py | 3 + .../database/optimization/batch_scheduler.py | 8 +-- .../database/optimization/cache_manager.py | 71 ++++++++++--------- src/common/database/utils/decorators.py | 44 ++++++------ 25 files changed, 170 insertions(+), 104 deletions(-) diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py index 7a2e76d68..5785f3971 100644 --- a/src/chat/chatter_manager.py +++ b/src/chat/chatter_manager.py @@ -83,7 +83,7 @@ class ChatterManager: inactive_streams = [] for stream_id, instance in self.instances.items(): if hasattr(instance, "get_activity_time"): - activity_time = instance.get_activity_time() + activity_time = getattr(instance, "get_activity_time")() if (current_time - activity_time) > max_inactive_seconds: inactive_streams.append(stream_id) diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index ef8b24657..2e297c7d0 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -104,7 +104,7 @@ class ChatterActionManager: log_prefix=log_prefix, shutting_down=shutting_down, plugin_config=plugin_config, - action_message=action_message, + action_message=action_message, # type: ignore ) logger.debug(f"创建Action实例成功: {action_name}") @@ -173,6 +173,7 @@ class ChatterActionManager: Returns: 执行结果 """ + assert global_config is not None chat_stream = None try: diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index ca057d8e7..a13bf2cce 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -30,6 +30,7 @@ class ActionModifier: def __init__(self, action_manager: ChatterActionManager, chat_id: str): """初始化动作处理器""" + assert model_config is not None self.chat_id = chat_id # chat_stream 和 log_prefix 将在异步方法中初始化 self.chat_stream: "ChatStream | None" = None @@ -67,6 +68,7 @@ class ActionModifier: 处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用 """ + assert global_config is not None # 初始化log_prefix await self._initialize_log_prefix() # 根据 stream_id 加载当前可用的动作 diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index fe9be0494..560d518ef 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -240,6 +240,8 @@ class DefaultReplyer: chat_stream: "ChatStream", request_type: str = "replyer", ): + assert global_config is not None + assert model_config is not None self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream # 这些将在异步初始化中设置 @@ -267,6 +269,7 @@ class DefaultReplyer: async def _build_auth_role_prompt(self) -> str: """根据主人配置生成额外提示词""" + assert global_config is not None master_config = global_config.permission.master_prompt if not master_config or not master_config.enable: return "" @@ -515,6 +518,7 @@ class DefaultReplyer: Returns: str: 表达习惯信息字符串 """ + assert global_config is not None # 检查是否允许在此聊天流中使用表达 use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id) if not use_expression: @@ -583,6 +587,7 @@ class DefaultReplyer: Returns: str: 记忆信息字符串 """ + assert global_config is not None # 检查是否启用三层记忆系统 if not (global_config.memory and global_config.memory.enable): return "" @@ -776,6 +781,7 @@ class DefaultReplyer: Returns: str: 关键词反应提示字符串,如果没有触发任何反应则为空字符串 """ + assert global_config is not None if target is None: return "" @@ -834,6 +840,7 @@ class DefaultReplyer: Returns: str: 格式化的notice信息文本,如果没有notice或未启用则返回空字符串 """ + assert global_config is not None try: logger.debug(f"开始构建notice块,chat_id={chat_id}") @@ -902,6 +909,7 @@ class DefaultReplyer: Returns: Tuple[str, str]: (已读历史消息prompt, 未读历史消息prompt) """ + assert global_config is not None try: # 从message_manager获取真实的已读/未读消息 @@ -1002,6 +1010,7 @@ class DefaultReplyer: """ 回退的已读/未读历史消息构建方法 """ + assert global_config is not None # 通过is_read字段分离已读和未读消息 read_messages = [] unread_messages = [] @@ -1115,6 +1124,7 @@ class DefaultReplyer: Returns: str: 构建好的上下文 """ + assert global_config is not None if available_actions is None: available_actions = {} chat_stream = self.chat_stream @@ -1607,6 +1617,7 @@ class DefaultReplyer: reply_to: str, reply_message: dict[str, Any] | DatabaseMessages | None = None, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if + assert global_config is not None chat_stream = self.chat_stream chat_id = chat_stream.stream_id is_group_chat = bool(chat_stream.group_info) @@ -1767,6 +1778,7 @@ class DefaultReplyer: return prompt_text async def llm_generate_content(self, prompt: str): + assert global_config is not None with Timer("LLM生成", {}): # 内部计时器,可选保留 # 直接使用已初始化的模型实例 logger.info(f"使用模型集生成回复: {self.express_model.model_for_task}") @@ -1792,6 +1804,8 @@ class DefaultReplyer: return content, reasoning_content, model_name, tool_calls async def get_prompt_info(self, message: str, sender: str, target: str): + assert global_config is not None + assert model_config is not None related_info = "" start_time = time.time() from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool @@ -1843,6 +1857,7 @@ class DefaultReplyer: return "" async def build_relation_info(self, sender: str, target: str): + assert global_config is not None # 获取用户ID if sender == f"{global_config.bot.nickname}(你)": return "你将要回复的是你自己发送的消息。" @@ -1927,6 +1942,7 @@ class DefaultReplyer: reply_to: 回复对象 reply_message: 回复的原始消息 """ + assert global_config is not None return # 已禁用,保留函数签名以防其他地方有引用 # 以下代码已废弃,不再执行 diff --git a/src/chat/security/manager.py b/src/chat/security/manager.py index a8c3a5716..1c8c2ecfa 100644 --- a/src/chat/security/manager.py +++ b/src/chat/security/manager.py @@ -173,9 +173,10 @@ class SecurityManager: pre_check_results = await asyncio.gather(*pre_check_tasks, return_exceptions=True) # 筛选需要完整检查的检测器 - checkers_to_run = [ - c for c, need_check in zip(enabled_checkers, pre_check_results) if need_check is True - ] + checkers_to_run = [] + for c, need_check in zip(enabled_checkers, pre_check_results): + if need_check is True: + checkers_to_run.append(c) if not checkers_to_run: return SecurityCheckResult( @@ -192,20 +193,22 @@ class SecurityManager: results = await asyncio.gather(*check_tasks, return_exceptions=True) # 过滤异常结果 - valid_results = [] + valid_results: list[SecurityCheckResult] = [] for checker, result in zip(checkers_to_run, results): - if isinstance(result, Exception): + if isinstance(result, BaseException): logger.error(f"检测器 '{checker.name}' 执行失败: {result}") continue - result.checker_name = checker.name - valid_results.append(result) + + if isinstance(result, SecurityCheckResult): + result.checker_name = checker.name + valid_results.append(result) # 合并结果 return self._merge_results(valid_results, time.time() - start_time) async def _check_all(self, message: str, context: dict, start_time: float) -> SecurityCheckResult: """检测所有模式(顺序执行所有检测器)""" - results = [] + results: list[SecurityCheckResult] = [] for checker in self._checkers: if not checker.enabled: diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index b0991d53d..7cd4e0596 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -39,11 +39,13 @@ def replace_user_references_sync( Returns: str: 处理后的内容字符串 """ + assert global_config is not None if not content: return "" if name_resolver is None: def default_resolver(platform: str, user_id: str) -> str: + assert global_config is not None # 检查是否是机器人自己 if replace_bot_name and (user_id == str(global_config.bot.qq_account)): return f"{global_config.bot.nickname}(你)" @@ -116,10 +118,12 @@ async def replace_user_references_async( Returns: str: 处理后的内容字符串 """ + assert global_config is not None if name_resolver is None: person_info_manager = get_person_info_manager() async def default_resolver(platform: str, user_id: str) -> str: + assert global_config is not None # 检查是否是机器人自己 if replace_bot_name and (user_id == str(global_config.bot.qq_account)): return f"{global_config.bot.nickname}(你)" @@ -392,7 +396,7 @@ async def get_actions_by_timestamp_with_chat_inclusive( actions = list(result.scalars()) return [action.__dict__ for action in reversed(actions)] else: # earliest - result = await session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -540,6 +544,7 @@ async def _build_readable_messages_internal( Returns: 包含格式化消息的字符串、原始消息详情列表、图片映射字典和更新后的计数器的元组。 """ + assert global_config is not None if not messages: return "", [], pic_id_mapping or {}, pic_counter @@ -694,6 +699,7 @@ async def _build_readable_messages_internal( percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1) original_len = len(content) limit = -1 # 默认不截断 + replace_content = "" if percentile < 0.2: # 60% 之前的消息 (即最旧的 60%) limit = 50 @@ -973,6 +979,7 @@ async def build_readable_messages( truncate: 是否截断长消息 show_actions: 是否显示动作记录 """ + assert global_config is not None # 创建messages的深拷贝,避免修改原始列表 if not messages: return "" @@ -1112,6 +1119,7 @@ async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str: 构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。 处理 回复 和 @ 字段,将bbb映射为匿名占位符。 """ + assert global_config is not None if not messages: print("111111111111没有消息,无法构建匿名消息") return "" @@ -1127,6 +1135,7 @@ async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str: def get_anon_name(platform, user_id): # print(f"get_anon_name: platform:{platform}, user_id:{user_id}") # print(f"global_config.bot.qq_account:{global_config.bot.qq_account}") + assert global_config is not None if user_id == global_config.bot.qq_account: # print("SELF11111111111111") @@ -1204,6 +1213,7 @@ async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]: Returns: 一个包含唯一 person_id 的列表。 """ + assert global_config is not None person_ids_set = set() # 使用集合来自动去重 for msg in messages: diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 35234c352..ae1149b05 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -649,6 +649,7 @@ class Prompt: async def _build_expression_habits(self) -> dict[str, Any]: """构建表达习惯(如表情、口癖)的上下文块.""" + assert global_config is not None # 检查当前聊天是否启用了表达习惯功能 use_expression, _, _ = global_config.expression.get_expression_config_for_chat( self.parameters.chat_id @@ -728,6 +729,7 @@ class Prompt: async def _build_tool_info(self) -> dict[str, Any]: """构建工具调用结果的上下文块.""" + assert global_config is not None if not global_config.tool.enable_tool: return {"tool_info_block": ""} @@ -779,6 +781,7 @@ class Prompt: async def _build_knowledge_info(self) -> dict[str, Any]: """构建从知识库检索到的相关信息的上下文块.""" + assert global_config is not None if not global_config.lpmm_knowledge.enable: return {"knowledge_prompt": ""} @@ -873,6 +876,7 @@ class Prompt: def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]: """为S4U(Scene for You)模式准备最终用于格式化的参数字典.""" + assert global_config is not None return { **context_data, "expression_habits_block": context_data.get("expression_habits_block", ""), @@ -915,6 +919,7 @@ class Prompt: def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]: """为Normal模式准备最终用于格式化的参数字典.""" + assert global_config is not None return { **context_data, "expression_habits_block": context_data.get("expression_habits_block", ""), @@ -959,6 +964,7 @@ class Prompt: def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]: """为默认模式(或其他未指定模式)准备最终用于格式化的参数字典.""" + assert global_config is not None return { "expression_habits_block": context_data.get("expression_habits_block", ""), "relation_info_block": context_data.get("relation_info_block", ""), @@ -1143,6 +1149,7 @@ class Prompt: Returns: str: 构建好的跨群聊上下文字符串。 """ + assert global_config is not None if not global_config.cross_context.enable: return "" diff --git a/src/chat/utils/report_generator.py b/src/chat/utils/report_generator.py index f923b6e63..874451efc 100644 --- a/src/chat/utils/report_generator.py +++ b/src/chat/utils/report_generator.py @@ -338,6 +338,7 @@ class HTMLReportGenerator: # 渲染模板 # 读取CSS和JS文件内容 + assert isinstance(self.jinja_env.loader, FileSystemLoader) async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.css"), encoding="utf-8") as f: report_css = await f.read() async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.js"), encoding="utf-8") as f: diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index a99f5785d..c6fdcec44 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -192,7 +192,7 @@ class StatisticOutputTask(AsyncTask): self._statistic_console_output(stats, now) # 使用新的 HTMLReportGenerator 生成报告 chart_data = await self._collect_chart_data(stats) - deploy_time = datetime.fromtimestamp(local_storage.get("deploy_time", now.timestamp())) + deploy_time = datetime.fromtimestamp(float(local_storage.get("deploy_time", now.timestamp()))) # type: ignore report_generator = HTMLReportGenerator( name_mapping=self.name_mapping, stat_period=self.stat_period, @@ -219,7 +219,7 @@ class StatisticOutputTask(AsyncTask): # 使用新的 HTMLReportGenerator 生成报告 chart_data = await self._collect_chart_data(stats) - deploy_time = datetime.fromtimestamp(local_storage.get("deploy_time", now.timestamp())) + deploy_time = datetime.fromtimestamp(float(local_storage.get("deploy_time", now.timestamp()))) # type: ignore report_generator = HTMLReportGenerator( name_mapping=self.name_mapping, stat_period=self.stat_period, diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 450fff21c..381f69206 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -49,6 +49,7 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]: tuple[bool, float]: (是否提及, 提及类型) 提及类型: 0=未提及, 1=弱提及(文本匹配), 2=强提及(@/回复/私聊) """ + assert global_config is not None nicknames = global_config.bot.alias_names mention_type = 0 # 0=未提及, 1=弱提及, 2=强提及 @@ -132,6 +133,7 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]: async def get_embedding(text, request_type="embedding") -> list[float] | None: """获取文本的embedding向量""" + assert model_config is not None # 每次都创建新的LLMRequest实例以避免事件循环冲突 llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) try: @@ -139,11 +141,12 @@ async def get_embedding(text, request_type="embedding") -> list[float] | None: except Exception as e: logger.error(f"获取embedding失败: {e!s}") embedding = None - return embedding + return embedding # type: ignore async def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list: # 获取当前群聊记录内发言的人 + assert global_config is not None filter_query = {"chat_id": chat_stream_id} sort_order = [("time", -1)] recent_messages = await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) @@ -400,11 +403,12 @@ def recover_quoted_content(sentences: list[str], placeholder_map: dict[str, str] recovered_sentences.append(sentence) return recovered_sentences - def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]: + assert global_config is not None if not global_config.response_post_process.enable_response_post_process: return [text] + # --- 三层防护系统 --- # --- 三层防护系统 --- # 第一层:保护颜文字 protected_text, kaomoji_mapping = protect_kaomoji(text) if global_config.response_splitter.enable_kaomoji_protection else (text, {}) diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index f0ae224c8..967d5af08 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -64,8 +64,6 @@ class ImageManager: # except Exception as e: # logger.error(f"数据库连接失败: {e}") - self._initialized = True - def _ensure_image_dir(self): """确保图像存储目录存在""" os.makedirs(self.IMAGE_DIR, exist_ok=True) @@ -159,6 +157,7 @@ class ImageManager: async def get_emoji_description(self, image_base64: str) -> str: """获取表情包描述,统一使用EmojiManager中的逻辑进行处理和缓存""" try: + assert global_config is not None from src.chat.emoji_system.emoji_manager import get_emoji_manager emoji_manager = get_emoji_manager() @@ -190,7 +189,7 @@ class ImageManager: return "[表情包(描述生成失败)]" # 4. (可选) 如果启用了“偷表情包”,则将图片和完整描述存入待注册区 - if global_config and global_config.emoji and global_config.emoji.steal_emoji: + if global_config.emoji and global_config.emoji.steal_emoji: logger.debug(f"偷取表情包功能已开启,保存待注册表情包: {image_hash}") try: image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower() diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 945923403..8ed85e2cc 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -44,6 +44,8 @@ class VideoAnalyzer: """基于 inkfox 的视频关键帧 + LLM 描述分析器""" def __init__(self) -> None: + assert global_config is not None + assert model_config is not None cfg = getattr(global_config, "video_analysis", object()) self.max_frames: int = getattr(cfg, "max_frames", 20) self.frame_quality: int = getattr(cfg, "frame_quality", 85) diff --git a/src/chat/utils/utils_video_legacy.py b/src/chat/utils/utils_video_legacy.py index 306cb1591..91219d402 100644 --- a/src/chat/utils/utils_video_legacy.py +++ b/src/chat/utils/utils_video_legacy.py @@ -135,6 +135,8 @@ class LegacyVideoAnalyzer: def __init__(self): """初始化视频分析器""" + assert global_config is not None + assert model_config is not None # 使用专用的视频分析配置 try: self.video_llm = LLMRequest( diff --git a/src/chat/utils/utils_voice.py b/src/chat/utils/utils_voice.py index f74359f18..58cc88b09 100644 --- a/src/chat/utils/utils_voice.py +++ b/src/chat/utils/utils_voice.py @@ -11,6 +11,8 @@ logger = get_logger("chat_voice") async def get_voice_text(voice_base64: str) -> str: """获取音频文件转录文本""" + assert global_config is not None + assert model_config is not None if not global_config.voice.enable_asr: logger.warning("语音识别未启用,无法处理语音消息") return "[语音]" diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index a3f348898..04ee330c1 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -62,7 +62,9 @@ class StreamContext(BaseDataModel): stream_id: str chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊 chat_mode: ChatMode = ChatMode.FOCUS # 聊天模式,默认为专注模式 - max_context_size: int = field(default_factory=lambda: getattr(global_config.chat, "max_context_size", 100)) + max_context_size: int = field( + default_factory=lambda: getattr(global_config.chat, "max_context_size", 100) if global_config else 100 + ) unread_messages: list["DatabaseMessages"] = field(default_factory=list) history_messages: list["DatabaseMessages"] = field(default_factory=list) last_check_time: float = field(default_factory=time.time) @@ -98,7 +100,9 @@ class StreamContext(BaseDataModel): def __post_init__(self): """初始化历史消息异步加载""" if not self.max_context_size or self.max_context_size <= 0: - self.max_context_size = getattr(global_config.chat, "max_context_size", 100) + self.max_context_size = ( + getattr(global_config.chat, "max_context_size", 100) if global_config else 100 + ) try: loop = asyncio.get_event_loop() @@ -118,6 +122,7 @@ class StreamContext(BaseDataModel): async def add_message(self, message: "DatabaseMessages", skip_energy_update: bool = False) -> bool: """添加消息到上下文,支持跳过能量更新的选项""" try: + assert global_config is not None cache_enabled = global_config.chat.enable_message_cache if cache_enabled and not self.is_cache_enabled: self.enable_cache(True) @@ -150,7 +155,7 @@ class StreamContext(BaseDataModel): # ͬ�����ݵ�ͳһ������� try: if global_config.memory and global_config.memory.enable: - unified_manager = _get_unified_memory_manager() + unified_manager: Any = _get_unified_memory_manager() if unified_manager: message_dict = { "message_id": str(message.message_id), @@ -161,7 +166,7 @@ class StreamContext(BaseDataModel): "platform": message.chat_info.platform, "stream_id": self.stream_id, } - await unified_manager.add_message(message_dict) + await unified_manager.add_message(message_dict) # type: ignore logger.debug(f"��Ϣ�����ӵ��������ϵͳ: {message.message_id}") except Exception as e: logger.error(f"������Ϣ���������ϵͳʧ��: {e}") diff --git a/src/common/database/api/crud.py b/src/common/database/api/crud.py index 1c9b1aef9..f6e40ae3f 100644 --- a/src/common/database/api/crud.py +++ b/src/common/database/api/crud.py @@ -9,9 +9,10 @@ import operator from collections.abc import Callable from functools import lru_cache -from typing import Any, TypeVar +from typing import Any, Generic, TypeVar from sqlalchemy import delete, func, select, update +from sqlalchemy.engine import CursorResult, Result from src.common.database.core.models import Base from src.common.database.core.session import get_db_session @@ -25,23 +26,23 @@ from src.common.logger import get_logger logger = get_logger("database.crud") -T = TypeVar("T", bound=Base) +T = TypeVar("T", bound=Any) @lru_cache(maxsize=256) -def _get_model_column_names(model: type[Base]) -> tuple[str, ...]: +def _get_model_column_names(model: type[Any]) -> tuple[str, ...]: """获取模型的列名称列表""" return tuple(column.name for column in model.__table__.columns) @lru_cache(maxsize=256) -def _get_model_field_set(model: type[Base]) -> frozenset[str]: +def _get_model_field_set(model: type[Any]) -> frozenset[str]: """获取模型的有效字段集合""" return frozenset(_get_model_column_names(model)) @lru_cache(maxsize=256) -def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]: +def _get_model_value_fetcher(model: type[Any]) -> Callable[[Any], tuple[Any, ...]]: """为模型准备attrgetter,用于批量获取属性值""" column_names = _get_model_column_names(model) @@ -51,21 +52,21 @@ def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, . if len(column_names) == 1: attr_name = column_names[0] - def _single(instance: Base) -> tuple[Any, ...]: + def _single(instance: Any) -> tuple[Any, ...]: return (getattr(instance, attr_name),) return _single getter = operator.attrgetter(*column_names) - def _multi(instance: Base) -> tuple[Any, ...]: + def _multi(instance: Any) -> tuple[Any, ...]: values = getter(instance) return values if isinstance(values, tuple) else (values,) return _multi -def _model_to_dict(instance: Base) -> dict[str, Any]: +def _model_to_dict(instance: Any) -> dict[str, Any]: """将 SQLAlchemy 模型实例转换为字典 Args: @@ -113,7 +114,7 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T: return instance -class CRUDBase: +class CRUDBase(Generic[T]): """基础CRUD操作类 提供通用的增删改查操作,自动集成缓存和批处理 @@ -249,7 +250,7 @@ class CRUDBase: if cached_dicts is not None: logger.debug(f"缓存命中: {cache_key}") # 从字典列表恢复对象列表 - return [_dict_to_model(self.model, d) for d in cached_dicts] + return [_dict_to_model(self.model, d) for d in cached_dicts] # type: ignore # 从数据库查询 async with get_db_session() as session: @@ -278,7 +279,7 @@ class CRUDBase: await cache.set(cache_key, instances_dicts) # 从字典列表重建对象列表返回(detached状态,所有字段已加载) - return [_dict_to_model(self.model, d) for d in instances_dicts] + return [_dict_to_model(self.model, d) for d in instances_dicts] # type: ignore async def create( self, @@ -420,7 +421,7 @@ class CRUDBase: async with get_db_session() as session: stmt = delete(self.model).where(self.model.id == id) result = await session.execute(stmt) - success = result.rowcount > 0 + success = result.rowcount > 0 # type: ignore # 注意:commit在get_db_session的context manager退出时自动执行 # 清除缓存 @@ -455,7 +456,7 @@ class CRUDBase: stmt = stmt.where(getattr(self.model, key) == value) result = await session.execute(stmt) - return result.scalar() + return int(result.scalar() or 0) async def exists( self, @@ -549,7 +550,7 @@ class CRUDBase: .values(**obj_in) ) result = await session.execute(stmt) - count += result.rowcount + count += result.rowcount # type: ignore # 清除缓存 cache_key = f"{self.model_name}:id:{id}" diff --git a/src/common/database/api/query.py b/src/common/database/api/query.py index 6815820ef..51cbc4da4 100644 --- a/src/common/database/api/query.py +++ b/src/common/database/api/query.py @@ -20,7 +20,7 @@ from src.common.logger import get_logger logger = get_logger("database.query") -T = TypeVar("T", bound="Base") +T = TypeVar("T", bound=Any) class QueryBuilder(Generic[T]): @@ -330,7 +330,7 @@ class QueryBuilder(Generic[T]): items = await self.all() - return items, total + return items, total # type: ignore class AggregateQuery: diff --git a/src/common/database/api/specialized.py b/src/common/database/api/specialized.py index f9ddc274e..01b2372e2 100644 --- a/src/common/database/api/specialized.py +++ b/src/common/database/api/specialized.py @@ -122,7 +122,7 @@ async def get_recent_actions( 动作记录列表 """ query = QueryBuilder(ActionRecords) - return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all() + return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all() # type: ignore # ===== Messages 业务API ===== @@ -148,7 +148,7 @@ async def get_chat_history( .limit(limit) .offset(offset) .all() - ) + ) # type: ignore async def get_message_count(stream_id: str) -> int: @@ -292,7 +292,7 @@ async def get_active_streams( if platform: query = query.filter(platform=platform) - return await query.order_by("-last_message_time").limit(limit).all() + return await query.order_by("-last_message_time").limit(limit).all() # type: ignore # ===== LLMUsage 业务API ===== @@ -390,7 +390,7 @@ async def get_usage_statistics( # 聚合统计 total_input = await query.sum("input_tokens") total_output = await query.sum("output_tokens") - total_count = await query.filter().count() if hasattr(query, "count") else 0 + total_count = await getattr(query.filter(), "count")() if hasattr(query, "count") else 0 return { "total_input_tokens": int(total_input), diff --git a/src/common/database/compatibility/adapter.py b/src/common/database/compatibility/adapter.py index a19eddfea..783998fc7 100644 --- a/src/common/database/compatibility/adapter.py +++ b/src/common/database/compatibility/adapter.py @@ -123,7 +123,7 @@ async def build_filters(model_class, filters: dict[str, Any]): return conditions -def _model_to_dict(instance) -> dict[str, Any]: +def _model_to_dict(instance) -> dict[str, Any] | None: """将数据库模型实例转换为字典(兼容旧API Args: @@ -238,7 +238,7 @@ async def db_query( return None # 更新记录 - updated = await crud.update(instance.id, data) + updated = await crud.update(instance.id, data) # type: ignore return _model_to_dict(updated) elif query_type == "delete": @@ -257,7 +257,7 @@ async def db_query( return None # 删除记录 - success = await crud.delete(instance.id) + success = await crud.delete(instance.id) # type: ignore return {"deleted": success} elif query_type == "count": diff --git a/src/common/database/core/engine.py b/src/common/database/core/engine.py index a235449b0..064178595 100644 --- a/src/common/database/core/engine.py +++ b/src/common/database/core/engine.py @@ -46,6 +46,7 @@ async def get_engine() -> AsyncEngine: if _engine_lock is None: _engine_lock = asyncio.Lock() + assert _engine_lock is not None # 使用锁保护初始化过程 async with _engine_lock: # 双重检查锁定模式 @@ -55,6 +56,7 @@ async def get_engine() -> AsyncEngine: try: from src.config.config import global_config + assert global_config is not None config = global_config.database db_type = config.database_type diff --git a/src/common/database/core/models.py b/src/common/database/core/models.py index 89d5d6f68..8fcf4145d 100644 --- a/src/common/database/core/models.py +++ b/src/common/database/core/models.py @@ -44,6 +44,7 @@ def get_string_field(max_length=255, **kwargs): """ from src.config.config import global_config + assert global_config is not None db_type = global_config.database.database_type # MySQL 索引需要指定长度的 VARCHAR diff --git a/src/common/database/core/session.py b/src/common/database/core/session.py index 9cd701fc2..751bb51bf 100644 --- a/src/common/database/core/session.py +++ b/src/common/database/core/session.py @@ -75,6 +75,7 @@ async def _apply_session_settings(session: AsyncSession, db_type: str) -> None: # 可以设置 schema 搜索路径等 from src.config.config import global_config + assert global_config is not None schema = global_config.database.postgresql_schema if schema and schema != "public": await session.execute(text(f"SET search_path TO {schema}")) @@ -114,6 +115,7 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: # 获取数据库类型并应用特定设置 from src.config.config import global_config + assert global_config is not None await _apply_session_settings(session, global_config.database.database_type) yield session @@ -142,6 +144,7 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]: # 应用数据库特定设置 from src.config.config import global_config + assert global_config is not None await _apply_session_settings(session, global_config.database.database_type) yield session diff --git a/src/common/database/optimization/batch_scheduler.py b/src/common/database/optimization/batch_scheduler.py index f2d2591fb..8a385ddc3 100644 --- a/src/common/database/optimization/batch_scheduler.py +++ b/src/common/database/optimization/batch_scheduler.py @@ -13,7 +13,7 @@ from collections import defaultdict, deque from collections.abc import Callable from dataclasses import dataclass, field from enum import IntEnum -from typing import Any, TypeVar +from typing import Any from sqlalchemy import delete, insert, select, update @@ -23,8 +23,6 @@ from src.common.memory_utils import estimate_size_smart logger = get_logger("batch_scheduler") -T = TypeVar("T") - class Priority(IntEnum): """操作优先级""" @@ -429,7 +427,7 @@ class AdaptiveBatchScheduler: # 执行更新(但不commit) result = await session.execute(stmt) - results.append((op, result.rowcount)) + results.append((op, result.rowcount)) # type: ignore # 注意:commit 由 get_db_session_direct 上下文管理器自动处理 @@ -471,7 +469,7 @@ class AdaptiveBatchScheduler: # 执行删除(但不commit) result = await session.execute(stmt) - results.append((op, result.rowcount)) + results.append((op, result.rowcount)) # type: ignore # 注意:commit 由 get_db_session_direct 上下文管理器自动处理 diff --git a/src/common/database/optimization/cache_manager.py b/src/common/database/optimization/cache_manager.py index 700502c47..eed9783a5 100644 --- a/src/common/database/optimization/cache_manager.py +++ b/src/common/database/optimization/cache_manager.py @@ -398,47 +398,48 @@ class MultiLevelCache: l2_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l2_cache, "L2")) # 使用超时避免死锁 - try: - l1_stats, l2_stats = await asyncio.gather( - asyncio.wait_for(l1_stats_task, timeout=1.0), - asyncio.wait_for(l2_stats_task, timeout=1.0), - return_exceptions=True - ) - except asyncio.TimeoutError: - logger.warning("缓存统计获取超时,使用基本统计") - l1_stats = await self.l1_cache.get_stats() - l2_stats = await self.l2_cache.get_stats() + results = await asyncio.gather( + asyncio.wait_for(l1_stats_task, timeout=1.0), + asyncio.wait_for(l2_stats_task, timeout=1.0), + return_exceptions=True + ) + l1_stats = results[0] + l2_stats = results[1] # 处理异常情况 - if isinstance(l1_stats, Exception): + if isinstance(l1_stats, BaseException): logger.error(f"L1统计获取失败: {l1_stats}") l1_stats = CacheStats() - if isinstance(l2_stats, Exception): + if isinstance(l2_stats, BaseException): logger.error(f"L2统计获取失败: {l2_stats}") l2_stats = CacheStats() + assert isinstance(l1_stats, CacheStats) + assert isinstance(l2_stats, CacheStats) + # 🔧 修复:并行获取键集合,避免锁嵌套 l1_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l1_cache)) l2_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l2_cache)) - try: - l1_keys, l2_keys = await asyncio.gather( - asyncio.wait_for(l1_keys_task, timeout=1.0), - asyncio.wait_for(l2_keys_task, timeout=1.0), - return_exceptions=True - ) - except asyncio.TimeoutError: - logger.warning("缓存键获取超时,使用默认值") - l1_keys, l2_keys = set(), set() + results = await asyncio.gather( + asyncio.wait_for(l1_keys_task, timeout=1.0), + asyncio.wait_for(l2_keys_task, timeout=1.0), + return_exceptions=True + ) + l1_keys = results[0] + l2_keys = results[1] # 处理异常情况 - if isinstance(l1_keys, Exception): + if isinstance(l1_keys, BaseException): logger.warning(f"L1键获取失败: {l1_keys}") l1_keys = set() - if isinstance(l2_keys, Exception): + if isinstance(l2_keys, BaseException): logger.warning(f"L2键获取失败: {l2_keys}") l2_keys = set() + assert isinstance(l1_keys, set) + assert isinstance(l2_keys, set) + # 计算共享键和独占键 shared_keys = l1_keys & l2_keys l1_only_keys = l1_keys - l2_keys @@ -448,24 +449,25 @@ class MultiLevelCache: l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys)) l2_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l2_cache, l2_keys)) - try: - l1_size, l2_size = await asyncio.gather( - asyncio.wait_for(l1_size_task, timeout=1.0), - asyncio.wait_for(l2_size_task, timeout=1.0), - return_exceptions=True - ) - except asyncio.TimeoutError: - logger.warning("内存计算超时,使用统计值") - l1_size, l2_size = l1_stats.total_size, l2_stats.total_size + results = await asyncio.gather( + asyncio.wait_for(l1_size_task, timeout=1.0), + asyncio.wait_for(l2_size_task, timeout=1.0), + return_exceptions=True + ) + l1_size = results[0] + l2_size = results[1] # 处理异常情况 - if isinstance(l1_size, Exception): + if isinstance(l1_size, BaseException): logger.warning(f"L1内存计算失败: {l1_size}") l1_size = l1_stats.total_size - if isinstance(l2_size, Exception): + if isinstance(l2_size, BaseException): logger.warning(f"L2内存计算失败: {l2_size}") l2_size = l2_stats.total_size + assert isinstance(l1_size, int) + assert isinstance(l2_size, int) + # 计算实际总内存(避免重复计数) actual_total_size = l1_size + l2_size - min(l1_stats.total_size, l2_stats.total_size) @@ -769,6 +771,7 @@ async def get_cache() -> MultiLevelCache: try: from src.config.config import global_config + assert global_config is not None db_config = global_config.database # 检查是否启用缓存 diff --git a/src/common/database/utils/decorators.py b/src/common/database/utils/decorators.py index 319debcb1..b2baad6db 100644 --- a/src/common/database/utils/decorators.py +++ b/src/common/database/utils/decorators.py @@ -11,7 +11,7 @@ import functools import hashlib import time from collections.abc import Awaitable, Callable -from typing import Any, TypeVar +from typing import Any, ParamSpec, TypeVar from sqlalchemy.exc import DBAPIError, OperationalError from sqlalchemy.exc import TimeoutError as SQLTimeoutError @@ -56,8 +56,9 @@ def generate_cache_key( return ":".join(cache_key_parts) -T = TypeVar("T") -F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) + +P = ParamSpec("P") +R = TypeVar("R") def retry( @@ -77,14 +78,13 @@ def retry( exceptions: 需要重试的异常类型 Example: - @retry(max_attempts=3, delay=1.0) async def query_data(): return await session.execute(stmt) """ - def decorator(func: Callable[..., T]) -> Callable[..., T]: + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: @functools.wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> T: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: last_exception = None current_delay = delay @@ -107,7 +107,9 @@ def retry( ) # 所有尝试都失败 - raise last_exception + if last_exception: + raise last_exception + raise RuntimeError(f"Retry failed after {max_attempts} attempts") return wrapper @@ -128,9 +130,9 @@ def timeout(seconds: float): return await session.execute(complex_stmt) """ - def decorator(func: Callable[..., T]) -> Callable[..., T]: + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: @functools.wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> T: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: try: return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds) except asyncio.TimeoutError: @@ -164,9 +166,9 @@ def cached( return await query_user(user_id) """ - def decorator(func: Callable[..., T]) -> Callable[..., T]: + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: @functools.wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> T: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # 延迟导入避免循环依赖 from src.common.database.optimization import get_cache @@ -226,9 +228,9 @@ def measure_time(log_slow: float | None = None): return await session.execute(stmt) """ - def decorator(func: Callable[..., T]) -> Callable[..., T]: + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: @functools.wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> T: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: start_time = time.perf_counter() try: @@ -268,21 +270,23 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True): 函数需要接受session参数 """ - def decorator(func: Callable[..., T]) -> Callable[..., T]: + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: @functools.wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> T: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # 查找session参数 - session = None - if args: - from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import AsyncSession + session: AsyncSession | None = None + if args: for arg in args: if isinstance(arg, AsyncSession): session = arg break if not session and "session" in kwargs: - session = kwargs["session"] + possible_session = kwargs["session"] + if isinstance(possible_session, AsyncSession): + session = possible_session if not session: logger.warning(f"{func.__name__} 未找到session参数,跳过事务管理") @@ -331,7 +335,7 @@ def db_operation( return await complex_operation() """ - def decorator(func: Callable[..., T]) -> Callable[..., T]: + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: # 从内到外应用装饰器 wrapped = func