diff --git a/scripts/cleanup_models.py b/scripts/cleanup_models.py index 0b09c4015..e02e8ce6b 100644 --- a/scripts/cleanup_models.py +++ b/scripts/cleanup_models.py @@ -16,7 +16,7 @@ models_file = os.path.join( print(f"正在清理文件: {models_file}") # 读取文件 -with open(models_file, "r", encoding="utf-8") as f: +with open(models_file, encoding="utf-8") as f: lines = f.readlines() # 找到最后一个模型类的结束位置(MonthlyPlan的 __table_args__ 结束) @@ -26,7 +26,7 @@ found_end = False for i, line in enumerate(lines, 1): keep_lines.append(line) - + # 检查是否到达 MonthlyPlan 的 __table_args__ 结束 if i > 580 and line.strip() == ")": # 再检查前一行是否有 Index 相关内容 @@ -43,7 +43,7 @@ if not found_end: with open(models_file, "w", encoding="utf-8") as f: f.writelines(keep_lines) -print(f"✅ 文件清理完成") +print("✅ 文件清理完成") print(f"保留行数: {len(keep_lines)}") print(f"原始行数: {len(lines)}") print(f"删除行数: {len(lines) - len(keep_lines)}") diff --git a/scripts/extract_models.py b/scripts/extract_models.py index 2eba4adaf..c97ca163c 100644 --- a/scripts/extract_models.py +++ b/scripts/extract_models.py @@ -4,20 +4,20 @@ import re # 读取原始文件 -with open('src/common/database/sqlalchemy_models.py', 'r', encoding='utf-8') as f: +with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f: content = f.read() # 找到get_string_field函数的开始和结束 -get_string_field_start = content.find('# MySQL兼容的字段类型辅助函数') -get_string_field_end = content.find('\n\nclass ChatStreams(Base):') +get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数") +get_string_field_end = content.find("\n\nclass ChatStreams(Base):") get_string_field = content[get_string_field_start:get_string_field_end] # 找到第一个class定义开始 -first_class_pos = content.find('class ChatStreams(Base):') +first_class_pos = content.find("class ChatStreams(Base):") # 找到所有class定义,直到遇到非class的def # 简单策略:找到所有以"class "开头且继承Base的类 -classes_pattern = r'class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)' +classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)" matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL)) if matches: @@ -53,14 +53,14 @@ Base = declarative_base() ''' -new_content = header + get_string_field + '\n\n' + models_content +new_content = header + get_string_field + "\n\n" + models_content # 写入新文件 -with open('src/common/database/core/models.py', 'w', encoding='utf-8') as f: +with open("src/common/database/core/models.py", "w", encoding="utf-8") as f: f.write(new_content) -print('✅ Models file rewritten successfully') -print(f'File size: {len(new_content)} characters') +print("✅ Models file rewritten successfully") +print(f"File size: {len(new_content)} characters") pattern = r"^class \w+\(Base\):" model_count = len(re.findall(pattern, models_content, re.MULTILINE)) -print(f'Number of model classes: {model_count}') +print(f"Number of model classes: {model_count}") diff --git a/scripts/update_database_imports.py b/scripts/update_database_imports.py index 2e8df9bf5..15736e641 100644 --- a/scripts/update_database_imports.py +++ b/scripts/update_database_imports.py @@ -8,54 +8,53 @@ import re from pathlib import Path -from typing import Dict, List, Tuple # 定义导入映射规则 IMPORT_MAPPINGS = { # 模型导入 - r'from src\.common\.database\.sqlalchemy_models import (.+)': - r'from src.common.database.core.models import \1', - + r"from src\.common\.database\.sqlalchemy_models import (.+)": + r"from src.common.database.core.models import \1", + # API导入 - 需要特殊处理 - r'from src\.common\.database\.sqlalchemy_database_api import (.+)': - r'from src.common.database.compatibility import \1', - + r"from src\.common\.database\.sqlalchemy_database_api import (.+)": + r"from src.common.database.compatibility import \1", + # get_db_session 从 sqlalchemy_database_api 导入 - r'from src\.common\.database\.sqlalchemy_database_api import get_db_session': - r'from src.common.database.core import get_db_session', - + r"from src\.common\.database\.sqlalchemy_database_api import get_db_session": + r"from src.common.database.core import get_db_session", + # get_db_session 从 sqlalchemy_models 导入 - r'from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)': - lambda m: f'from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}' - if 'get_db_session' in m.group(0) else m.group(0), - + r"from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)": + lambda m: f"from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}" + if "get_db_session" in m.group(0) else m.group(0), + # get_engine 导入 - r'from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)': - lambda m: f'from src.common.database.core import {m.group(1)}get_engine{m.group(2)}', - + r"from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)": + lambda m: f"from src.common.database.core import {m.group(1)}get_engine{m.group(2)}", + # Base 导入 - r'from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)': - lambda m: f'from src.common.database.core.models import {m.group(1)}Base{m.group(2)}', - + r"from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)": + lambda m: f"from src.common.database.core.models import {m.group(1)}Base{m.group(2)}", + # initialize_database 导入 - r'from src\.common\.database\.sqlalchemy_models import initialize_database': - r'from src.common.database.core import check_and_migrate_database as initialize_database', - + r"from src\.common\.database\.sqlalchemy_models import initialize_database": + r"from src.common.database.core import check_and_migrate_database as initialize_database", + # database.py 导入 - r'from src\.common\.database\.database import stop_database': - r'from src.common.database.core import close_engine as stop_database', - - r'from src\.common\.database\.database import initialize_sql_database': - r'from src.common.database.core import check_and_migrate_database as initialize_sql_database', + r"from src\.common\.database\.database import stop_database": + r"from src.common.database.core import close_engine as stop_database", + + r"from src\.common\.database\.database import initialize_sql_database": + r"from src.common.database.core import check_and_migrate_database as initialize_sql_database", } # 需要排除的文件 EXCLUDE_PATTERNS = [ - '**/database_refactoring_plan.md', # 文档文件 - '**/old/**', # 旧文件目录 - '**/sqlalchemy_*.py', # 旧的数据库文件本身 - '**/database.py', # 旧的database文件 - '**/db_*.py', # 旧的db文件 + "**/database_refactoring_plan.md", # 文档文件 + "**/old/**", # 旧文件目录 + "**/sqlalchemy_*.py", # 旧的数据库文件本身 + "**/database.py", # 旧的database文件 + "**/db_*.py", # 旧的db文件 ] @@ -67,47 +66,47 @@ def should_exclude(file_path: Path) -> bool: return False -def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int, List[str]]: +def update_imports_in_file(file_path: Path, dry_run: bool = True) -> tuple[int, list[str]]: """更新单个文件中的导入语句 - + Args: file_path: 文件路径 dry_run: 是否只是预览而不实际修改 - + Returns: (修改次数, 修改详情列表) """ try: - content = file_path.read_text(encoding='utf-8') + content = file_path.read_text(encoding="utf-8") original_content = content changes = [] - + # 应用每个映射规则 for pattern, replacement in IMPORT_MAPPINGS.items(): matches = list(re.finditer(pattern, content)) for match in matches: old_line = match.group(0) - + # 处理函数类型的替换 if callable(replacement): new_line_result = replacement(match) new_line = new_line_result if isinstance(new_line_result, str) else old_line else: new_line = re.sub(pattern, replacement, old_line) - + if old_line != new_line and isinstance(new_line, str): content = content.replace(old_line, new_line, 1) changes.append(f" - {old_line}") changes.append(f" + {new_line}") - + # 如果有修改且不是dry_run,写回文件 if content != original_content: if not dry_run: - file_path.write_text(content, encoding='utf-8') + file_path.write_text(content, encoding="utf-8") return len(changes) // 2, changes - + return 0, [] - + except Exception as e: print(f"❌ 处理文件 {file_path} 时出错: {e}") return 0, [] @@ -116,34 +115,34 @@ def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int, def main(): """主函数""" print("🔍 搜索需要更新导入的文件...") - + # 获取项目根目录 root_dir = Path(__file__).parent.parent - + # 搜索所有Python文件 all_python_files = list(root_dir.rglob("*.py")) - + # 过滤掉排除的文件 target_files = [f for f in all_python_files if not should_exclude(f)] - + print(f"📊 找到 {len(target_files)} 个Python文件需要检查") print("\n" + "="*80) - + # 第一遍:预览模式 print("\n🔍 预览模式 - 检查需要更新的文件...\n") - + files_to_update = [] for file_path in target_files: count, changes = update_imports_in_file(file_path, dry_run=True) if count > 0: files_to_update.append((file_path, count, changes)) - + if not files_to_update: print("✅ 没有文件需要更新!") return - + print(f"📝 发现 {len(files_to_update)} 个文件需要更新:\n") - + total_changes = 0 for file_path, count, changes in files_to_update: rel_path = file_path.relative_to(root_dir) @@ -153,23 +152,23 @@ def main(): if len(changes) > 10: print(f" ... 还有 {len(changes) - 10} 行") total_changes += count - + print("\n" + "="*80) - print(f"\n📊 统计:") + print("\n📊 统计:") print(f" - 需要更新的文件: {len(files_to_update)}") print(f" - 总修改次数: {total_changes}") - + # 询问是否继续 print("\n" + "="*80) response = input("\n是否执行更新?(yes/no): ").strip().lower() - - if response != 'yes': + + if response != "yes": print("❌ 已取消更新") return - + # 第二遍:实际更新 print("\n✨ 开始更新文件...\n") - + success_count = 0 for file_path, _, _ in files_to_update: count, _ = update_imports_in_file(file_path, dry_run=False) @@ -177,7 +176,7 @@ def main(): rel_path = file_path.relative_to(root_dir) print(f"✅ {rel_path} ({count} 处修改)") success_count += 1 - + print("\n" + "="*80) print(f"\n🎉 完成!成功更新 {success_count} 个文件") diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index 809fd2c00..0a7b0d3da 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -263,8 +263,8 @@ class AntiPromptInjector: try: from sqlalchemy import delete - from src.common.database.core.models import Messages from src.common.database.core import get_db_session + from src.common.database.core.models import Messages message_id = message_data.get("message_id") if not message_id: @@ -291,8 +291,8 @@ class AntiPromptInjector: try: from sqlalchemy import update - from src.common.database.core.models import Messages from src.common.database.core import get_db_session + from src.common.database.core.models import Messages message_id = message_data.get("message_id") if not message_id: diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 3bf3b2e5b..60cdd28fa 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -9,8 +9,8 @@ from typing import Any, TypeVar, cast from sqlalchemy import delete, select -from src.common.database.core.models import AntiInjectionStats from src.common.database.core import get_db_session +from src.common.database.core.models import AntiInjectionStats from src.common.logger import get_logger from src.config.config import global_config diff --git a/src/chat/antipromptinjector/management/user_ban.py b/src/chat/antipromptinjector/management/user_ban.py index ea5ac96dc..9c89fa885 100644 --- a/src/chat/antipromptinjector/management/user_ban.py +++ b/src/chat/antipromptinjector/management/user_ban.py @@ -8,8 +8,8 @@ import datetime from sqlalchemy import select -from src.common.database.core.models import BanUser from src.common.database.core import get_db_session +from src.common.database.core.models import BanUser from src.common.logger import get_logger from ..types import DetectionResult diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 3ca02e477..1637e7570 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -15,9 +15,9 @@ from rich.traceback import install from sqlalchemy import select from src.chat.utils.utils_image import get_image_manager, image_path_to_base64 +from src.common.database.api.crud import CRUDBase from src.common.database.compatibility import get_db_session from src.common.database.core.models import Emoji, Images -from src.common.database.api.crud import CRUDBase from src.common.database.utils.decorators import cached from src.common.logger import get_logger from src.config.config import global_config, model_config @@ -215,7 +215,7 @@ class MaiEmoji: else: await crud.delete(will_delete_emoji.id) result = 1 # Successfully deleted one record - + # 使缓存失效 from src.common.database.optimization.cache_manager import get_cache from src.common.database.utils.decorators import generate_cache_key @@ -708,7 +708,7 @@ class EmojiManager: try: # 使用CRUD进行查询 crud = CRUDBase(Emoji) - + if emoji_hash: # 查询特定hash的表情包 emoji_record = await crud.get_by(emoji_hash=emoji_hash) diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index 3ccac8b07..f5ee7ca88 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -9,9 +9,8 @@ from dataclasses import dataclass, field from enum import Enum from typing import Any, TypedDict -from src.common.logger import get_logger from src.common.database.api.crud import CRUDBase -from src.common.database.utils.decorators import cached +from src.common.logger import get_logger from src.config.config import global_config logger = get_logger("energy_system") @@ -203,7 +202,6 @@ class RelationshipEnergyCalculator(EnergyCalculator): # 从数据库获取聊天流兴趣分数 try: - from sqlalchemy import select from src.common.database.core.models import ChatStreams diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 91184124a..45bfb9544 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -236,12 +236,12 @@ class ExpressionLearner: """ 获取指定chat_id的style和grammar表达方式(带10分钟缓存) 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 - + 优化: 使用CRUD和缓存,减少数据库访问 """ # 使用静态方法以正确处理缓存键 return await self._get_expressions_by_chat_id_cached(self.chat_id) - + @staticmethod @cached(ttl=600, key_prefix="chat_expressions") async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]: @@ -278,7 +278,7 @@ class ExpressionLearner: async def _apply_global_decay_to_database(self, current_time: float) -> None: """ 对数据库中的所有表达方式应用全局衰减 - + 优化: 使用CRUD批量处理所有更改,最后统一提交 """ try: @@ -288,7 +288,7 @@ class ExpressionLearner: updated_count = 0 deleted_count = 0 - + # 需要手动操作的情况下使用session async with get_db_session() as session: # 批量处理所有修改 @@ -383,7 +383,7 @@ class ExpressionLearner: current_time = time.time() # 存储到数据库 Expression 表 - crud = CRUDBase(Expression) + CRUDBase(Expression) for chat_id, expr_list in chat_dict.items(): async with get_db_session() as session: for new_expr in expr_list: @@ -429,10 +429,10 @@ class ExpressionLearner: # 删除count最小的多余表达方式 for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: await session.delete(expr) - + # 提交后清除相关缓存 await session.commit() - + # 清除该chat_id的表达方式缓存 from src.common.database.optimization.cache_manager import get_cache from src.common.database.utils.decorators import generate_cache_key diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index ceaba75c4..b7ac002d8 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -9,10 +9,8 @@ from json_repair import repair_json from sqlalchemy import select from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.common.database.api.crud import CRUDBase from src.common.database.compatibility import get_db_session from src.common.database.core.models import Expression -from src.common.database.utils.decorators import cached from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest @@ -152,7 +150,7 @@ class ExpressionSelector: # sourcery skip: extract-duplicate-method, move-assign # 支持多chat_id合并抽选 related_chat_ids = self.get_related_chat_ids(chat_id) - + # 使用CRUD查询(由于需要IN条件,使用session) async with get_db_session() as session: # 优化:一次性查询所有相关chat_id的表达方式 @@ -224,7 +222,7 @@ class ExpressionSelector: if key not in updates_by_key: updates_by_key[key] = expr affected_chat_ids.add(source_id) - + for chat_id, expr_type, situation, style in updates_by_key: async with get_db_session() as session: query = await session.execute( @@ -247,7 +245,7 @@ class ExpressionSelector: f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db" ) await session.commit() - + # 清除所有受影响的chat_id的缓存 from src.common.database.optimization.cache_manager import get_cache from src.common.database.utils.decorators import generate_cache_key diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index c1fd44557..4863ea762 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -728,7 +728,6 @@ class MemorySystem: context = context or {} # 所有记忆完全共享,统一使用 global 作用域,不区分用户 - resolved_user_id = GLOBAL_MEMORY_SCOPE self.status = MemorySystemStatus.RETRIEVING start_time = time.time() diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 4e583e8e7..36ec0dfbc 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -4,15 +4,14 @@ import time from maim_message import GroupInfo, UserInfo from rich.traceback import install -from sqlalchemy import select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from src.common.data_models.database_data_model import DatabaseMessages +from src.common.database.api.crud import CRUDBase +from src.common.database.api.specialized import get_or_create_chat_stream from src.common.database.compatibility import get_db_session from src.common.database.core.models import ChatStreams # 新增导入 -from src.common.database.api.specialized import get_or_create_chat_stream -from src.common.database.api.crud import CRUDBase from src.common.logger import get_logger from src.config.config import global_config # 新增导入 @@ -708,7 +707,7 @@ class ChatManager: # 使用CRUD批量查询 crud = CRUDBase(ChatStreams) all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流 - + for model_instance in all_streams: user_info_data = { "platform": model_instance.user_platform, diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index b40cae91b..a23bbd229 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -22,14 +22,14 @@ logger = get_logger("message_storage") class MessageStorageBatcher: """ 消息存储批处理器 - + 优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力 """ def __init__(self, batch_size: int = 50, flush_interval: float = 5.0): """ 初始化批处理器 - + Args: batch_size: 批量大小,达到此数量立即写入 flush_interval: 自动刷新间隔(秒) @@ -51,7 +51,7 @@ class MessageStorageBatcher: async def stop(self): """停止批处理器""" self._running = False - + if self._flush_task: self._flush_task.cancel() try: @@ -67,7 +67,7 @@ class MessageStorageBatcher: async def add_message(self, message_data: dict): """ 添加消息到批处理队列 - + Args: message_data: 包含消息对象和chat_stream的字典 { @@ -97,23 +97,23 @@ class MessageStorageBatcher: start_time = time.time() success_count = 0 - + try: # 🔧 优化:准备字典数据而不是ORM对象,使用批量INSERT messages_dicts = [] - + for msg_data in messages_to_store: try: message_dict = await self._prepare_message_dict( - msg_data['message'], - msg_data['chat_stream'] + msg_data["message"], + msg_data["chat_stream"] ) if message_dict: messages_dicts.append(message_dict) except Exception as e: logger.error(f"准备消息数据失败: {e}") continue - + # 批量写入数据库 - 使用高效的批量INSERT if messages_dicts: from sqlalchemy import insert @@ -122,7 +122,7 @@ class MessageStorageBatcher: await session.execute(stmt) await session.commit() success_count = len(messages_dicts) - + elapsed = time.time() - start_time logger.info( f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 " @@ -134,18 +134,18 @@ class MessageStorageBatcher: async def _prepare_message_dict(self, message, chat_stream): """准备消息字典数据(用于批量INSERT) - + 这个方法准备字典而不是ORM对象,性能更高 """ message_obj = await self._prepare_message_object(message, chat_stream) if message_obj is None: return None - + # 将ORM对象转换为字典(只包含列字段) message_dict = {} for column in Messages.__table__.columns: message_dict[column.name] = getattr(message_obj, column.name) - + return message_dict async def _prepare_message_object(self, message, chat_stream): @@ -251,12 +251,12 @@ class MessageStorageBatcher: is_picid = message.is_picid is_notify = message.is_notify is_command = message.is_command - is_public_notice = getattr(message, 'is_public_notice', False) - notice_type = getattr(message, 'notice_type', None) - actions = getattr(message, 'actions', None) - should_reply = getattr(message, 'should_reply', None) - should_act = getattr(message, 'should_act', None) - additional_config = getattr(message, 'additional_config', None) + is_public_notice = getattr(message, "is_public_notice", False) + notice_type = getattr(message, "notice_type", None) + actions = getattr(message, "actions", None) + should_reply = getattr(message, "should_reply", None) + should_act = getattr(message, "should_act", None) + additional_config = getattr(message, "additional_config", None) key_words = MessageStorage._serialize_keywords(message.key_words) key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) @@ -349,7 +349,7 @@ class MessageStorageBatcher: # 全局批处理器实例 -_message_storage_batcher: Optional[MessageStorageBatcher] = None +_message_storage_batcher: MessageStorageBatcher | None = None _message_update_batcher: Optional["MessageUpdateBatcher"] = None @@ -367,7 +367,7 @@ def get_message_storage_batcher() -> MessageStorageBatcher: class MessageUpdateBatcher: """ 消息更新批处理器 - + 优化: 将多个消息ID更新操作批量处理,减少数据库连接次数 """ @@ -478,7 +478,7 @@ class MessageStorage: async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None: """ 存储消息到数据库 - + Args: message: 消息对象 chat_stream: 聊天流对象 @@ -488,11 +488,11 @@ class MessageStorage: if use_batch: batcher = get_message_storage_batcher() await batcher.add_message({ - 'message': message, - 'chat_stream': chat_stream + "message": message, + "chat_stream": chat_stream }) return - + # 直接写入模式(保留用于特殊场景) try: # 过滤敏感信息的正则模式 @@ -676,9 +676,9 @@ class MessageStorage: async def update_message(message_data: dict, use_batch: bool = True): """ 更新消息ID(从消息字典) - + 优化: 添加批处理选项,将多个更新操作合并,减少数据库连接 - + Args: message_data: 消息数据字典 use_batch: 是否使用批处理(默认True) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 7702bb519..3701c59e9 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -3,7 +3,7 @@ from collections import defaultdict from datetime import datetime, timedelta from typing import Any -from src.common.database.compatibility import db_get, db_query, db_save +from src.common.database.compatibility import db_get, db_query from src.common.database.core.models import LLMUsage, Messages, OnlineTime from src.common.logger import get_logger from src.manager.async_task_manager import AsyncTask diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index a43b96083..8213ad30b 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -12,8 +12,8 @@ from PIL import Image from rich.traceback import install from sqlalchemy import and_, select -from src.common.database.core.models import ImageDescriptions, Images from src.common.database.core import get_db_session +from src.common.database.core.models import ImageDescriptions, Images from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index d51e7f7c3..945923403 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -25,8 +25,8 @@ from typing import Any from PIL import Image -from src.common.database.core.models import Videos from src.common.database.core import get_db_session +from src.common.database.core.models import Videos from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/common/database/__init__.py b/src/common/database/__init__.py index be633e619..2447c8f41 100644 --- a/src/common/database/__init__.py +++ b/src/common/database/__init__.py @@ -9,6 +9,37 @@ """ # ===== 核心层 ===== +# ===== API层 ===== +from src.common.database.api import ( + AggregateQuery, + CRUDBase, + QueryBuilder, + # ChatStreams API + get_active_streams, + # Messages API + get_chat_history, + get_message_count, + # PersonInfo API + get_or_create_person, + # ActionRecords API + get_recent_actions, + # LLMUsage API + get_usage_statistics, + record_llm_usage, + # 业务API + save_message, + store_action_info, + update_person_affinity, +) + +# ===== 兼容层(向后兼容旧API)===== +from src.common.database.compatibility import ( + MODEL_MAPPING, + build_filters, + db_get, + db_query, + db_save, +) from src.common.database.core import ( Base, check_and_migrate_database, @@ -27,29 +58,6 @@ from src.common.database.optimization import ( get_preloader, ) -# ===== API层 ===== -from src.common.database.api import ( - AggregateQuery, - CRUDBase, - QueryBuilder, - # ActionRecords API - get_recent_actions, - # ChatStreams API - get_active_streams, - # Messages API - get_chat_history, - get_message_count, - # PersonInfo API - get_or_create_person, - # LLMUsage API - get_usage_statistics, - record_llm_usage, - # 业务API - save_message, - store_action_info, - update_person_affinity, -) - # ===== Utils层 ===== from src.common.database.utils import ( cached, @@ -66,61 +74,52 @@ from src.common.database.utils import ( transactional, ) -# ===== 兼容层(向后兼容旧API)===== -from src.common.database.compatibility import ( - MODEL_MAPPING, - build_filters, - db_get, - db_query, - db_save, -) - __all__ = [ - # 核心层 - "Base", - "get_engine", - "get_session_factory", - "get_db_session", - "check_and_migrate_database", - # 优化层 - "MultiLevelCache", - "DataPreloader", - "AdaptiveBatchScheduler", - "get_cache", - "get_preloader", - "get_batch_scheduler", - # API层 - 基础类 - "CRUDBase", - "QueryBuilder", - "AggregateQuery", - # API层 - 业务API - "store_action_info", - "get_recent_actions", - "get_chat_history", - "get_message_count", - "save_message", - "get_or_create_person", - "update_person_affinity", - "get_active_streams", - "record_llm_usage", - "get_usage_statistics", - # Utils层 - "retry", - "timeout", - "cached", - "measure_time", - "transactional", - "db_operation", - "get_monitor", - "record_operation", - "record_cache_hit", - "record_cache_miss", - "print_stats", - "reset_stats", # 兼容层 "MODEL_MAPPING", + "AdaptiveBatchScheduler", + "AggregateQuery", + # 核心层 + "Base", + # API层 - 基础类 + "CRUDBase", + "DataPreloader", + # 优化层 + "MultiLevelCache", + "QueryBuilder", "build_filters", + "cached", + "check_and_migrate_database", + "db_get", + "db_operation", "db_query", "db_save", - "db_get", + "get_active_streams", + "get_batch_scheduler", + "get_cache", + "get_chat_history", + "get_db_session", + "get_engine", + "get_message_count", + "get_monitor", + "get_or_create_person", + "get_preloader", + "get_recent_actions", + "get_session_factory", + "get_usage_statistics", + "measure_time", + "print_stats", + "record_cache_hit", + "record_cache_miss", + "record_llm_usage", + "record_operation", + "reset_stats", + # Utils层 + "retry", + "save_message", + # API层 - 业务API + "store_action_info", + "timeout", + "transactional", + "update_person_affinity", ] diff --git a/src/common/database/api/__init__.py b/src/common/database/api/__init__.py index b80d8082e..b87a18a69 100644 --- a/src/common/database/api/__init__.py +++ b/src/common/database/api/__init__.py @@ -11,49 +11,49 @@ from src.common.database.api.query import AggregateQuery, QueryBuilder # 业务特定API from src.common.database.api.specialized import ( - # ActionRecords - get_recent_actions, - store_action_info, # ChatStreams get_active_streams, - get_or_create_chat_stream, - # LLMUsage - get_usage_statistics, - record_llm_usage, # Messages get_chat_history, get_message_count, - save_message, + get_or_create_chat_stream, # PersonInfo get_or_create_person, - update_person_affinity, + # ActionRecords + get_recent_actions, + # LLMUsage + get_usage_statistics, # UserRelationships get_user_relationship, + record_llm_usage, + save_message, + store_action_info, + update_person_affinity, update_relationship_affinity, ) __all__ = [ + "AggregateQuery", # 基础类 "CRUDBase", "QueryBuilder", - "AggregateQuery", - # ActionRecords API - "store_action_info", - "get_recent_actions", + "get_active_streams", # Messages API "get_chat_history", "get_message_count", - "save_message", - # PersonInfo API - "get_or_create_person", - "update_person_affinity", # ChatStreams API "get_or_create_chat_stream", - "get_active_streams", - # LLMUsage API - "record_llm_usage", + # PersonInfo API + "get_or_create_person", + "get_recent_actions", "get_usage_statistics", # UserRelationships API "get_user_relationship", + # LLMUsage API + "record_llm_usage", + "save_message", + # ActionRecords API + "store_action_info", + "update_person_affinity", "update_relationship_affinity", ] diff --git a/src/common/database/api/crud.py b/src/common/database/api/crud.py index 8a9a75de6..69af46562 100644 --- a/src/common/database/api/crud.py +++ b/src/common/database/api/crud.py @@ -110,12 +110,12 @@ class CRUDBase: if instance is not None: # ✅ 在 session 内部转换为字典,此时所有字段都可安全访问 instance_dict = _model_to_dict(instance) - + # 写入缓存 if use_cache: cache = await get_cache() await cache.set(cache_key, instance_dict) - + # 从字典重建对象返回(detached状态,所有字段已加载) return _dict_to_model(self.model, instance_dict) @@ -159,12 +159,12 @@ class CRUDBase: if instance is not None: # ✅ 在 session 内部转换为字典,此时所有字段都可安全访问 instance_dict = _model_to_dict(instance) - + # 写入缓存 if use_cache: cache = await get_cache() await cache.set(cache_key, instance_dict) - + # 从字典重建对象返回(detached状态,所有字段已加载) return _dict_to_model(self.model, instance_dict) @@ -206,7 +206,7 @@ class CRUDBase: # 应用过滤条件 for key, value in filters.items(): if hasattr(self.model, key): - if isinstance(value, (list, tuple, set)): + if isinstance(value, list | tuple | set): stmt = stmt.where(getattr(self.model, key).in_(value)) else: stmt = stmt.where(getattr(self.model, key) == value) @@ -219,12 +219,12 @@ class CRUDBase: # ✅ 在 session 内部转换为字典列表,此时所有字段都可安全访问 instances_dicts = [_model_to_dict(inst) for inst in instances] - + # 写入缓存 if use_cache: cache = await get_cache() await cache.set(cache_key, instances_dicts) - + # 从字典列表重建对象列表返回(detached状态,所有字段已加载) return [_dict_to_model(self.model, d) for d in instances_dicts] @@ -266,13 +266,13 @@ class CRUDBase: await session.refresh(instance) # 注意:commit在get_db_session的context manager退出时自动执行 # 但为了明确性,这里不需要显式commit - + # 注意:create不清除缓存,因为: # 1. 新记录不会影响已有的单条查询缓存(get/get_by) # 2. get_multi的缓存会自然过期(TTL机制) # 3. 清除所有缓存代价太大,影响性能 # 如果需要强一致性,应该在查询时设置use_cache=False - + return instance async def update( @@ -397,7 +397,7 @@ class CRUDBase: # 应用过滤条件 for key, value in filters.items(): if hasattr(self.model, key): - if isinstance(value, (list, tuple, set)): + if isinstance(value, list | tuple | set): stmt = stmt.where(getattr(self.model, key).in_(value)) else: stmt = stmt.where(getattr(self.model, key) == value) @@ -466,14 +466,14 @@ class CRUDBase: for instance in instances: await session.refresh(instance) - + # 批量创建的缓存策略: # bulk_create通常用于批量导入场景,此时清除缓存是合理的 # 因为可能创建大量记录,缓存的列表查询会明显过期 cache = await get_cache() await cache.clear() logger.info(f"批量创建{len(instances)}条{self.model_name}记录后已清除缓存") - + return instances async def bulk_update( diff --git a/src/common/database/api/query.py b/src/common/database/api/query.py index 02cca7c12..8d7bab1b1 100644 --- a/src/common/database/api/query.py +++ b/src/common/database/api/query.py @@ -207,12 +207,12 @@ class QueryBuilder(Generic[T]): # ✅ 在 session 内部转换为字典列表,此时所有字段都可安全访问 instances_dicts = [_model_to_dict(inst) for inst in instances] - + # 写入缓存 if self._use_cache: cache = await get_cache() await cache.set(cache_key, instances_dicts) - + # 从字典列表重建对象列表返回(detached状态,所有字段已加载) return [_dict_to_model(self.model, d) for d in instances_dicts] @@ -241,12 +241,12 @@ class QueryBuilder(Generic[T]): if instance is not None: # ✅ 在 session 内部转换为字典,此时所有字段都可安全访问 instance_dict = _model_to_dict(instance) - + # 写入缓存 if self._use_cache: cache = await get_cache() await cache.set(cache_key, instance_dict) - + # 从字典重建对象返回(detached状态,所有字段已加载) return _dict_to_model(self.model, instance_dict) diff --git a/src/common/database/api/specialized.py b/src/common/database/api/specialized.py index 494fa4283..91dea9b4f 100644 --- a/src/common/database/api/specialized.py +++ b/src/common/database/api/specialized.py @@ -4,7 +4,7 @@ """ import time -from typing import Any, Optional +from typing import Any import orjson @@ -42,11 +42,11 @@ async def store_action_info( action_prompt_display: str = "", action_done: bool = True, thinking_id: str = "", - action_data: Optional[dict] = None, + action_data: dict | None = None, action_name: str = "", -) -> Optional[dict[str, Any]]: +) -> dict[str, Any] | None: """存储动作信息到数据库 - + Args: chat_stream: 聊天流对象 action_build_into_prompt: 是否将此动作构建到提示中 @@ -55,7 +55,7 @@ async def store_action_info( thinking_id: 关联的思考ID action_data: 动作数据字典 action_name: 动作名称 - + Returns: 保存的记录数据或None """ @@ -71,7 +71,7 @@ async def store_action_info( "action_build_into_prompt": action_build_into_prompt, "action_prompt_display": action_prompt_display, } - + # 从chat_stream获取聊天信息 if chat_stream: record_data.update( @@ -89,20 +89,20 @@ async def store_action_info( "chat_info_platform": "", } ) - + # 使用get_or_create保存记录 saved_record, created = await _action_records_crud.get_or_create( defaults=record_data, action_id=action_id, ) - + if saved_record: logger.debug(f"成功存储动作信息: {action_name} (ID: {action_id})") return {col.name: getattr(saved_record, col.name) for col in saved_record.__table__.columns} else: logger.error(f"存储动作信息失败: {action_name}") return None - + except Exception as e: logger.error(f"存储动作信息时发生错误: {e}", exc_info=True) return None @@ -113,11 +113,11 @@ async def get_recent_actions( limit: int = 10, ) -> list[ActionRecords]: """获取最近的动作记录 - + Args: chat_id: 聊天ID limit: 限制数量 - + Returns: 动作记录列表 """ @@ -132,12 +132,12 @@ async def get_chat_history( offset: int = 0, ) -> list[Messages]: """获取聊天历史 - + Args: stream_id: 流ID limit: 限制数量 offset: 偏移量 - + Returns: 消息列表 """ @@ -153,10 +153,10 @@ async def get_chat_history( async def get_message_count(stream_id: str) -> int: """获取消息数量 - + Args: stream_id: 流ID - + Returns: 消息数量 """ @@ -167,13 +167,13 @@ async def get_message_count(stream_id: str) -> int: async def save_message( message_data: dict[str, Any], use_batch: bool = True, -) -> Optional[Messages]: +) -> Messages | None: """保存消息 - + Args: message_data: 消息数据 use_batch: 是否使用批处理 - + Returns: 保存的消息实例 """ @@ -185,15 +185,15 @@ async def save_message( async def get_or_create_person( platform: str, person_id: str, - defaults: Optional[dict[str, Any]] = None, -) -> tuple[Optional[PersonInfo], bool]: + defaults: dict[str, Any] | None = None, +) -> tuple[PersonInfo | None, bool]: """获取或创建人员信息 - + Args: platform: 平台 person_id: 人员ID defaults: 默认值 - + Returns: (人员信息实例, 是否新创建) """ @@ -210,12 +210,12 @@ async def update_person_affinity( affinity_delta: float, ) -> bool: """更新人员好感度 - + Args: platform: 平台 person_id: 人员ID affinity_delta: 好感度变化值 - + Returns: 是否成功 """ @@ -225,26 +225,26 @@ async def update_person_affinity( platform=platform, person_id=person_id, ) - + if not person: logger.warning(f"人员不存在: {platform}/{person_id}") return False - + # 更新好感度 new_affinity = (person.affinity or 0.0) + affinity_delta await _person_info_crud.update( person.id, {"affinity": new_affinity}, ) - + # 使缓存失效 cache = await get_cache() cache_key = generate_cache_key("person_info", platform, person_id) await cache.delete(cache_key) - + logger.debug(f"更新好感度: {platform}/{person_id} {affinity_delta:+.2f} -> {new_affinity:.2f}") return True - + except Exception as e: logger.error(f"更新好感度失败: {e}", exc_info=True) return False @@ -255,15 +255,15 @@ async def update_person_affinity( async def get_or_create_chat_stream( stream_id: str, platform: str, - defaults: Optional[dict[str, Any]] = None, -) -> tuple[Optional[ChatStreams], bool]: + defaults: dict[str, Any] | None = None, +) -> tuple[ChatStreams | None, bool]: """获取或创建聊天流 - + Args: stream_id: 流ID platform: 平台 defaults: 默认值 - + Returns: (聊天流实例, 是否新创建) """ @@ -275,23 +275,23 @@ async def get_or_create_chat_stream( async def get_active_streams( - platform: Optional[str] = None, + platform: str | None = None, limit: int = 100, ) -> list[ChatStreams]: """获取活跃的聊天流 - + Args: platform: 平台(可选) limit: 限制数量 - + Returns: 聊天流列表 """ query = QueryBuilder(ChatStreams) - + if platform: query = query.filter(platform=platform) - + return await query.order_by("-last_message_time").limit(limit).all() @@ -300,20 +300,20 @@ async def record_llm_usage( model_name: str, input_tokens: int, output_tokens: int, - stream_id: Optional[str] = None, - platform: Optional[str] = None, + stream_id: str | None = None, + platform: str | None = None, user_id: str = "system", request_type: str = "chat", - model_assign_name: Optional[str] = None, - model_api_provider: Optional[str] = None, + model_assign_name: str | None = None, + model_api_provider: str | None = None, endpoint: str = "/v1/chat/completions", cost: float = 0.0, status: str = "success", - time_cost: Optional[float] = None, + time_cost: float | None = None, use_batch: bool = True, -) -> Optional[LLMUsage]: +) -> LLMUsage | None: """记录LLM使用情况 - + Args: model_name: 模型名称 input_tokens: 输入token数 @@ -329,7 +329,7 @@ async def record_llm_usage( status: 状态 time_cost: 时间成本 use_batch: 是否使用批处理 - + Returns: LLM使用记录实例 """ @@ -346,37 +346,36 @@ async def record_llm_usage( "model_assign_name": model_assign_name or model_name, "model_api_provider": model_api_provider or "unknown", } - + if time_cost is not None: usage_data["time_cost"] = time_cost - + return await _llm_usage_crud.create(usage_data, use_batch=use_batch) async def get_usage_statistics( - start_time: Optional[float] = None, - end_time: Optional[float] = None, - model_name: Optional[str] = None, + start_time: float | None = None, + end_time: float | None = None, + model_name: str | None = None, ) -> dict[str, Any]: """获取使用统计 - + Args: start_time: 开始时间戳 end_time: 结束时间戳 model_name: 模型名称 - + Returns: 统计数据字典 """ from src.common.database.api.query import AggregateQuery - + query = AggregateQuery(LLMUsage) - + # 添加时间过滤 if start_time: - async with get_db_session() as session: - from sqlalchemy import and_ - + async with get_db_session(): + conditions = [] if start_time: conditions.append(LLMUsage.timestamp >= start_time) @@ -384,15 +383,15 @@ async def get_usage_statistics( conditions.append(LLMUsage.timestamp <= end_time) if model_name: conditions.append(LLMUsage.model_name == model_name) - + if conditions: query._conditions = conditions - + # 聚合统计 total_input = await query.sum("input_tokens") total_output = await query.sum("output_tokens") total_count = await query.filter().count() if hasattr(query, "count") else 0 - + return { "total_input_tokens": int(total_input), "total_output_tokens": int(total_output), @@ -407,14 +406,14 @@ async def get_user_relationship( platform: str, user_id: str, target_id: str, -) -> Optional[UserRelationships]: +) -> UserRelationships | None: """获取用户关系 - + Args: platform: 平台 user_id: 用户ID target_id: 目标用户ID - + Returns: 用户关系实例 """ @@ -432,13 +431,13 @@ async def update_relationship_affinity( affinity_delta: float, ) -> bool: """更新关系好感度 - + Args: platform: 平台 user_id: 用户ID target_id: 目标用户ID affinity_delta: 好感度变化值 - + Returns: 是否成功 """ @@ -450,15 +449,15 @@ async def update_relationship_affinity( user_id=user_id, target_id=target_id, ) - + if not relationship: logger.error(f"无法创建关系: {platform}/{user_id}->{target_id}") return False - + # 更新好感度和互动次数 new_affinity = (relationship.affinity or 0.0) + affinity_delta new_count = (relationship.interaction_count or 0) + 1 - + await _user_relationships_crud.update( relationship.id, { @@ -467,19 +466,19 @@ async def update_relationship_affinity( "last_interaction_time": time.time(), }, ) - + # 使缓存失效 cache = await get_cache() cache_key = generate_cache_key("user_relationship", platform, user_id, target_id) await cache.delete(cache_key) - + logger.debug( f"更新关系: {platform}/{user_id}->{target_id} " f"好感度{affinity_delta:+.2f}->{new_affinity:.2f} " f"互动{new_count}次" ) return True - + except Exception as e: logger.error(f"更新关系好感度失败: {e}", exc_info=True) return False diff --git a/src/common/database/compatibility/__init__.py b/src/common/database/compatibility/__init__.py index 14e1902b4..fe7d2cdce 100644 --- a/src/common/database/compatibility/__init__.py +++ b/src/common/database/compatibility/__init__.py @@ -14,14 +14,14 @@ from .adapter import ( ) __all__ = [ - # 从 core 重新导出的函数 - "get_db_session", - "get_engine", # 兼容层适配器 "MODEL_MAPPING", "build_filters", + "db_get", "db_query", "db_save", - "db_get", + # 从 core 重新导出的函数 + "get_db_session", + "get_engine", "store_action_info", ] diff --git a/src/common/database/compatibility/adapter.py b/src/common/database/compatibility/adapter.py index 0e50c821d..f2d1b3e58 100644 --- a/src/common/database/compatibility/adapter.py +++ b/src/common/database/compatibility/adapter.py @@ -4,15 +4,13 @@ 保持原有函数签名和行为不变 """ -import time -from typing import Any, Optional - -import orjson -from sqlalchemy import and_, asc, desc, select +from typing import Any from src.common.database.api import ( CRUDBase, QueryBuilder, +) +from src.common.database.api import ( store_action_info as new_store_action_info, ) from src.common.database.core.models import ( @@ -34,15 +32,14 @@ from src.common.database.core.models import ( Messages, MonthlyPlan, OnlineTime, - PersonInfo, PermissionNodes, + PersonInfo, Schedule, ThinkingLog, UserPermissions, UserRelationships, Videos, ) -from src.common.database.core.session import get_db_session from src.common.logger import get_logger logger = get_logger("database.compatibility") @@ -82,11 +79,11 @@ _crud_instances = {name: CRUDBase(model) for name, model in MODEL_MAPPING.items( async def build_filters(model_class, filters: dict[str, Any]): """构建查询过滤条件(兼容MongoDB风格操作符) - + Args: model_class: SQLAlchemy模型类 filters: 过滤条件字典 - + Returns: 条件列表 """ @@ -127,16 +124,16 @@ async def build_filters(model_class, filters: dict[str, Any]): def _model_to_dict(instance) -> dict[str, Any]: """将模型实例转换为字典 - + Args: instance: 模型实例 - + Returns: 字典表示 """ if instance is None: return None - + result = {} for column in instance.__table__.columns: result[column.name] = getattr(instance, column.name) @@ -145,15 +142,15 @@ def _model_to_dict(instance) -> dict[str, Any]: async def db_query( model_class, - data: Optional[dict[str, Any]] = None, - query_type: Optional[str] = "get", - filters: Optional[dict[str, Any]] = None, - limit: Optional[int] = None, - order_by: Optional[list[str]] = None, - single_result: Optional[bool] = False, + data: dict[str, Any] | None = None, + query_type: str | None = "get", + filters: dict[str, Any] | None = None, + limit: int | None = None, + order_by: list[str] | None = None, + single_result: bool | None = False, ) -> list[dict[str, Any]] | dict[str, Any] | None: """执行异步数据库查询操作(兼容旧API) - + Args: model_class: SQLAlchemy模型类 data: 用于创建或更新的数据字典 @@ -162,7 +159,7 @@ async def db_query( limit: 限制结果数量 order_by: 排序字段,前缀'-'表示降序 single_result: 是否只返回单个结果 - + Returns: 根据查询类型返回相应结果 """ @@ -179,7 +176,7 @@ async def db_query( if query_type == "get": # 使用QueryBuilder query_builder = QueryBuilder(model_class) - + # 应用过滤条件 if filters: # 将MongoDB风格过滤器转换为QueryBuilder格式 @@ -202,15 +199,15 @@ async def db_query( query_builder = query_builder.filter(**{f"{field_name}__nin": op_value}) else: query_builder = query_builder.filter(**{field_name: value}) - + # 应用排序 if order_by: query_builder = query_builder.order_by(*order_by) - + # 应用限制 if limit: query_builder = query_builder.limit(limit) - + # 执行查询 if single_result: result = await query_builder.first() @@ -223,7 +220,7 @@ async def db_query( if not data: logger.error("创建操作需要提供data参数") return None - + instance = await crud.create(data) return _model_to_dict(instance) @@ -231,17 +228,17 @@ async def db_query( if not filters or not data: logger.error("更新操作需要提供filters和data参数") return None - + # 先查找记录 query_builder = QueryBuilder(model_class) for field_name, value in filters.items(): query_builder = query_builder.filter(**{field_name: value}) - + instance = await query_builder.first() if not instance: logger.warning(f"未找到匹配的记录: {filters}") return None - + # 更新记录 updated = await crud.update(instance.id, data) return _model_to_dict(updated) @@ -250,29 +247,29 @@ async def db_query( if not filters: logger.error("删除操作需要提供filters参数") return None - + # 先查找记录 query_builder = QueryBuilder(model_class) for field_name, value in filters.items(): query_builder = query_builder.filter(**{field_name: value}) - + instance = await query_builder.first() if not instance: logger.warning(f"未找到匹配的记录: {filters}") return None - + # 删除记录 success = await crud.delete(instance.id) return {"deleted": success} elif query_type == "count": query_builder = QueryBuilder(model_class) - + # 应用过滤条件 if filters: for field_name, value in filters.items(): query_builder = query_builder.filter(**{field_name: value}) - + count = await query_builder.count() return {"count": count} @@ -286,15 +283,15 @@ async def db_save( data: dict[str, Any], key_field: str, key_value: Any, -) -> Optional[dict[str, Any]]: +) -> dict[str, Any] | None: """保存或更新记录(兼容旧API) - + Args: model_class: SQLAlchemy模型类 data: 数据字典 key_field: 主键字段名 key_value: 主键值 - + Returns: 保存的记录数据或None """ @@ -303,15 +300,15 @@ async def db_save( crud = _crud_instances.get(model_name) if not crud: crud = CRUDBase(model_class) - + # 使用get_or_create (返回tuple[T, bool]) instance, created = await crud.get_or_create( defaults=data, **{key_field: key_value}, ) - + return _model_to_dict(instance) - + except Exception as e: logger.error(f"保存数据库记录出错: {e}", exc_info=True) return None @@ -319,20 +316,20 @@ async def db_save( async def db_get( model_class, - filters: Optional[dict[str, Any]] = None, - limit: Optional[int] = None, - order_by: Optional[str] = None, - single_result: Optional[bool] = False, + filters: dict[str, Any] | None = None, + limit: int | None = None, + order_by: str | None = None, + single_result: bool | None = False, ) -> list[dict[str, Any]] | dict[str, Any] | None: """从数据库获取记录(兼容旧API) - + Args: model_class: SQLAlchemy模型类 filters: 过滤条件 limit: 结果数量限制 order_by: 排序字段,前缀'-'表示降序 single_result: 是否只返回单个结果 - + Returns: 记录数据或None """ @@ -353,11 +350,11 @@ async def store_action_info( action_prompt_display: str = "", action_done: bool = True, thinking_id: str = "", - action_data: Optional[dict] = None, + action_data: dict | None = None, action_name: str = "", -) -> Optional[dict[str, Any]]: +) -> dict[str, Any] | None: """存储动作信息到数据库(兼容旧API) - + 直接使用新的specialized API """ return await new_store_action_info( diff --git a/src/common/database/config/old/database_config.py b/src/common/database/config/old/database_config.py index 1165682ee..71cc9824b 100644 --- a/src/common/database/config/old/database_config.py +++ b/src/common/database/config/old/database_config.py @@ -5,7 +5,7 @@ import os from dataclasses import dataclass -from typing import Any, Optional +from typing import Any from urllib.parse import quote_plus from src.common.logger import get_logger @@ -16,50 +16,50 @@ logger = get_logger("database_config") @dataclass class DatabaseConfig: """数据库配置""" - + # 基础配置 db_type: str # "sqlite" 或 "mysql" url: str # 数据库连接URL - + # 引擎配置 engine_kwargs: dict[str, Any] - + # SQLite特定配置 - sqlite_path: Optional[str] = None - + sqlite_path: str | None = None + # MySQL特定配置 - mysql_host: Optional[str] = None - mysql_port: Optional[int] = None - mysql_user: Optional[str] = None - mysql_password: Optional[str] = None - mysql_database: Optional[str] = None + mysql_host: str | None = None + mysql_port: int | None = None + mysql_user: str | None = None + mysql_password: str | None = None + mysql_database: str | None = None mysql_charset: str = "utf8mb4" - mysql_unix_socket: Optional[str] = None + mysql_unix_socket: str | None = None -_database_config: Optional[DatabaseConfig] = None +_database_config: DatabaseConfig | None = None def get_database_config() -> DatabaseConfig: """获取数据库配置 - + 从全局配置中读取数据库设置并构建配置对象 """ global _database_config - + if _database_config is not None: return _database_config - + from src.config.config import global_config - + config = global_config.database - + # 构建数据库URL if config.database_type == "mysql": # MySQL配置 encoded_user = quote_plus(config.mysql_user) encoded_password = quote_plus(config.mysql_password) - + if config.mysql_unix_socket: # Unix socket连接 encoded_socket = quote_plus(config.mysql_unix_socket) @@ -75,7 +75,7 @@ def get_database_config() -> DatabaseConfig: f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" f"?charset={config.mysql_charset}" ) - + engine_kwargs = { "echo": False, "future": True, @@ -90,7 +90,7 @@ def get_database_config() -> DatabaseConfig: "connect_timeout": config.connection_timeout, }, } - + _database_config = DatabaseConfig( db_type="mysql", url=url, @@ -103,12 +103,12 @@ def get_database_config() -> DatabaseConfig: mysql_charset=config.mysql_charset, mysql_unix_socket=config.mysql_unix_socket, ) - + logger.info( f"MySQL配置已加载: " f"{config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" ) - + else: # SQLite配置 if not os.path.isabs(config.sqlite_path): @@ -116,12 +116,12 @@ def get_database_config() -> DatabaseConfig: db_path = os.path.join(ROOT_PATH, config.sqlite_path) else: db_path = config.sqlite_path - + # 确保数据库目录存在 os.makedirs(os.path.dirname(db_path), exist_ok=True) - + url = f"sqlite+aiosqlite:///{db_path}" - + engine_kwargs = { "echo": False, "future": True, @@ -130,16 +130,16 @@ def get_database_config() -> DatabaseConfig: "timeout": 60, }, } - + _database_config = DatabaseConfig( db_type="sqlite", url=url, engine_kwargs=engine_kwargs, sqlite_path=db_path, ) - + logger.info(f"SQLite配置已加载: {db_path}") - + return _database_config diff --git a/src/common/database/core/__init__.py b/src/common/database/core/__init__.py index ca896467f..8f83149db 100644 --- a/src/common/database/core/__init__.py +++ b/src/common/database/core/__init__.py @@ -19,7 +19,6 @@ from .models import ( ChatStreams, Emoji, Expression, - get_string_field, GraphEdges, GraphNodes, ImageDescriptions, @@ -37,30 +36,17 @@ from .models import ( UserPermissions, UserRelationships, Videos, + get_string_field, ) from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory __all__ = [ - # Engine - "get_engine", - "close_engine", - "get_engine_info", - # Session - "get_db_session", - "get_db_session_direct", - "get_session_factory", - "reset_session_factory", - # Migration - "check_and_migrate_database", - "create_all_tables", - "drop_all_tables", - # Models - Base - "Base", - "get_string_field", # Models - Tables (按字母顺序) "ActionRecords", "AntiInjectionStats", "BanUser", + # Models - Base + "Base", "BotPersonalityInterests", "CacheEntries", "ChatStreams", @@ -83,4 +69,18 @@ __all__ = [ "UserPermissions", "UserRelationships", "Videos", + # Migration + "check_and_migrate_database", + "close_engine", + "create_all_tables", + "drop_all_tables", + # Session + "get_db_session", + "get_db_session_direct", + # Engine + "get_engine", + "get_engine_info", + "get_session_factory", + "get_string_field", + "reset_session_factory", ] diff --git a/src/common/database/core/engine.py b/src/common/database/core/engine.py index 4b8e0cc7a..2087ae49b 100644 --- a/src/common/database/core/engine.py +++ b/src/common/database/core/engine.py @@ -5,7 +5,6 @@ import asyncio import os -from typing import Optional from urllib.parse import quote_plus from sqlalchemy import text @@ -18,49 +17,49 @@ from ..utils.exceptions import DatabaseInitializationError logger = get_logger("database.engine") # 全局引擎实例 -_engine: Optional[AsyncEngine] = None -_engine_lock: Optional[asyncio.Lock] = None +_engine: AsyncEngine | None = None +_engine_lock: asyncio.Lock | None = None async def get_engine() -> AsyncEngine: """获取全局数据库引擎(单例模式) - + Returns: AsyncEngine: SQLAlchemy异步引擎 - + Raises: DatabaseInitializationError: 引擎初始化失败 """ global _engine, _engine_lock - + # 快速路径:引擎已初始化 if _engine is not None: return _engine - + # 延迟创建锁(避免在导入时创建) if _engine_lock is None: _engine_lock = asyncio.Lock() - + # 使用锁保护初始化过程 async with _engine_lock: # 双重检查锁定模式 if _engine is not None: return _engine - + try: from src.config.config import global_config - + config = global_config.database db_type = config.database_type - + logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...") - + # 构建数据库URL和引擎参数 if db_type == "mysql": # MySQL配置 encoded_user = quote_plus(config.mysql_user) encoded_password = quote_plus(config.mysql_password) - + if config.mysql_unix_socket: # Unix socket连接 encoded_socket = quote_plus(config.mysql_unix_socket) @@ -76,7 +75,7 @@ async def get_engine() -> AsyncEngine: f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" f"?charset={config.mysql_charset}" ) - + engine_kwargs = { "echo": False, "future": True, @@ -91,11 +90,11 @@ async def get_engine() -> AsyncEngine: "connect_timeout": config.connection_timeout, }, } - + logger.info( f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" ) - + else: # SQLite配置 if not os.path.isabs(config.sqlite_path): @@ -103,12 +102,12 @@ async def get_engine() -> AsyncEngine: db_path = os.path.join(ROOT_PATH, config.sqlite_path) else: db_path = config.sqlite_path - + # 确保数据库目录存在 os.makedirs(os.path.dirname(db_path), exist_ok=True) - + url = f"sqlite+aiosqlite:///{db_path}" - + engine_kwargs = { "echo": False, "future": True, @@ -117,19 +116,19 @@ async def get_engine() -> AsyncEngine: "timeout": 60, }, } - + logger.info(f"SQLite配置: {db_path}") - + # 创建异步引擎 _engine = create_async_engine(url, **engine_kwargs) - + # SQLite特定优化 if db_type == "sqlite": await _enable_sqlite_optimizations(_engine) - + logger.info(f"✅ {db_type.upper()} 数据库引擎初始化成功") return _engine - + except Exception as e: logger.error(f"❌ 数据库引擎初始化失败: {e}", exc_info=True) raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e @@ -137,11 +136,11 @@ async def get_engine() -> AsyncEngine: async def close_engine(): """关闭数据库引擎 - + 释放所有连接池资源 """ global _engine - + if _engine is not None: logger.info("正在关闭数据库引擎...") await _engine.dispose() @@ -151,13 +150,13 @@ async def close_engine(): async def _enable_sqlite_optimizations(engine: AsyncEngine): """启用SQLite性能优化 - + 优化项: - WAL模式:提高并发性能 - NORMAL同步:平衡性能和安全性 - 启用外键约束 - 设置busy_timeout:避免锁定错误 - + Args: engine: SQLAlchemy异步引擎 """ @@ -175,22 +174,22 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine): await conn.execute(text("PRAGMA cache_size = -10000")) # 临时存储使用内存 await conn.execute(text("PRAGMA temp_store = MEMORY")) - + logger.info("✅ SQLite性能优化已启用 (WAL模式 + 并发优化)") - + except Exception as e: logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置") async def get_engine_info() -> dict: """获取引擎信息(用于监控和调试) - + Returns: dict: 引擎信息字典 """ try: engine = await get_engine() - + info = { "name": engine.name, "driver": engine.driver, @@ -199,9 +198,9 @@ async def get_engine_info() -> dict: "pool_checked_out": getattr(engine.pool, "checked_out", lambda: 0)(), "pool_overflow": getattr(engine.pool, "overflow", lambda: 0)(), } - + return info - + except Exception as e: logger.error(f"获取引擎信息失败: {e}") return {} diff --git a/src/common/database/core/migration.py b/src/common/database/core/migration.py index eac6d0cde..587823318 100644 --- a/src/common/database/core/migration.py +++ b/src/common/database/core/migration.py @@ -20,15 +20,15 @@ logger = get_logger("db_migration") async def check_and_migrate_database(existing_engine=None): """异步检查数据库结构并自动迁移 - + 自动执行以下操作: - 创建不存在的表 - 为现有表添加缺失的列 - 为现有表创建缺失的索引 - + Args: existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎 - + Note: 此函数是幂等的,可以安全地多次调用 """ @@ -65,7 +65,7 @@ async def check_and_migrate_database(existing_engine=None): for table in tables_to_create: logger.info(f"表 '{table.name}' 创建成功。") db_table_names.add(table.name) # 将新创建的表添加到集合中 - + # 提交表创建事务 await connection.commit() except Exception as e: @@ -191,40 +191,40 @@ async def check_and_migrate_database(existing_engine=None): async def create_all_tables(existing_engine=None): """创建所有表(不进行迁移检查) - + 直接创建所有在 Base.metadata 中定义的表。 如果表已存在,将被跳过。 - + Args: existing_engine: 可选的已存在的数据库引擎 - + Note: 生产环境建议使用 check_and_migrate_database() """ logger.info("正在创建所有数据库表...") engine = existing_engine if existing_engine is not None else await get_engine() - + async with engine.begin() as connection: await connection.run_sync(Base.metadata.create_all) - + logger.info("数据库表创建完成。") async def drop_all_tables(existing_engine=None): """删除所有表(危险操作!) - + 删除所有在 Base.metadata 中定义的表。 - + Args: existing_engine: 可选的已存在的数据库引擎 - + Warning: 此操作将删除所有数据,不可恢复!仅用于测试环境! """ logger.warning("⚠️ 正在删除所有数据库表...") engine = existing_engine if existing_engine is not None else await get_engine() - + async with engine.begin() as connection: await connection.run_sync(Base.metadata.drop_all) - + logger.warning("所有数据库表已删除。") diff --git a/src/common/database/core/session.py b/src/common/database/core/session.py index c269ba9c4..90e3f634c 100644 --- a/src/common/database/core/session.py +++ b/src/common/database/core/session.py @@ -6,7 +6,6 @@ import asyncio from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import Optional from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker @@ -18,38 +17,38 @@ from .engine import get_engine logger = get_logger("database.session") # 全局会话工厂 -_session_factory: Optional[async_sessionmaker] = None -_factory_lock: Optional[asyncio.Lock] = None +_session_factory: async_sessionmaker | None = None +_factory_lock: asyncio.Lock | None = None async def get_session_factory() -> async_sessionmaker: """获取会话工厂(单例模式) - + Returns: async_sessionmaker: SQLAlchemy异步会话工厂 """ global _session_factory, _factory_lock - + # 快速路径 if _session_factory is not None: return _session_factory - + # 延迟创建锁 if _factory_lock is None: _factory_lock = asyncio.Lock() - + async with _factory_lock: # 双重检查 if _session_factory is not None: return _session_factory - + engine = await get_engine() _session_factory = async_sessionmaker( bind=engine, class_=AsyncSession, expire_on_commit=False, # 避免在commit后访问属性时重新查询 ) - + logger.debug("会话工厂已创建") return _session_factory @@ -57,28 +56,28 @@ async def get_session_factory() -> async_sessionmaker: @asynccontextmanager async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """获取数据库会话上下文管理器 - + 这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。 - + 使用示例: async with get_db_session() as session: result = await session.execute(select(User)) users = result.scalars().all() - + Yields: AsyncSession: SQLAlchemy异步会话对象 """ # 延迟导入避免循环依赖 from ..optimization.connection_pool import get_connection_pool_manager - + session_factory = await get_session_factory() pool_manager = get_connection_pool_manager() - + # 使用连接池管理器(透明复用连接) async with pool_manager.get_session(session_factory) as session: # 为SQLite设置特定的PRAGMA from src.config.config import global_config - + if global_config.database.database_type == "sqlite": try: await session.execute(text("PRAGMA busy_timeout = 60000")) @@ -86,22 +85,22 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: except Exception: # 复用连接时PRAGMA可能已设置,忽略错误 pass - + yield session @asynccontextmanager async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]: """获取数据库会话(直接模式,不使用连接池) - + 用于特殊场景,如需要完全独立的连接时。 一般情况下应使用 get_db_session()。 - + Yields: AsyncSession: SQLAlchemy异步会话对象 """ session_factory = await get_session_factory() - + async with session_factory() as session: try: yield session diff --git a/src/common/database/optimization/__init__.py b/src/common/database/optimization/__init__.py index c0eb80251..7cd0c99df 100644 --- a/src/common/database/optimization/__init__.py +++ b/src/common/database/optimization/__init__.py @@ -11,17 +11,17 @@ from .batch_scheduler import ( AdaptiveBatchScheduler, BatchOperation, BatchStats, + Priority, close_batch_scheduler, get_batch_scheduler, - Priority, ) from .cache_manager import ( CacheEntry, CacheStats, - close_cache, - get_cache, LRUCache, MultiLevelCache, + close_cache, + get_cache, ) from .connection_pool import ( ConnectionPoolManager, @@ -31,36 +31,36 @@ from .connection_pool import ( ) from .preloader import ( AccessPattern, - close_preloader, CommonDataPreloader, DataPreloader, + close_preloader, get_preloader, ) __all__ = [ - # Connection Pool - "ConnectionPoolManager", - "get_connection_pool_manager", - "start_connection_pool", - "stop_connection_pool", - # Cache - "MultiLevelCache", - "LRUCache", - "CacheEntry", - "CacheStats", - "get_cache", - "close_cache", - # Preloader - "DataPreloader", - "CommonDataPreloader", "AccessPattern", - "get_preloader", - "close_preloader", # Batch Scheduler "AdaptiveBatchScheduler", "BatchOperation", "BatchStats", + "CacheEntry", + "CacheStats", + "CommonDataPreloader", + # Connection Pool + "ConnectionPoolManager", + # Preloader + "DataPreloader", + "LRUCache", + # Cache + "MultiLevelCache", "Priority", - "get_batch_scheduler", "close_batch_scheduler", + "close_cache", + "close_preloader", + "get_batch_scheduler", + "get_cache", + "get_connection_pool_manager", + "get_preloader", + "start_connection_pool", + "stop_connection_pool", ] diff --git a/src/common/database/optimization/batch_scheduler.py b/src/common/database/optimization/batch_scheduler.py index 7498a7b16..919155423 100644 --- a/src/common/database/optimization/batch_scheduler.py +++ b/src/common/database/optimization/batch_scheduler.py @@ -10,12 +10,12 @@ import asyncio import time from collections import defaultdict, deque +from collections.abc import Callable from dataclasses import dataclass, field from enum import IntEnum -from typing import Any, Callable, Optional, TypeVar +from typing import Any, TypeVar from sqlalchemy import delete, insert, select, update -from sqlalchemy.ext.asyncio import AsyncSession from src.common.database.core.session import get_db_session from src.common.logger import get_logger @@ -36,22 +36,22 @@ class Priority(IntEnum): @dataclass class BatchOperation: """批量操作""" - + operation_type: str # 'select', 'insert', 'update', 'delete' model_class: type conditions: dict[str, Any] = field(default_factory=dict) - data: Optional[dict[str, Any]] = None - callback: Optional[Callable] = None - future: Optional[asyncio.Future] = None + data: dict[str, Any] | None = None + callback: Callable | None = None + future: asyncio.Future | None = None timestamp: float = field(default_factory=time.time) priority: Priority = Priority.NORMAL - timeout: Optional[float] = None # 超时时间(秒) + timeout: float | None = None # 超时时间(秒) @dataclass class BatchStats: """批处理统计""" - + total_operations: int = 0 batched_operations: int = 0 cache_hits: int = 0 @@ -60,7 +60,7 @@ class BatchStats: avg_wait_time: float = 0.0 timeout_count: int = 0 error_count: int = 0 - + # 自适应统计 last_batch_duration: float = 0.0 last_batch_size: int = 0 @@ -69,7 +69,7 @@ class BatchStats: class AdaptiveBatchScheduler: """自适应批量调度器 - + 特性: - 动态批次大小:根据负载自动调整 - 优先级队列:高优先级操作优先执行 @@ -87,7 +87,7 @@ class AdaptiveBatchScheduler: cache_ttl: float = 5.0, ): """初始化调度器 - + Args: min_batch_size: 最小批次大小 max_batch_size: 最大批次大小 @@ -104,23 +104,23 @@ class AdaptiveBatchScheduler: self.current_wait_time = base_wait_time self.max_queue_size = max_queue_size self.cache_ttl = cache_ttl - + # 操作队列,按优先级分类 self.operation_queues: dict[Priority, deque[BatchOperation]] = { priority: deque() for priority in Priority } - + # 调度控制 - self._scheduler_task: Optional[asyncio.Task] = None + self._scheduler_task: asyncio.Task | None = None self._is_running = False self._lock = asyncio.Lock() - + # 统计信息 self.stats = BatchStats() - + # 简单的结果缓存 self._result_cache: dict[str, tuple[Any, float]] = {} - + logger.info( f"自适应批量调度器初始化: " f"批次大小{min_batch_size}-{max_batch_size}, " @@ -132,7 +132,7 @@ class AdaptiveBatchScheduler: if self._is_running: logger.warning("调度器已在运行") return - + self._is_running = True self._scheduler_task = asyncio.create_task(self._scheduler_loop()) logger.info("批量调度器已启动") @@ -141,16 +141,16 @@ class AdaptiveBatchScheduler: """停止调度器""" if not self._is_running: return - + self._is_running = False - + if self._scheduler_task: self._scheduler_task.cancel() try: await self._scheduler_task except asyncio.CancelledError: pass - + # 处理剩余操作 await self._flush_all_queues() logger.info("批量调度器已停止") @@ -160,10 +160,10 @@ class AdaptiveBatchScheduler: operation: BatchOperation, ) -> asyncio.Future: """添加操作到队列 - + Args: operation: 批量操作 - + Returns: Future对象,可用于获取结果 """ @@ -175,11 +175,11 @@ class AdaptiveBatchScheduler: future = asyncio.get_event_loop().create_future() future.set_result(cached_result) return future - + # 创建future future = asyncio.get_event_loop().create_future() operation.future = future - + async with self._lock: # 检查队列是否已满 total_queued = sum(len(q) for q in self.operation_queues.values()) @@ -191,7 +191,7 @@ class AdaptiveBatchScheduler: # 添加到优先级队列 self.operation_queues[operation.priority].append(operation) self.stats.total_operations += 1 - + return future async def _scheduler_loop(self) -> None: @@ -217,10 +217,10 @@ class AdaptiveBatchScheduler: for _ in range(count): if queue: operations.append(queue.popleft()) - + if not operations: return - + # 执行批量操作 await self._execute_operations(operations) @@ -231,10 +231,10 @@ class AdaptiveBatchScheduler: """执行批量操作""" if not operations: return - + start_time = time.time() batch_size = len(operations) - + try: # 检查超时 valid_operations = [] @@ -246,41 +246,41 @@ class AdaptiveBatchScheduler: self.stats.timeout_count += 1 else: valid_operations.append(op) - + if not valid_operations: return - + # 按操作类型分组 op_groups = defaultdict(list) for op in valid_operations: key = f"{op.operation_type}_{op.model_class.__name__}" op_groups[key].append(op) - + # 执行各组操作 - for group_key, ops in op_groups.items(): + for ops in op_groups.values(): await self._execute_group(ops) - + # 更新统计 duration = time.time() - start_time self.stats.batched_operations += batch_size self.stats.total_execution_time += duration self.stats.last_batch_duration = duration self.stats.last_batch_size = batch_size - + if self.stats.batched_operations > 0: self.stats.avg_batch_size = ( - self.stats.batched_operations / + self.stats.batched_operations / (self.stats.total_execution_time / duration) ) - + logger.debug( f"批量执行完成: {batch_size}个操作, 耗时{duration*1000:.2f}ms" ) - + except Exception as e: logger.error(f"批量操作执行失败: {e}", exc_info=True) self.stats.error_count += 1 - + # 设置所有future的异常 for op in operations: if op.future and not op.future.done(): @@ -290,9 +290,9 @@ class AdaptiveBatchScheduler: """执行同类操作组""" if not operations: return - + op_type = operations[0].operation_type - + try: if op_type == "select": await self._execute_select_batch(operations) @@ -304,7 +304,7 @@ class AdaptiveBatchScheduler: await self._execute_delete_batch(operations) else: raise ValueError(f"未知操作类型: {op_type}") - + except Exception as e: logger.error(f"执行{op_type}操作组失败: {e}", exc_info=True) for op in operations: @@ -323,30 +323,30 @@ class AdaptiveBatchScheduler: stmt = select(op.model_class) for key, value in op.conditions.items(): attr = getattr(op.model_class, key) - if isinstance(value, (list, tuple, set)): + if isinstance(value, list | tuple | set): stmt = stmt.where(attr.in_(value)) else: stmt = stmt.where(attr == value) - + # 执行查询 result = await session.execute(stmt) data = result.scalars().all() - + # 设置结果 if op.future and not op.future.done(): op.future.set_result(data) - + # 缓存结果 cache_key = self._generate_cache_key(op) self._set_cache(cache_key, data) - + # 执行回调 if op.callback: try: op.callback(data) except Exception as e: logger.warning(f"回调执行失败: {e}") - + except Exception as e: logger.error(f"查询失败: {e}", exc_info=True) if op.future and not op.future.done(): @@ -363,23 +363,23 @@ class AdaptiveBatchScheduler: all_data = [op.data for op in operations if op.data] if not all_data: return - + # 批量插入 stmt = insert(operations[0].model_class).values(all_data) - result = await session.execute(stmt) + await session.execute(stmt) await session.commit() - + # 设置结果 for op in operations: if op.future and not op.future.done(): op.future.set_result(True) - + if op.callback: try: op.callback(True) except Exception as e: logger.warning(f"回调执行失败: {e}") - + except Exception as e: logger.error(f"批量插入失败: {e}", exc_info=True) await session.rollback() @@ -402,28 +402,28 @@ class AdaptiveBatchScheduler: for key, value in op.conditions.items(): attr = getattr(op.model_class, key) stmt = stmt.where(attr == value) - + if op.data: stmt = stmt.values(**op.data) - + # 执行更新(但不commit) result = await session.execute(stmt) results.append((op, result.rowcount)) - + # 所有操作成功后,一次性commit await session.commit() - + # 设置所有操作的结果 for op, rowcount in results: if op.future and not op.future.done(): op.future.set_result(rowcount) - + if op.callback: try: op.callback(rowcount) except Exception as e: logger.warning(f"回调执行失败: {e}") - + except Exception as e: logger.error(f"批量更新失败: {e}", exc_info=True) await session.rollback() @@ -447,25 +447,25 @@ class AdaptiveBatchScheduler: for key, value in op.conditions.items(): attr = getattr(op.model_class, key) stmt = stmt.where(attr == value) - + # 执行删除(但不commit) result = await session.execute(stmt) results.append((op, result.rowcount)) - + # 所有操作成功后,一次性commit await session.commit() - + # 设置所有操作的结果 for op, rowcount in results: if op.future and not op.future.done(): op.future.set_result(rowcount) - + if op.callback: try: op.callback(rowcount) except Exception as e: logger.warning(f"回调执行失败: {e}") - + except Exception as e: logger.error(f"批量删除失败: {e}", exc_info=True) await session.rollback() @@ -479,7 +479,7 @@ class AdaptiveBatchScheduler: # 计算拥塞评分 total_queued = sum(len(q) for q in self.operation_queues.values()) self.stats.congestion_score = min(1.0, total_queued / self.max_queue_size) - + # 根据拥塞情况调整批次大小 if self.stats.congestion_score > 0.7: # 高拥塞,增加批次大小 @@ -493,7 +493,7 @@ class AdaptiveBatchScheduler: self.min_batch_size, int(self.current_batch_size * 0.9), ) - + # 根据批次执行时间调整等待时间 if self.stats.last_batch_duration > 0: if self.stats.last_batch_duration > self.current_wait_time * 2: @@ -518,7 +518,7 @@ class AdaptiveBatchScheduler: ] return "|".join(key_parts) - def _get_from_cache(self, cache_key: str) -> Optional[Any]: + def _get_from_cache(self, cache_key: str) -> Any | None: """从缓存获取结果""" if cache_key in self._result_cache: result, timestamp = self._result_cache[cache_key] @@ -551,27 +551,27 @@ class AdaptiveBatchScheduler: # 全局调度器实例 -_global_scheduler: Optional[AdaptiveBatchScheduler] = None +_global_scheduler: AdaptiveBatchScheduler | None = None _scheduler_lock = asyncio.Lock() async def get_batch_scheduler() -> AdaptiveBatchScheduler: """获取全局批量调度器(单例)""" global _global_scheduler - + if _global_scheduler is None: async with _scheduler_lock: if _global_scheduler is None: _global_scheduler = AdaptiveBatchScheduler() await _global_scheduler.start() - + return _global_scheduler async def close_batch_scheduler() -> None: """关闭全局批量调度器""" global _global_scheduler - + if _global_scheduler is not None: await _global_scheduler.stop() _global_scheduler = None diff --git a/src/common/database/optimization/cache_manager.py b/src/common/database/optimization/cache_manager.py index a0021c7c7..240885b3a 100644 --- a/src/common/database/optimization/cache_manager.py +++ b/src/common/database/optimization/cache_manager.py @@ -11,8 +11,9 @@ import asyncio import time from collections import OrderedDict +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, TypeVar +from typing import Any, Generic, TypeVar from src.common.logger import get_logger @@ -24,7 +25,7 @@ T = TypeVar("T") @dataclass class CacheEntry(Generic[T]): """缓存条目 - + Attributes: value: 缓存的值 created_at: 创建时间戳 @@ -42,7 +43,7 @@ class CacheEntry(Generic[T]): @dataclass class CacheStats: """缓存统计信息 - + Attributes: hits: 命中次数 misses: 未命中次数 @@ -70,7 +71,7 @@ class CacheStats: class LRUCache(Generic[T]): """LRU缓存实现 - + 使用OrderedDict实现O(1)的get/set操作 """ @@ -81,7 +82,7 @@ class LRUCache(Generic[T]): name: str = "cache", ): """初始化LRU缓存 - + Args: max_size: 最大缓存条目数 ttl: 过期时间(秒) @@ -94,18 +95,18 @@ class LRUCache(Generic[T]): self._lock = asyncio.Lock() self._stats = CacheStats() - async def get(self, key: str) -> Optional[T]: + async def get(self, key: str) -> T | None: """获取缓存值 - + Args: key: 缓存键 - + Returns: 缓存值,如果不存在或已过期返回None """ async with self._lock: entry = self._cache.get(key) - + if entry is None: self._stats.misses += 1 return None @@ -125,20 +126,20 @@ class LRUCache(Generic[T]): entry.last_accessed = now entry.access_count += 1 self._stats.hits += 1 - + # 移到末尾(最近使用) self._cache.move_to_end(key) - + return entry.value async def set( self, key: str, value: T, - size: Optional[int] = None, + size: int | None = None, ) -> None: """设置缓存值 - + Args: key: 缓存键 value: 缓存值 @@ -146,16 +147,16 @@ class LRUCache(Generic[T]): """ async with self._lock: now = time.time() - + # 如果键已存在,更新值 if key in self._cache: old_entry = self._cache[key] self._stats.total_size -= old_entry.size - + # 估算大小 if size is None: size = self._estimate_size(value) - + # 创建新条目 entry = CacheEntry( value=value, @@ -164,7 +165,7 @@ class LRUCache(Generic[T]): access_count=0, size=size, ) - + # 如果缓存已满,淘汰最久未使用的条目 while len(self._cache) >= self.max_size: oldest_key, oldest_entry = self._cache.popitem(last=False) @@ -175,7 +176,7 @@ class LRUCache(Generic[T]): f"[{self.name}] 淘汰缓存条目: {oldest_key} " f"(访问{oldest_entry.access_count}次)" ) - + # 添加新条目 self._cache[key] = entry self._stats.item_count += 1 @@ -183,10 +184,10 @@ class LRUCache(Generic[T]): async def delete(self, key: str) -> bool: """删除缓存条目 - + Args: key: 缓存键 - + Returns: 是否成功删除 """ @@ -217,7 +218,7 @@ class LRUCache(Generic[T]): def _estimate_size(self, value: Any) -> int: """估算数据大小(字节) - + 这是一个简单的估算,实际大小可能不同 """ import sys @@ -230,11 +231,11 @@ class LRUCache(Generic[T]): class MultiLevelCache: """多级缓存管理器 - + 实现两级缓存架构: - L1: 高速缓存,小容量,短TTL - L2: 扩展缓存,大容量,长TTL - + 查询时先查L1,未命中再查L2,未命中再从数据源加载 """ @@ -246,7 +247,7 @@ class MultiLevelCache: l2_ttl: float = 300, ): """初始化多级缓存 - + Args: l1_max_size: L1缓存最大条目数 l1_ttl: L1缓存TTL(秒) @@ -255,8 +256,8 @@ class MultiLevelCache: """ self.l1_cache: LRUCache[Any] = LRUCache(l1_max_size, l1_ttl, "L1") self.l2_cache: LRUCache[Any] = LRUCache(l2_max_size, l2_ttl, "L2") - self._cleanup_task: Optional[asyncio.Task] = None - + self._cleanup_task: asyncio.Task | None = None + logger.info( f"多级缓存初始化: L1({l1_max_size}项/{l1_ttl}s) " f"L2({l2_max_size}项/{l2_ttl}s)" @@ -265,16 +266,16 @@ class MultiLevelCache: async def get( self, key: str, - loader: Optional[Callable[[], Any]] = None, - ) -> Optional[Any]: + loader: Callable[[], Any] | None = None, + ) -> Any | None: """从缓存获取数据 - + 查询顺序:L1 -> L2 -> loader - + Args: key: 缓存键 loader: 数据加载函数,当缓存未命中时调用 - + Returns: 缓存值或加载的值,如果都不存在返回None """ @@ -307,12 +308,12 @@ class MultiLevelCache: self, key: str, value: Any, - size: Optional[int] = None, + size: int | None = None, ) -> None: """设置缓存值 - + 同时写入L1和L2 - + Args: key: 缓存键 value: 缓存值 @@ -323,9 +324,9 @@ class MultiLevelCache: async def delete(self, key: str) -> None: """删除缓存条目 - + 同时从L1和L2删除 - + Args: key: 缓存键 """ @@ -347,7 +348,7 @@ class MultiLevelCache: async def start_cleanup_task(self, interval: float = 60) -> None: """启动定期清理任务 - + Args: interval: 清理间隔(秒) """ @@ -387,27 +388,27 @@ class MultiLevelCache: # 全局缓存实例 -_global_cache: Optional[MultiLevelCache] = None +_global_cache: MultiLevelCache | None = None _cache_lock = asyncio.Lock() async def get_cache() -> MultiLevelCache: """获取全局缓存实例(单例)""" global _global_cache - + if _global_cache is None: async with _cache_lock: if _global_cache is None: _global_cache = MultiLevelCache() await _global_cache.start_cleanup_task() - + return _global_cache async def close_cache() -> None: """关闭全局缓存""" global _global_cache - + if _global_cache is not None: await _global_cache.stop_cleanup_task() await _global_cache.clear() diff --git a/src/common/database/optimization/connection_pool.py b/src/common/database/optimization/connection_pool.py index f32302766..4030bc061 100644 --- a/src/common/database/optimization/connection_pool.py +++ b/src/common/database/optimization/connection_pool.py @@ -150,7 +150,7 @@ class ConnectionPoolManager: logger.debug(f"🆕 创建连接 (池大小: {len(self._connections)})") yield connection_info.session - + # 🔧 修复:正常退出时提交事务 # 这对SQLite至关重要,因为SQLite没有autocommit if connection_info and connection_info.session: @@ -249,7 +249,7 @@ class ConnectionPoolManager: """获取连接池统计信息""" total_requests = self._stats["pool_hits"] + self._stats["pool_misses"] pool_efficiency = (self._stats["pool_hits"] / max(1, total_requests)) * 100 if total_requests > 0 else 0 - + return { **self._stats, "active_connections": len(self._connections), diff --git a/src/common/database/optimization/preloader.py b/src/common/database/optimization/preloader.py index 7802a1cee..8335f5d0b 100644 --- a/src/common/database/optimization/preloader.py +++ b/src/common/database/optimization/preloader.py @@ -10,8 +10,9 @@ import asyncio import time from collections import defaultdict +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Optional +from typing import Any from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -25,7 +26,7 @@ logger = get_logger("preloader") @dataclass class AccessPattern: """访问模式统计 - + Attributes: key: 数据键 access_count: 访问次数 @@ -42,7 +43,7 @@ class AccessPattern: class DataPreloader: """数据预加载器 - + 通过分析访问模式,预测并预加载可能需要的数据 """ @@ -53,7 +54,7 @@ class DataPreloader: max_patterns: int = 1000, ): """初始化预加载器 - + Args: decay_factor: 时间衰减因子(0-1),越小衰减越快 preload_threshold: 预加载阈值,score超过此值时预加载 @@ -62,7 +63,7 @@ class DataPreloader: self.decay_factor = decay_factor self.preload_threshold = preload_threshold self.max_patterns = max_patterns - + # 访问模式跟踪 self._patterns: dict[str, AccessPattern] = {} # 关联关系:key -> [related_keys] @@ -73,9 +74,9 @@ class DataPreloader: self._total_accesses = 0 self._preload_count = 0 self._preload_hits = 0 - + self._lock = asyncio.Lock() - + logger.info( f"数据预加载器初始化: 衰减因子={decay_factor}, " f"预加载阈值={preload_threshold}" @@ -84,10 +85,10 @@ class DataPreloader: async def record_access( self, key: str, - related_keys: Optional[list[str]] = None, + related_keys: list[str] | None = None, ) -> None: """记录数据访问 - + Args: key: 被访问的数据键 related_keys: 关联访问的数据键列表 @@ -95,7 +96,7 @@ class DataPreloader: async with self._lock: self._total_accesses += 1 now = time.time() - + # 更新或创建访问模式 if key in self._patterns: pattern = self._patterns[key] @@ -108,15 +109,15 @@ class DataPreloader: last_access=now, ) self._patterns[key] = pattern - + # 更新热度评分(时间衰减) pattern.score = self._calculate_score(pattern) - + # 记录关联关系 if related_keys: self._associations[key].update(related_keys) pattern.related_keys = list(self._associations[key]) - + # 如果模式过多,删除评分最低的 if len(self._patterns) > self.max_patterns: min_key = min(self._patterns, key=lambda k: self._patterns[k].score) @@ -126,10 +127,10 @@ class DataPreloader: async def should_preload(self, key: str) -> bool: """判断是否应该预加载某个数据 - + Args: key: 数据键 - + Returns: 是否应该预加载 """ @@ -137,18 +138,18 @@ class DataPreloader: pattern = self._patterns.get(key) if pattern is None: return False - + # 更新评分 pattern.score = self._calculate_score(pattern) - + return pattern.score >= self.preload_threshold async def get_preload_keys(self, limit: int = 100) -> list[str]: """获取应该预加载的数据键列表 - + Args: limit: 最大返回数量 - + Returns: 按评分排序的数据键列表 """ @@ -156,14 +157,14 @@ class DataPreloader: # 更新所有评分 for pattern in self._patterns.values(): pattern.score = self._calculate_score(pattern) - + # 按评分排序 sorted_patterns = sorted( self._patterns.values(), key=lambda p: p.score, reverse=True, ) - + # 返回超过阈值的键 return [ p.key for p in sorted_patterns[:limit] @@ -172,10 +173,10 @@ class DataPreloader: async def get_related_keys(self, key: str) -> list[str]: """获取关联数据键 - + Args: key: 数据键 - + Returns: 关联数据键列表 """ @@ -188,27 +189,27 @@ class DataPreloader: loader: Callable[[], Awaitable[Any]], ) -> None: """预加载数据 - + Args: key: 数据键 loader: 异步加载函数 """ try: cache = await get_cache() - + # 检查缓存中是否已存在 if await cache.l1_cache.get(key) is not None: return - + # 加载数据 logger.debug(f"预加载数据: {key}") data = await loader() - + if data is not None: # 写入缓存 await cache.set(key, data) self._preload_count += 1 - + # 预加载关联数据 related_keys = await self.get_related_keys(key) for related_key in related_keys[:5]: # 最多预加载5个关联项 @@ -216,7 +217,7 @@ class DataPreloader: # 这里需要调用者提供关联数据的加载函数 # 暂时只记录,不实际加载 logger.debug(f"发现关联数据: {related_key}") - + except Exception as e: logger.error(f"预加载数据失败 {key}: {e}", exc_info=True) @@ -226,13 +227,13 @@ class DataPreloader: loaders: dict[str, Callable[[], Awaitable[Any]]], ) -> None: """批量启动预加载任务 - + Args: session: 数据库会话 loaders: 数据键到加载函数的映射 """ preload_keys = await self.get_preload_keys() - + for key in preload_keys: if key in loaders: loader = loaders[key] @@ -242,9 +243,9 @@ class DataPreloader: async def record_hit(self, key: str) -> None: """记录预加载命中 - + 当缓存命中的数据是预加载的,调用此方法统计 - + Args: key: 数据键 """ @@ -259,7 +260,7 @@ class DataPreloader: if self._preload_count > 0 else 0.0 ) - + return { "total_accesses": self._total_accesses, "tracked_patterns": len(self._patterns), @@ -278,7 +279,7 @@ class DataPreloader: self._total_accesses = 0 self._preload_count = 0 self._preload_hits = 0 - + # 取消所有预加载任务 for task in self._preload_tasks: task.cancel() @@ -286,38 +287,38 @@ class DataPreloader: def _calculate_score(self, pattern: AccessPattern) -> float: """计算热度评分 - + 使用时间衰减的访问频率: score = access_count * decay_factor^(time_since_last_access) - + Args: pattern: 访问模式 - + Returns: 热度评分 """ now = time.time() time_diff = now - pattern.last_access - + # 时间衰减(以小时为单位) hours_passed = time_diff / 3600 decay = self.decay_factor ** hours_passed - + # 评分 = 访问次数 * 时间衰减 score = pattern.access_count * decay - + return score class CommonDataPreloader: """常见数据预加载器 - + 针对特定的数据类型提供预加载策略 """ def __init__(self, preloader: DataPreloader): """初始化 - + Args: preloader: 基础预加载器 """ @@ -330,16 +331,16 @@ class CommonDataPreloader: platform: str, ) -> None: """预加载用户相关数据 - + 包括:个人信息、权限、关系等 - + Args: session: 数据库会话 user_id: 用户ID platform: 平台 """ from src.common.database.core.models import PersonInfo, UserPermissions, UserRelationships - + # 预加载个人信息 await self._preload_model( session, @@ -347,7 +348,7 @@ class CommonDataPreloader: PersonInfo, {"platform": platform, "user_id": user_id}, ) - + # 预加载用户权限 await self._preload_model( session, @@ -355,7 +356,7 @@ class CommonDataPreloader: UserPermissions, {"platform": platform, "user_id": user_id}, ) - + # 预加载用户关系 await self._preload_model( session, @@ -371,16 +372,16 @@ class CommonDataPreloader: limit: int = 50, ) -> None: """预加载聊天上下文 - + 包括:最近消息、聊天流信息等 - + Args: session: 数据库会话 stream_id: 聊天流ID limit: 消息数量限制 """ - from src.common.database.core.models import ChatStreams, Messages - + from src.common.database.core.models import ChatStreams + # 预加载聊天流信息 await self._preload_model( session, @@ -388,7 +389,7 @@ class CommonDataPreloader: ChatStreams, {"stream_id": stream_id}, ) - + # 预加载最近消息(这个比较复杂,暂时跳过) # TODO: 实现消息列表的预加载 @@ -400,7 +401,7 @@ class CommonDataPreloader: filters: dict[str, Any], ) -> None: """预加载模型数据 - + Args: session: 数据库会话 cache_key: 缓存键 @@ -413,31 +414,31 @@ class CommonDataPreloader: stmt = stmt.where(getattr(model_class, key) == value) result = await session.execute(stmt) return result.scalar_one_or_none() - + await self.preloader.preload_data(cache_key, loader) # 全局预加载器实例 -_global_preloader: Optional[DataPreloader] = None +_global_preloader: DataPreloader | None = None _preloader_lock = asyncio.Lock() async def get_preloader() -> DataPreloader: """获取全局预加载器实例(单例)""" global _global_preloader - + if _global_preloader is None: async with _preloader_lock: if _global_preloader is None: _global_preloader = DataPreloader() - + return _global_preloader async def close_preloader() -> None: """关闭全局预加载器""" global _global_preloader - + if _global_preloader is not None: await _global_preloader.clear() _global_preloader = None diff --git a/src/common/database/utils/__init__.py b/src/common/database/utils/__init__.py index d59fba36c..4df4ac93f 100644 --- a/src/common/database/utils/__init__.py +++ b/src/common/database/utils/__init__.py @@ -37,29 +37,29 @@ from .monitoring import ( ) __all__ = [ + "BatchSchedulerError", + "CacheError", + "ConnectionPoolError", + "DatabaseConnectionError", # 异常 "DatabaseError", "DatabaseInitializationError", - "DatabaseConnectionError", + "DatabaseMigrationError", + # 监控 + "DatabaseMonitor", "DatabaseQueryError", "DatabaseTransactionError", - "DatabaseMigrationError", - "CacheError", - "BatchSchedulerError", - "ConnectionPoolError", + "cached", + "db_operation", + "get_monitor", + "measure_time", + "print_stats", + "record_cache_hit", + "record_cache_miss", + "record_operation", + "reset_stats", # 装饰器 "retry", "timeout", - "cached", - "measure_time", "transactional", - "db_operation", - # 监控 - "DatabaseMonitor", - "get_monitor", - "record_operation", - "record_cache_hit", - "record_cache_miss", - "print_stats", - "reset_stats", ] diff --git a/src/common/database/utils/decorators.py b/src/common/database/utils/decorators.py index 176a5c25b..a5c4fdc43 100644 --- a/src/common/database/utils/decorators.py +++ b/src/common/database/utils/decorators.py @@ -10,9 +10,11 @@ import asyncio import functools import hashlib import time -from typing import Any, Awaitable, Callable, Optional, TypeVar +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar -from sqlalchemy.exc import DBAPIError, OperationalError, TimeoutError as SQLTimeoutError +from sqlalchemy.exc import DBAPIError, OperationalError +from sqlalchemy.exc import TimeoutError as SQLTimeoutError from src.common.logger import get_logger @@ -25,33 +27,33 @@ def generate_cache_key( **kwargs: Any, ) -> str: """生成与@cached装饰器相同的缓存键 - + 用于手动缓存失效等操作 - + Args: key_prefix: 缓存键前缀 *args: 位置参数 **kwargs: 关键字参数 - + Returns: 缓存键字符串 - + Example: cache_key = generate_cache_key("person_info", platform, person_id) await cache.delete(cache_key) """ cache_key_parts = [key_prefix] - + if args: args_str = ",".join(str(arg) for arg in args) args_hash = hashlib.md5(args_str.encode()).hexdigest()[:8] cache_key_parts.append(f"args:{args_hash}") - + if kwargs: kwargs_str = ",".join(f"{k}={v}" for k, v in sorted(kwargs.items())) kwargs_hash = hashlib.md5(kwargs_str.encode()).hexdigest()[:8] cache_key_parts.append(f"kwargs:{kwargs_hash}") - + return ":".join(cache_key_parts) T = TypeVar("T") @@ -65,15 +67,15 @@ def retry( exceptions: tuple[type[Exception], ...] = (OperationalError, DBAPIError, SQLTimeoutError), ): """重试装饰器 - + 自动重试失败的数据库操作,适用于临时性错误 - + Args: max_attempts: 最大尝试次数 delay: 初始延迟时间(秒) backoff: 延迟倍数(指数退避) exceptions: 需要重试的异常类型 - + Example: @retry(max_attempts=3, delay=1.0) async def query_data(): @@ -114,12 +116,12 @@ def retry( def timeout(seconds: float): """超时装饰器 - + 为数据库操作添加超时控制 - + Args: seconds: 超时时间(秒) - + Example: @timeout(30.0) async def long_query(): @@ -141,21 +143,21 @@ def timeout(seconds: float): def cached( - ttl: Optional[int] = 300, - key_prefix: Optional[str] = None, + ttl: int | None = 300, + key_prefix: str | None = None, use_args: bool = True, use_kwargs: bool = True, ): """缓存装饰器 - + 自动缓存函数返回值 - + Args: ttl: 缓存过期时间(秒),None表示永不过期 key_prefix: 缓存键前缀,默认使用函数名 use_args: 是否将位置参数包含在缓存键中 use_kwargs: 是否将关键字参数包含在缓存键中 - + Example: @cached(ttl=60, key_prefix="user_data") async def get_user_info(user_id: str) -> dict: @@ -167,7 +169,7 @@ def cached( async def wrapper(*args: Any, **kwargs: Any) -> T: # 延迟导入避免循环依赖 from src.common.database.optimization import get_cache - + # 生成缓存键 cache_key_parts = [key_prefix or func.__name__] @@ -207,14 +209,14 @@ def cached( return decorator -def measure_time(log_slow: Optional[float] = None): +def measure_time(log_slow: float | None = None): """性能测量装饰器 - + 测量函数执行时间,可选择性记录慢查询 - + Args: log_slow: 慢查询阈值(秒),超过此时间会记录warning日志 - + Example: @measure_time(log_slow=1.0) async def complex_query(): @@ -246,19 +248,19 @@ def measure_time(log_slow: Optional[float] = None): def transactional(auto_commit: bool = True, auto_rollback: bool = True): """事务装饰器 - + 自动管理事务的提交和回滚 - + Args: auto_commit: 是否自动提交 auto_rollback: 发生异常时是否自动回滚 - + Example: @transactional() async def update_multiple_records(session): await session.execute(stmt1) await session.execute(stmt2) - + Note: 函数需要接受session参数 """ @@ -306,20 +308,20 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True): # 组合装饰器示例 def db_operation( retry_attempts: int = 3, - timeout_seconds: Optional[float] = None, - cache_ttl: Optional[int] = None, + timeout_seconds: float | None = None, + cache_ttl: int | None = None, measure: bool = True, ): """组合装饰器 - + 组合多个装饰器,提供完整的数据库操作保护 - + Args: retry_attempts: 重试次数 timeout_seconds: 超时时间 cache_ttl: 缓存时间 measure: 是否测量性能 - + Example: @db_operation(retry_attempts=3, timeout_seconds=30, cache_ttl=60) async def important_query(): diff --git a/src/common/database/utils/monitoring.py b/src/common/database/utils/monitoring.py index c8eef3628..bfd102806 100644 --- a/src/common/database/utils/monitoring.py +++ b/src/common/database/utils/monitoring.py @@ -4,7 +4,6 @@ """ import time -from collections import defaultdict from dataclasses import dataclass, field from typing import Any, Optional @@ -22,7 +21,7 @@ class OperationMetrics: min_time: float = float("inf") max_time: float = 0.0 error_count: int = 0 - last_execution_time: Optional[float] = None + last_execution_time: float | None = None @property def avg_time(self) -> float: @@ -91,7 +90,7 @@ class DatabaseMetrics: class DatabaseMonitor: """数据库监控器 - + 单例模式,收集和报告数据库性能指标 """ @@ -285,7 +284,7 @@ class DatabaseMonitor: # 全局监控器实例 -_monitor: Optional[DatabaseMonitor] = None +_monitor: DatabaseMonitor | None = None def get_monitor() -> DatabaseMonitor: diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index e64b4f8b3..32131fb4f 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -4,8 +4,8 @@ from datetime import datetime from PIL import Image -from src.common.database.core.models import LLMUsage from src.common.database.core import get_db_session +from src.common.database.core.models import LLMUsage from src.common.logger import get_logger from src.config.api_ada_configs import ModelInfo diff --git a/src/main.py b/src/main.py index 111ee9904..642487286 100644 --- a/src/main.py +++ b/src/main.py @@ -224,7 +224,7 @@ class MainSystem: storage_batcher = get_message_storage_batcher() cleanup_tasks.append(("消息存储批处理器", storage_batcher.stop())) - + update_batcher = get_message_update_batcher() cleanup_tasks.append(("消息更新批处理器", update_batcher.stop())) except Exception as e: @@ -502,7 +502,7 @@ MoFox_Bot(第三方修改版) storage_batcher = get_message_storage_batcher() await storage_batcher.start() logger.info("消息存储批处理器已启动") - + update_batcher = get_message_update_batcher() await update_batcher.start() logger.info("消息更新批处理器已启动") diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index e551b270e..7498fb911 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -256,9 +256,10 @@ class RelationshipFetcher: str: 格式化后的聊天流印象字符串 """ try: - from src.common.database.api.specialized import get_or_create_chat_stream import time + from src.common.database.api.specialized import get_or_create_chat_stream + # 使用优化后的API(带缓存) # 从stream_id解析platform,或使用默认值 platform = stream_id.split("_")[0] if "_" in stream_id else "unknown" @@ -289,7 +290,7 @@ class RelationshipFetcher: except Exception as e: logger.warning(f"访问stream对象属性失败: {e}") stream_data = {} - + impression_parts = [] # 1. 聊天环境基本信息 diff --git a/src/plugin_system/apis/schedule_api.py b/src/plugin_system/apis/schedule_api.py index 154780da9..993ae3d1a 100644 --- a/src/plugin_system/apis/schedule_api.py +++ b/src/plugin_system/apis/schedule_api.py @@ -52,8 +52,8 @@ from typing import Any import orjson from sqlalchemy import func, select -from src.common.database.core.models import MonthlyPlan, Schedule from src.common.database.core import get_db_session +from src.common.database.core.models import MonthlyPlan, Schedule from src.common.logger import get_logger from src.schedule.database import get_active_plans_for_month diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 8376caa38..5c1dd7c32 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -7,7 +7,7 @@ from src.plugin_system.base.component_types import ChatType, CommandInfo, Compon from src.plugin_system.base.plus_command import PlusCommand if TYPE_CHECKING: - from src.chat.message_receive.chat_stream import ChatStream + pass logger = get_logger("base_command") diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index 573492782..3132c48c8 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -10,8 +10,8 @@ from sqlalchemy import delete, select from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.ext.asyncio import async_sessionmaker -from src.common.database.core.models import PermissionNodes, UserPermissions from src.common.database.core import get_engine +from src.common.database.core.models import PermissionNodes, UserPermissions from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo diff --git a/src/plugin_system/services/relationship_service.py b/src/plugin_system/services/relationship_service.py index 32a7b3ca2..1bb995209 100644 --- a/src/plugin_system/services/relationship_service.py +++ b/src/plugin_system/services/relationship_service.py @@ -5,8 +5,8 @@ import time -from src.common.database.core.models import UserRelationships from src.common.database.core import get_db_session +from src.common.database.core.models import UserRelationships from src.common.logger import get_logger from src.config.config import global_config diff --git a/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py b/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py index 23981188a..33a67ec6e 100644 --- a/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py @@ -7,12 +7,8 @@ import json from typing import Any, ClassVar -from sqlalchemy import select - -from src.common.database.compatibility import get_db_session -from src.common.database.core.models import ChatStreams from src.common.database.api.crud import CRUDBase -from src.common.database.utils.decorators import cached +from src.common.database.core.models import ChatStreams from src.common.logger import get_logger from src.config.config import model_config from src.llm_models.utils_model import LLMRequest @@ -358,14 +354,14 @@ class ChatStreamImpressionTool(BaseTool): "stream_interest_score": impression.get("stream_interest_score", 0.5), } ) - + # 使缓存失效 from src.common.database.optimization.cache_manager import get_cache from src.common.database.utils.decorators import generate_cache_key cache = await get_cache() await cache.delete(generate_cache_key("stream_impression", stream_id)) await cache.delete(generate_cache_key("chat_stream", stream_id)) - + logger.info(f"聊天流印象已更新到数据库: {stream_id}") else: error_msg = f"聊天流 {stream_id} 不存在于数据库中,无法更新印象" diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py index e3243b45e..1d33f3121 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py @@ -64,7 +64,7 @@ class ProactiveThinkingReplyHandler(BaseEventHandler): from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - + if chat_stream: stream_config = chat_stream.get_raw_id() if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config): diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py index eb1c136ab..918575d8a 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py @@ -7,13 +7,10 @@ import json from datetime import datetime from typing import Any, Literal -from sqlalchemy import select - from src.chat.express.expression_selector import expression_selector from src.chat.utils.prompt import Prompt -from src.common.database.compatibility import get_db_session -from src.common.database.core.models import ChatStreams from src.common.database.api.crud import CRUDBase +from src.common.database.core.models import ChatStreams from src.common.database.utils.decorators import cached from src.common.logger import get_logger from src.config.config import global_config, model_config @@ -208,7 +205,7 @@ class ProactiveThinkingPlanner: # 3. 获取bot人设和时间信息 individuality = Individuality() bot_personality = await individuality.get_personality_block() - + # 构建时间信息块 time_block = f"当前时间是 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" @@ -624,11 +621,11 @@ async def execute_proactive_thinking(stream_id: str): from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - + if chat_stream: # 使用 ChatStream 的 get_raw_id() 方法获取配置字符串 stream_config = chat_stream.get_raw_id() - + # 执行白名单/黑名单检查 if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config): logger.debug(f"聊天流 {stream_id} ({stream_config}) 未通过白名单/黑名单检查,跳过主动思考") @@ -637,7 +634,7 @@ async def execute_proactive_thinking(stream_id: str): logger.warning(f"无法获取聊天流 {stream_id} 的信息,跳过白名单检查") except Exception as e: logger.warning(f"白名单检查时出错: {e},继续执行") - + # 0.2 检查安静时段 if proactive_thinking_scheduler._is_in_quiet_hours(): logger.debug("安静时段,跳过") diff --git a/src/schedule/database.py b/src/schedule/database.py index ef281976c..d0bb11aa8 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -3,8 +3,8 @@ from sqlalchemy import delete, func, select, update -from src.common.database.core.models import MonthlyPlan from src.common.database.core import get_db_session +from src.common.database.core.models import MonthlyPlan from src.common.logger import get_logger from src.config.config import global_config @@ -312,7 +312,7 @@ async def delete_plans_older_than(month: str): logger.info(f"没有找到比 {month} 更早的月度计划需要删除。") return 0 - plan_months = sorted(list(set(p.target_month for p in plans_to_delete))) + plan_months = sorted({p.target_month for p in plans_to_delete}) logger.info(f"将删除 {len(plans_to_delete)} 条早于 {month} 的月度计划 (涉及月份: {', '.join(plan_months)})。") # 然后,执行删除操作 diff --git a/src/schedule/monthly_plan_manager.py b/src/schedule/monthly_plan_manager.py index 1893cbc91..e75dde62f 100644 --- a/src/schedule/monthly_plan_manager.py +++ b/src/schedule/monthly_plan_manager.py @@ -100,7 +100,7 @@ class MonthlyPlanGenerationTask(AsyncTask): next_month = datetime(now.year + 1, 1, 1) else: next_month = datetime(now.year, now.month + 1, 1) - + sleep_seconds = (next_month - now).total_seconds() logger.info( f" 下一次月度计划生成任务将在 {sleep_seconds:.2f} 秒后运行 (北京时间 {next_month.strftime('%Y-%m-%d %H:%M:%S')})" @@ -110,7 +110,7 @@ class MonthlyPlanGenerationTask(AsyncTask): # 到达月初,先归档上个月的计划 last_month = (next_month - timedelta(days=1)).strftime("%Y-%m") await self.monthly_plan_manager.plan_manager.archive_current_month_plans(last_month) - + # 为当前月生成新计划 current_month = next_month.strftime("%Y-%m") logger.info(f" 到达月初,开始生成 {current_month} 的月度计划...") diff --git a/src/schedule/prompts.py b/src/schedule/prompts.py index b239ae433..34c0e767e 100644 --- a/src/schedule/prompts.py +++ b/src/schedule/prompts.py @@ -97,4 +97,4 @@ MONTHLY_PLAN_GENERATION_PROMPT = Prompt( 请你扮演我,以我的身份和兴趣,为 {target_month} 制定合适的月度计划。 """, -) \ No newline at end of file +) diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index c32fccfc3..4556f3638 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -5,8 +5,8 @@ from typing import Any import orjson from sqlalchemy import select -from src.common.database.core.models import MonthlyPlan, Schedule from src.common.database.core import get_db_session +from src.common.database.core.models import MonthlyPlan, Schedule from src.common.logger import get_logger from src.config.config import global_config from src.manager.async_task_manager import AsyncTask, async_task_manager