From c57711c6745ad2ff6edd7c2182390ead31f75265 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sun, 14 Sep 2025 15:09:55 +0800 Subject: [PATCH 01/31] =?UTF-8?q?refactor(data=5Fmodel):=20=E8=A7=A3?= =?UTF-8?q?=E9=99=A4=20plan=5Ffilter=20=E5=AF=B9=20DatabaseMessages=20?= =?UTF-8?q?=E7=9A=84=E7=9B=B4=E6=8E=A5=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 `target_message_obj` 的类型从 `DatabaseMessages` 实例改为字典,从而消除了 `plan_filter` 模块对 `database_data_model` 的循环导入风险。同时更新了 `ActionPlannerInfo` 中 `action_message` 的类型注解以保持一致性。 --- src/chat/planner_actions/plan_filter.py | 3 +-- src/common/data_models/info_data_model.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index d76f1aa04..838320700 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -263,8 +263,7 @@ class PlanFilter: target_message_dict = self._get_latest_message(message_id_list) if target_message_dict: - from src.common.data_models.database_data_model import DatabaseMessages - target_message_obj = DatabaseMessages(**target_message_dict) + target_message_obj = target_message_dict available_action_names = list(plan.available_actions.keys()) if action not in ["no_action", "no_reply", "reply", "do_nothing", "proactive_reply"] and action not in available_action_names: diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index 2806587c1..32893706d 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -21,7 +21,7 @@ class ActionPlannerInfo(BaseDataModel): action_type: str = field(default_factory=str) reasoning: Optional[str] = None action_data: Optional[Dict] = None - action_message: Optional["DatabaseMessages"] = None + action_message: Optional[Dict] = None available_actions: Optional[Dict[str, "ActionInfo"]] = None From 4dbc651d74f2ae43fecd4186a40661f78277b70f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Fri, 19 Sep 2025 20:20:20 +0800 Subject: [PATCH 02/31] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=BC=82?= =?UTF-8?q?=E6=AD=A5=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 仅仅支持还有107处待迁移 --- requirements.txt | 2 + rust_image/Cargo.toml | 0 src/common/database/db_migration.py | 116 ++++++++++-------- .../database/sqlalchemy_database_api.py | 101 ++++++++------- src/common/database/sqlalchemy_init.py | 37 +++--- src/common/database/sqlalchemy_models.py | 49 ++++---- 6 files changed, 169 insertions(+), 136 deletions(-) create mode 100644 rust_image/Cargo.toml diff --git a/requirements.txt b/requirements.txt index edc7b9cb8..757c1e09a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ sqlalchemy +aiosqlite +aiomysql APScheduler aiohttp aiohttp-cors diff --git a/rust_image/Cargo.toml b/rust_image/Cargo.toml new file mode 100644 index 000000000..e69de29bb diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index cae3cbd29..9d2be9e5b 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -7,71 +7,81 @@ from src.common.logger import get_logger logger = get_logger("db_migration") -def check_and_migrate_database(): +async def check_and_migrate_database(): """ - 检查数据库结构并自动迁移(添加缺失的表和列)。 + 异步检查数据库结构并自动迁移(添加缺失的表和列)。 """ logger.info("正在检查数据库结构并执行自动迁移...") - engine = get_engine() - inspector = inspect(engine) + engine = await get_engine() + + # 使用异步引擎获取inspector + async with engine.connect() as connection: + # 在同步上下文中运行inspector操作 + inspector = await connection.run_sync(lambda sync_conn: inspect(sync_conn)) + + # 1. 获取数据库中所有已存在的表名 + db_table_names = await connection.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names())) - # 1. 获取数据库中所有已存在的表名 - db_table_names = set(inspector.get_table_names()) + # 2. 遍历所有在代码中定义的模型 + for table_name, table in Base.metadata.tables.items(): + logger.debug(f"正在检查表: {table_name}") - # 2. 遍历所有在代码中定义的模型 - for table_name, table in Base.metadata.tables.items(): - logger.debug(f"正在检查表: {table_name}") + # 3. 如果表不存在,则创建它 + if table_name not in db_table_names: + logger.info(f"表 '{table_name}' 不存在,正在创建...") + try: + await connection.run_sync(lambda sync_conn: table.create(sync_conn)) + logger.info(f"表 '{table_name}' 创建成功。") + except Exception as e: + logger.error(f"创建表 '{table_name}' 失败: {e}") + continue - # 3. 如果表不存在,则创建它 - if table_name not in db_table_names: - logger.info(f"表 '{table_name}' 不存在,正在创建...") - try: - table.create(engine) - logger.info(f"表 '{table_name}' 创建成功。") - except Exception as e: - logger.error(f"创建表 '{table_name}' 失败: {e}") - continue + # 4. 如果表已存在,则检查并添加缺失的列 + db_columns = await connection.run_sync( + lambda sync_conn: {col["name"] for col in inspect(sync_conn).get_columns(table_name)} + ) + model_columns = {col.name for col in table.c} - # 4. 如果表已存在,则检查并添加缺失的列 - db_columns = {col["name"] for col in inspector.get_columns(table_name)} - model_columns = {col.name for col in table.c} + missing_columns = model_columns - db_columns + if not missing_columns: + logger.debug(f"表 '{table_name}' 结构一致,无需修改。") + continue - missing_columns = model_columns - db_columns - if not missing_columns: - logger.debug(f"表 '{table_name}' 结构一致,无需修改。") - continue + logger.info(f"在表 '{table_name}' 中发现缺失的列: {', '.join(missing_columns)}") + + # 开始事务来添加缺失的列 + async with connection.begin() as trans: + try: + for column_name in missing_columns: + column = table.c[column_name] - logger.info(f"在表 '{table_name}' 中发现缺失的列: {', '.join(missing_columns)}") - with engine.connect() as connection: - trans = connection.begin() - try: - for column_name in missing_columns: - column = table.c[column_name] + # 构造并执行 ALTER TABLE 语句 + try: + # 在同步上下文中编译列类型 + column_type = await connection.run_sync( + lambda sync_conn: column.type.compile(sync_conn.dialect) + ) + sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" - # 构造并执行 ALTER TABLE 语句 - try: - column_type = column.type.compile(engine.dialect) - sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" + # 添加默认值和非空约束的处理 + if column.default is not None: + default_value = column.default.arg + if isinstance(default_value, str): + sql += f" DEFAULT '{default_value}'" + else: + sql += f" DEFAULT {default_value}" - # 添加默认值和非空约束的处理 - if column.default is not None: - default_value = column.default.arg - if isinstance(default_value, str): - sql += f" DEFAULT '{default_value}'" - else: - sql += f" DEFAULT {default_value}" + if not column.nullable: + sql += " NOT NULL" - if not column.nullable: - sql += " NOT NULL" + await connection.execute(text(sql)) + logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'。") + except Exception as e: + logger.error(f"向表 '{table_name}' 添加列 '{column_name}' 失败: {e}") - connection.execute(text(sql)) - logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'。") - except Exception as e: - logger.error(f"向表 '{table_name}' 添加列 '{column_name}' 失败: {e}") - - trans.commit() - except Exception as e: - logger.error(f"在表 '{table_name}' 添加列时发生错误,事务已回滚: {e}") - trans.rollback() + except Exception as e: + logger.error(f"在表 '{table_name}' 添加列时发生错误,事务已回滚: {e}") + await trans.rollback() + raise logger.info("数据库结构检查与自动迁移完成。") diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 1643f5838..4f0258c2b 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -6,9 +6,11 @@ import traceback import time +import asyncio from typing import Dict, List, Any, Union, Type, Optional from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy import desc, asc, func, and_ +from sqlalchemy import desc, asc, func, and_, select +from sqlalchemy.ext.asyncio import AsyncSession from src.common.logger import get_logger from src.common.database.sqlalchemy_models import ( Base, @@ -56,7 +58,7 @@ MODEL_MAPPING = { } -def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]): +async def build_filters(model_class, filters: Dict[str, Any]): """构建查询过滤条件""" conditions = [] @@ -94,7 +96,7 @@ def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]): async def db_query( - model_class: Type[Base], + model_class, data: Optional[Dict[str, Any]] = None, query_type: Optional[str] = "get", filters: Optional[Dict[str, Any]] = None, @@ -102,7 +104,7 @@ async def db_query( order_by: Optional[List[str]] = None, single_result: Optional[bool] = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: - """执行数据库查询操作 + """执行异步数据库查询操作 Args: model_class: SQLAlchemy模型类 @@ -120,15 +122,15 @@ async def db_query( if query_type not in ["get", "create", "update", "delete", "count"]: raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'") - with get_db_session() as session: + async with get_db_session() as session: if query_type == "get": - query = session.query(model_class) + query = select(model_class) # 应用过滤条件 if filters: - conditions = build_filters(session, model_class, filters) + conditions = await build_filters(model_class, filters) if conditions: - query = query.filter(and_(*conditions)) + query = query.where(and_(*conditions)) # 应用排序 if order_by: @@ -146,14 +148,15 @@ async def db_query( query = query.limit(limit) # 执行查询 - results = query.all() + result = await session.execute(query) + results = result.scalars().all() # 转换为字典格式 result_dicts = [] - for result in results: + for result_obj in results: result_dict = {} - for column in result.__table__.columns: - result_dict[column.name] = getattr(result, column.name) + for column in result_obj.__table__.columns: + result_dict[column.name] = getattr(result_obj, column.name) result_dicts.append(result_dict) if single_result: @@ -167,7 +170,7 @@ async def db_query( # 创建新记录 new_record = model_class(**data) session.add(new_record) - session.flush() # 获取自动生成的ID + await session.flush() # 获取自动生成的ID # 转换为字典格式返回 result_dict = {} @@ -179,43 +182,60 @@ async def db_query( if not data: raise ValueError("更新记录需要提供data参数") - query = session.query(model_class) + query = select(model_class) # 应用过滤条件 if filters: - conditions = build_filters(session, model_class, filters) + conditions = await build_filters(model_class, filters) if conditions: - query = query.filter(and_(*conditions)) + query = query.where(and_(*conditions)) - # 执行更新 - affected_rows = query.update(data) + # 首先获取要更新的记录 + result = await session.execute(query) + records_to_update = result.scalars().all() + + # 更新每个记录 + affected_rows = 0 + for record in records_to_update: + for field, value in data.items(): + if hasattr(record, field): + setattr(record, field, value) + affected_rows += 1 + return affected_rows elif query_type == "delete": - query = session.query(model_class) + query = select(model_class) # 应用过滤条件 if filters: - conditions = build_filters(session, model_class, filters) + conditions = await build_filters(model_class, filters) if conditions: - query = query.filter(and_(*conditions)) + query = query.where(and_(*conditions)) - # 执行删除 - affected_rows = query.delete() + # 首先获取要删除的记录 + result = await session.execute(query) + records_to_delete = result.scalars().all() + + # 删除记录 + affected_rows = 0 + for record in records_to_delete: + session.delete(record) + affected_rows += 1 + return affected_rows elif query_type == "count": - query = session.query(func.count(model_class.id)) + query = select(func.count(model_class.id)) # 应用过滤条件 if filters: - base_query = session.query(model_class) - conditions = build_filters(session, model_class, filters) + conditions = await build_filters(model_class, filters) if conditions: - base_query = base_query.filter(and_(*conditions)) - query = session.query(func.count()).select_from(base_query.subquery()) + query = query.where(and_(*conditions)) - return query.scalar() + result = await session.execute(query) + return result.scalar() except SQLAlchemyError as e: logger.error(f"[SQLAlchemy] 数据库操作出错: {e}") @@ -238,9 +258,9 @@ async def db_query( async def db_save( - model_class: Type[Base], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None + model_class, data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None ) -> Optional[Dict[str, Any]]: - """保存数据到数据库(创建或更新) + """异步保存数据到数据库(创建或更新) Args: model_class: SQLAlchemy模型类 @@ -252,13 +272,13 @@ async def db_save( 保存后的记录数据或None """ try: - with get_db_session() as session: + async with get_db_session() as session: # 如果提供了key_field和key_value,尝试更新现有记录 if key_field and key_value is not None: if hasattr(model_class, key_field): - existing_record = ( - session.query(model_class).filter(getattr(model_class, key_field) == key_value).first() - ) + query = select(model_class).where(getattr(model_class, key_field) == key_value) + result = await session.execute(query) + existing_record = result.scalars().first() if existing_record: # 更新现有记录 @@ -266,7 +286,7 @@ async def db_save( if hasattr(existing_record, field): setattr(existing_record, field, value) - session.flush() + await session.flush() # 转换为字典格式返回 result_dict = {} @@ -277,8 +297,7 @@ async def db_save( # 创建新记录 new_record = model_class(**data) session.add(new_record) - session.commit() - session.flush() + await session.flush() # 转换为字典格式返回 result_dict = {} @@ -297,13 +316,13 @@ async def db_save( async def db_get( - model_class: Type[Base], + model_class, filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, order_by: Optional[str] = None, single_result: Optional[bool] = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: - """从数据库获取记录 + """异步从数据库获取记录 Args: model_class: SQLAlchemy模型类 @@ -335,7 +354,7 @@ async def store_action_info( action_data: Optional[dict] = None, action_name: str = "", ) -> Optional[Dict[str, Any]]: - """存储动作信息到数据库 + """异步存储动作信息到数据库 Args: chat_stream: 聊天流对象 diff --git a/src/common/database/sqlalchemy_init.py b/src/common/database/sqlalchemy_init.py index fa7a864eb..7d3f97136 100644 --- a/src/common/database/sqlalchemy_init.py +++ b/src/common/database/sqlalchemy_init.py @@ -1,7 +1,7 @@ """SQLAlchemy数据库初始化模块 替换Peewee的数据库初始化逻辑 -提供统一的数据库初始化接口 +提供统一的异步数据库初始化接口 """ from typing import Optional @@ -12,25 +12,25 @@ from src.common.database.sqlalchemy_models import Base, get_engine, initialize_d logger = get_logger("sqlalchemy_init") -def initialize_sqlalchemy_database() -> bool: +async def initialize_sqlalchemy_database() -> bool: """ - 初始化SQLAlchemy数据库 + 初始化SQLAlchemy异步数据库 创建所有表结构 Returns: bool: 初始化是否成功 """ try: - logger.info("开始初始化SQLAlchemy数据库...") + logger.info("开始初始化SQLAlchemy异步数据库...") # 初始化数据库引擎和会话 - engine, session_local = initialize_database() + engine, session_local = await initialize_database() if engine is None: logger.error("数据库引擎初始化失败") return False - logger.info("SQLAlchemy数据库初始化成功") + logger.info("SQLAlchemy异步数据库初始化成功") return True except SQLAlchemyError as e: @@ -41,9 +41,9 @@ def initialize_sqlalchemy_database() -> bool: return False -def create_all_tables() -> bool: +async def create_all_tables() -> bool: """ - 创建所有数据库表 + 异步创建所有数据库表 Returns: bool: 创建是否成功 @@ -51,13 +51,14 @@ def create_all_tables() -> bool: try: logger.info("开始创建数据库表...") - engine = get_engine() + engine = await get_engine() if engine is None: logger.error("无法获取数据库引擎") return False - # 创建所有表 - Base.metadata.create_all(bind=engine) + # 异步创建所有表 + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) logger.info("数据库表创建成功") return True @@ -70,15 +71,15 @@ def create_all_tables() -> bool: return False -def get_database_info() -> Optional[dict]: +async def get_database_info() -> Optional[dict]: """ - 获取数据库信息 + 异步获取数据库信息 Returns: dict: 数据库信息字典,包含引擎信息等 """ try: - engine = get_engine() + engine = await get_engine() if engine is None: return None @@ -100,9 +101,9 @@ def get_database_info() -> Optional[dict]: _database_initialized = False -def initialize_database_compat() -> bool: +async def initialize_database_compat() -> bool: """ - 兼容性数据库初始化函数 + 兼容性异步数据库初始化函数 用于替换原有的Peewee初始化代码 Returns: @@ -113,9 +114,9 @@ def initialize_database_compat() -> bool: if _database_initialized: return True - success = initialize_sqlalchemy_database() + success = await initialize_sqlalchemy_database() if success: - success = create_all_tables() + success = await create_all_tables() if success: _database_initialized = True diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 464b38e9f..f9b0ef68d 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -3,16 +3,18 @@ 替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力 """ -from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime +from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session, Mapped, mapped_column +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker +from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.pool import QueuePool import os import datetime import time -from typing import Iterator, Optional, Any, Dict +from typing import Iterator, Optional, Any, Dict, AsyncGenerator from src.common.logger import get_logger -from contextlib import contextmanager +from contextlib import asynccontextmanager +import asyncio logger = get_logger("sqlalchemy_models") @@ -575,14 +577,14 @@ def get_database_url(): # 使用Unix socket连接 encoded_socket = quote_plus(config.mysql_unix_socket) return ( - f"mysql+pymysql://{encoded_user}:{encoded_password}" + f"mysql+aiomysql://{encoded_user}:{encoded_password}" f"@/{config.mysql_database}" f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" ) else: # 使用标准TCP连接 return ( - f"mysql+pymysql://{encoded_user}:{encoded_password}" + f"mysql+aiomysql://{encoded_user}:{encoded_password}" f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" f"?charset={config.mysql_charset}" ) @@ -597,11 +599,11 @@ def get_database_url(): # 确保数据库目录存在 os.makedirs(os.path.dirname(db_path), exist_ok=True) - return f"sqlite:///{db_path}" + return f"sqlite+aiosqlite:///{db_path}" -def initialize_database(): - """初始化数据库引擎和会话""" +async def initialize_database(): + """初始化异步数据库引擎和会话""" global _engine, _SessionLocal if _engine is not None: @@ -654,41 +656,40 @@ def initialize_database(): } ) - _engine = create_engine(database_url, **engine_kwargs) - _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine) + _engine = create_async_engine(database_url, **engine_kwargs) + _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) # 调用新的迁移函数,它会处理表的创建和列的添加 from src.common.database.db_migration import check_and_migrate_database - check_and_migrate_database() + await check_and_migrate_database() - logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}") + logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}") return _engine, _SessionLocal -@contextmanager -def get_db_session() -> Iterator[Session]: - """数据库会话上下文管理器 - 推荐使用这个而不是get_session()""" - session: Optional[Session] = None +@asynccontextmanager +async def get_db_session() -> AsyncGenerator[AsyncSession, None]: + """异步数据库会话上下文管理器""" + session: Optional[AsyncSession] = None try: - engine, SessionLocal = initialize_database() + engine, SessionLocal = await initialize_database() if not SessionLocal: raise RuntimeError("Database session not initialized") session = SessionLocal() yield session - # session.commit() except Exception: if session: - session.rollback() + await session.rollback() raise finally: if session: - session.close() + await session.close() -def get_engine(): - """获取数据库引擎""" - engine, _ = initialize_database() +async def get_engine(): + """获取异步数据库引擎""" + engine, _ = await initialize_database() return engine From 6a98ae62081667084f37499f3dfe30af3f8465a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Fri, 19 Sep 2025 20:45:26 +0800 Subject: [PATCH 03/31] =?UTF-8?q?=E4=BA=8C=E6=AC=A1=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 6 +++++- pyproject.toml | 2 ++ src/common/database/database.py | 4 ++-- src/common/database/sqlalchemy_models.py | 8 +++----- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/bot.py b/bot.py index b4c448670..f382df1e1 100644 --- a/bot.py +++ b/bot.py @@ -193,9 +193,11 @@ class MaiBotMain(BaseMain): logger.error(f"数据库连接初始化失败: {e}") raise e + async def initialize_database_async(self): + """异步初始化数据库表结构""" logger.info("正在初始化数据库表结构...") try: - init_db() + await init_db() logger.info("数据库表结构初始化完成") except Exception as e: logger.error(f"数据库表结构初始化失败: {e}") @@ -229,6 +231,8 @@ if __name__ == "__main__": try: # 执行初始化和任务调度 loop.run_until_complete(main_system.initialize()) + # 异步初始化数据库表结构 + loop.run_until_complete(maibot.initialize_database_async()) initialize_lpmm_knowledge() # Schedule tasks returns a future that runs forever. # We can run console_input_loop concurrently. diff --git a/pyproject.toml b/pyproject.toml index 68b1837e7..ea7bc77f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,8 @@ dependencies = [ "uvicorn>=0.35.0", "watchdog>=6.0.0", "websockets>=15.0.1", + "aiomysql>=0.2.0", + "aiosqlite>=0.21.0", ] [[tool.uv.index]] diff --git a/src/common/database/database.py b/src/common/database/database.py index 88dee6464..d196df032 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -89,7 +89,7 @@ def get_db(): return _db -def initialize_sql_database(database_config): +async def initialize_sql_database(database_config): """ 根据配置初始化SQL数据库连接(SQLAlchemy版本) @@ -119,7 +119,7 @@ def initialize_sql_database(database_config): # 使用SQLAlchemy初始化 success = initialize_database_compat() if success: - _sql_engine = get_engine() + _sql_engine = await get_engine() logger.info("SQLAlchemy数据库初始化成功") else: logger.error("SQLAlchemy数据库初始化失败") diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index f9b0ef68d..e5eacee1f 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -7,7 +7,6 @@ from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, Dat from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy.pool import QueuePool import os import datetime import time @@ -621,10 +620,9 @@ async def initialize_database(): } if config.database_type == "mysql": - # MySQL连接池配置 + # MySQL连接池配置 - 异步引擎使用默认连接池 engine_kwargs.update( { - "poolclass": QueuePool, "pool_size": config.connection_pool_size, "max_overflow": config.connection_pool_size * 2, "pool_timeout": config.connection_timeout, @@ -640,10 +638,9 @@ async def initialize_database(): } ) else: - # SQLite配置 - 添加连接池设置以避免连接耗尽 + # SQLite配置 - 异步引擎使用默认连接池 engine_kwargs.update( { - "poolclass": QueuePool, "pool_size": 20, # 增加池大小 "max_overflow": 30, # 增加溢出连接数 "pool_timeout": 60, # 增加超时时间 @@ -678,6 +675,7 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: raise RuntimeError("Database session not initialized") session = SessionLocal() yield session + # await session.commit() except Exception: if session: await session.rollback() From 0cc4f5bb27a9f05646d255bd1c84014547e157a6 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 02:21:53 +0800 Subject: [PATCH 04/31] =?UTF-8?q?=E4=B8=89=E6=AC=A1=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/chat_loop/heartFC_chat.py | 2 +- src/chat/emoji_system/emoji_manager.py | 61 +++-- src/chat/express/expression_selector.py | 41 ++-- src/chat/memory_system/Hippocampus.py | 77 +++--- src/chat/message_receive/chat_stream.py | 34 ++- src/chat/message_receive/storage.py | 48 ++-- src/chat/planner_actions/action_modifier.py | 4 +- src/chat/planner_actions/plan_filter.py | 4 +- src/chat/planner_actions/plan_generator.py | 2 +- src/chat/replyer/default_generator.py | 17 +- src/chat/utils/chat_message_builder.py | 132 +++++----- src/common/message_repository.py | 18 +- src/main.py | 10 +- src/person_info/person_info.py | 257 +++++++------------- src/person_info/relationship_fetcher.py | 28 ++- src/plugin_system/apis/message_api.py | 18 +- src/schedule/database.py | 147 +++++------ src/schedule/monthly_plan_manager.py | 5 + src/schedule/plan_manager.py | 18 +- src/schedule/schedule_manager.py | 36 +-- 20 files changed, 478 insertions(+), 481 deletions(-) diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index 6f63cff1b..05edb3ee0 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -361,7 +361,7 @@ class HeartFChatting: # 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收 filter_command_flag = not (is_sleeping or is_in_insomnia) - recent_messages = message_api.get_messages_by_time_in_chat( + recent_messages = await message_api.get_messages_by_time_in_chat( chat_id=self.context.stream_id, start_time=self.context.last_read_time, end_time=time.time(), diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 8e6079897..ce7b0d074 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -149,7 +149,7 @@ class MaiEmoji: # --- 数据库操作 --- try: # 准备数据库记录 for emoji collection - with get_db_session() as session: + async with get_db_session() as session: emotion_str = ",".join(self.emotion) if self.emotion else "" emoji = Emoji( @@ -167,7 +167,7 @@ class MaiEmoji: last_used_time=self.last_used_time, ) session.add(emoji) - session.commit() + await session.commit() logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") @@ -203,17 +203,17 @@ class MaiEmoji: # 2. 删除数据库记录 try: - with get_db_session() as session: - will_delete_emoji = session.execute( - select(Emoji).where(Emoji.emoji_hash == self.hash) + async with get_db_session() as session: + will_delete_emoji = ( + await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)) ).scalar_one_or_none() if will_delete_emoji is None: logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") - result = 0 # Indicate no DB record was deleted + result = 0 else: - session.delete(will_delete_emoji) - result = 1 # Successfully deleted one record - session.commit() + await session.delete(will_delete_emoji) + result = 1 + await session.commit() except Exception as e: logger.error(f"[错误] 删除数据库记录时出错: {str(e)}") result = 0 @@ -424,17 +424,19 @@ class EmojiManager: # if not self._initialized: # raise RuntimeError("EmojiManager not initialized") - def record_usage(self, emoji_hash: str) -> None: + async def record_usage(self, emoji_hash: str) -> None: """记录表情使用次数""" try: - with get_db_session() as session: - emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() + async with get_db_session() as session: + emoji_update = ( + await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)) + ).scalar_one_or_none() if emoji_update is None: logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") else: emoji_update.usage_count += 1 - emoji_update.last_used_time = time.time() # Update last used time - session.commit() + emoji_update.last_used_time = time.time() + await session.commit() except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -658,10 +660,11 @@ class EmojiManager: async def get_all_emoji_from_db(self) -> None: """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" try: - with get_db_session() as session: + async with get_db_session() as session: logger.debug("[数据库] 开始加载所有表情包记录 ...") - emoji_instances = session.execute(select(Emoji)).scalars().all() + result = await session.execute(select(Emoji)) + emoji_instances = result.scalars().all() emoji_objects, load_errors = _to_emoji_objects(emoji_instances) # 更新内存中的列表和数量 @@ -687,14 +690,16 @@ class EmojiManager: list[MaiEmoji]: 表情包对象列表 """ try: - with get_db_session() as session: + async with get_db_session() as session: if emoji_hash: - query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all() + result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)) + query = result.scalars().all() else: logger.warning( "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" ) - query = session.execute(select(Emoji)).scalars().all() + result = await session.execute(select(Emoji)) + query = result.scalars().all() emoji_instances = query emoji_objects, load_errors = _to_emoji_objects(emoji_instances) @@ -771,10 +776,11 @@ class EmojiManager: # 如果内存中没有,从数据库查找 try: - with get_db_session() as session: - emoji_record = session.execute( + async with get_db_session() as session: + result = await session.execute( select(Emoji).where(Emoji.emoji_hash == emoji_hash) - ).scalar_one_or_none() + ) + emoji_record = result.scalar_one_or_none() if emoji_record and emoji_record.description: logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") return emoji_record.description @@ -937,10 +943,13 @@ class EmojiManager: # 2. 检查数据库中是否已存在该表情包的描述,实现复用 existing_description = None try: - with get_db_session() as session: - existing_image = session.query(Images).filter( - (Images.emoji_hash == image_hash) & (Images.type == "emoji") - ).one_or_none() + async with get_db_session() as session: + result = await session.execute( + select(Images).filter( + (Images.emoji_hash == image_hash) & (Images.type == "emoji") + ) + ) + existing_image = result.scalar_one_or_none() if existing_image and existing_image.description: existing_description = existing_image.description logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 2883ec82d..2a269fbf9 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -136,18 +136,18 @@ class ExpressionSelector: return related_chat_ids if related_chat_ids else [chat_id] - def get_random_expressions( + async def get_random_expressions( self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: # sourcery skip: extract-duplicate-method, move-assign # 支持多chat_id合并抽选 related_chat_ids = self.get_related_chat_ids(chat_id) - with get_db_session() as session: + async with get_db_session() as session: # 优化:一次性查询所有相关chat_id的表达方式 - style_query = session.execute( + style_query = await session.execute( select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")) ) - grammar_query = session.execute( + grammar_query = await session.execute( select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")) ) @@ -193,7 +193,7 @@ class ExpressionSelector: return selected_style, selected_grammar - def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): + async def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库""" if not expressions_to_update: return @@ -210,26 +210,27 @@ class ExpressionSelector: if key not in updates_by_key: updates_by_key[key] = expr for chat_id, expr_type, situation, style in updates_by_key: - with get_db_session() as session: - query = session.execute( + async with get_db_session() as session: + query = await session.execute( select(Expression).where( (Expression.chat_id == chat_id) & (Expression.type == expr_type) & (Expression.situation == situation) & (Expression.style == style) ) - ).scalar() - if query: - expr_obj = query - current_count = expr_obj.count - new_count = min(current_count + increment, 5.0) - expr_obj.count = new_count - expr_obj.last_active_time = time.time() - - logger.debug( - f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db" ) - session.commit() + query = query.scalar() + if query: + expr_obj = query + current_count = expr_obj.count + new_count = min(current_count + increment, 5.0) + expr_obj.count = new_count + expr_obj.last_active_time = time.time() + + logger.debug( + f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db" + ) + await session.commit() async def select_suitable_expressions_llm( self, @@ -248,7 +249,7 @@ class ExpressionSelector: return [] # 1. 获取35个随机表达方式(现在按权重抽取) - style_exprs, grammar_exprs = self.get_random_expressions(chat_id, 30, 0.5, 0.5) + style_exprs, grammar_exprs = await self.get_random_expressions(chat_id, 30, 0.5, 0.5) # 2. 构建所有表达方式的索引和情境列表 all_expressions = [] @@ -334,7 +335,7 @@ class ExpressionSelector: # 对选中的所有表达方式,一次性更新count数 if valid_expressions: - self.update_expressions_count_batch(valid_expressions, 0.006) + await self.update_expressions_count_batch(valid_expressions, 0.006) # logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个") return valid_expressions diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index af5078caf..46d46b202 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -201,7 +201,7 @@ class Hippocampus: self.entorhinal_cortex = EntorhinalCortex(self) self.parahippocampal_gyrus = ParahippocampalGyrus(self) # 从数据库加载记忆图 - self.entorhinal_cortex.sync_memory_from_db() + # self.entorhinal_cortex.sync_memory_from_db() # 改为异步启动 self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.small") def get_all_node_names(self) -> list: @@ -789,7 +789,7 @@ class EntorhinalCortex: self.hippocampus = hippocampus self.memory_graph = hippocampus.memory_graph - def get_memory_sample(self): + async def get_memory_sample(self): """从数据库获取记忆样本""" # 硬编码:每条消息最大记忆次数 max_memorized_time_per_msg = 2 @@ -812,7 +812,7 @@ class EntorhinalCortex: logger.debug(f"回忆往事: {readable_timestamp}") chat_samples = [] for timestamp in timestamps: - if messages := self.random_get_msg_snippet( + if messages := await self.random_get_msg_snippet( timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg, @@ -826,7 +826,9 @@ class EntorhinalCortex: return chat_samples @staticmethod - def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None: + async def random_get_msg_snippet( + target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int + ) -> list | None: # sourcery skip: invert-any-all, use-any, use-named-expression, use-next """从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)""" time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟 @@ -864,13 +866,13 @@ class EntorhinalCortex: for message in messages: # 确保在更新前获取最新的 memorized_times current_memorized_times = message.get("memorized_times", 0) - with get_db_session() as session: - session.execute( + async with get_db_session() as session: + await session.execute( update(Messages) .where(Messages.message_id == message["message_id"]) .values(memorized_times=current_memorized_times + 1) ) - session.commit() + await session.commit() return messages # 直接返回原始的消息列表 target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试 @@ -884,8 +886,8 @@ class EntorhinalCortex: current_time = datetime.datetime.now().timestamp() # 获取数据库中所有节点和内存中所有节点 - with get_db_session() as session: - db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()} + async with get_db_session() as session: + db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()} memory_nodes = list(self.memory_graph.G.nodes(data=True)) # 批量准备节点数据 @@ -954,24 +956,24 @@ class EntorhinalCortex: batch_size = 100 for i in range(0, len(nodes_to_create), batch_size): batch = nodes_to_create[i : i + batch_size] - session.execute(insert(GraphNodes), batch) + await session.execute(insert(GraphNodes), batch) if nodes_to_update: batch_size = 100 for i in range(0, len(nodes_to_update), batch_size): batch = nodes_to_update[i : i + batch_size] for node_data in batch: - session.execute( + await session.execute( update(GraphNodes) .where(GraphNodes.concept == node_data["concept"]) .values(**{k: v for k, v in node_data.items() if k != "concept"}) ) if nodes_to_delete: - session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete))) + await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete))) # 处理边的信息 - db_edges = list(session.execute(select(GraphEdges)).scalars()) + db_edges = list((await session.execute(select(GraphEdges))).scalars()) memory_edges = list(self.memory_graph.G.edges(data=True)) # 创建边的哈希值字典 @@ -1023,14 +1025,14 @@ class EntorhinalCortex: batch_size = 100 for i in range(0, len(edges_to_create), batch_size): batch = edges_to_create[i : i + batch_size] - session.execute(insert(GraphEdges), batch) + await session.execute(insert(GraphEdges), batch) if edges_to_update: batch_size = 100 for i in range(0, len(edges_to_update), batch_size): batch = edges_to_update[i : i + batch_size] for edge_data in batch: - session.execute( + await session.execute( update(GraphEdges) .where( (GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"]) @@ -1040,12 +1042,12 @@ class EntorhinalCortex: if edges_to_delete: for source, target in edges_to_delete: - session.execute( + await session.execute( delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target)) ) # 提交事务 - session.commit() + await session.commit() end_time = time.time() logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}秒") @@ -1057,10 +1059,10 @@ class EntorhinalCortex: logger.info("[数据库] 开始重新同步所有记忆数据...") # 清空数据库 - with get_db_session() as session: + async with get_db_session() as session: clear_start = time.time() - session.execute(delete(GraphNodes)) - session.execute(delete(GraphEdges)) + await session.execute(delete(GraphNodes)) + await session.execute(delete(GraphEdges)) clear_end = time.time() logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒") @@ -1119,7 +1121,7 @@ class EntorhinalCortex: batch_size = 500 # 增加批量大小 for i in range(0, len(nodes_data), batch_size): batch = nodes_data[i : i + batch_size] - session.execute(insert(GraphNodes), batch) + await session.execute(insert(GraphNodes), batch) node_end = time.time() logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒") @@ -1130,8 +1132,8 @@ class EntorhinalCortex: batch_size = 500 # 增加批量大小 for i in range(0, len(edges_data), batch_size): batch = edges_data[i : i + batch_size] - session.execute(insert(GraphEdges), batch) - session.commit() + await session.execute(insert(GraphEdges), batch) + await session.commit() edge_end = time.time() logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒") @@ -1140,7 +1142,7 @@ class EntorhinalCortex: logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒") logger.info(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边") - def sync_memory_from_db(self): + async def sync_memory_from_db(self): """从数据库同步数据到内存中的图结构""" current_time = datetime.datetime.now().timestamp() need_update = False @@ -1149,8 +1151,8 @@ class EntorhinalCortex: self.memory_graph.G.clear() # 从数据库加载所有节点 - with get_db_session() as session: - nodes = list(session.execute(select(GraphNodes)).scalars()) + async with get_db_session() as session: + nodes = list((await session.execute(select(GraphNodes))).scalars()) for node in nodes: concept = node.concept try: @@ -1168,7 +1170,9 @@ class EntorhinalCortex: if not node.last_modified: update_data["last_modified"] = current_time - session.execute(update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)) + await session.execute( + update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data) + ) # 获取时间信息(如果不存在则使用当前时间) created_time = node.created_time or current_time @@ -1183,7 +1187,7 @@ class EntorhinalCortex: continue # 从数据库加载所有边 - edges = list(session.execute(select(GraphEdges)).scalars()) + edges = list((await session.execute(select(GraphEdges))).scalars()) for edge in edges: source = edge.source target = edge.target @@ -1199,7 +1203,7 @@ class EntorhinalCortex: if not edge.last_modified: update_data["last_modified"] = current_time - session.execute( + await session.execute( update(GraphEdges) .where((GraphEdges.source == source) & (GraphEdges.target == target)) .values(**update_data) @@ -1214,7 +1218,7 @@ class EntorhinalCortex: self.memory_graph.G.add_edge( source, target, strength=strength, created_time=created_time, last_modified=last_modified ) - session.commit() + await session.commit() if need_update: logger.info("[数据库] 已为缺失的时间字段进行补充") @@ -1254,7 +1258,7 @@ class ParahippocampalGyrus: # 1. 使用 build_readable_messages 生成格式化文本 # build_readable_messages 只返回一个字符串,不需要解包 - input_text = build_readable_messages( + input_text = await build_readable_messages( messages, merge_messages=True, # 合并连续消息 timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式 @@ -1342,7 +1346,7 @@ class ParahippocampalGyrus: # sourcery skip: merge-list-appends-into-extend logger.info("------------------------------------开始构建记忆--------------------------------------") start_time = time.time() - memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample() + memory_samples = await self.hippocampus.entorhinal_cortex.get_memory_sample() all_added_nodes = [] all_connected_nodes = [] all_added_edges = [] @@ -1620,7 +1624,7 @@ class HippocampusManager: return self._hippocampus self._hippocampus = Hippocampus() - self._hippocampus.initialize() + # self._hippocampus.initialize() # 改为异步启动 self._initialized = True # 输出记忆图统计信息 @@ -1639,6 +1643,13 @@ class HippocampusManager: return self._hippocampus + async def initialize_async(self): + """异步初始化海马体实例""" + if not self._initialized: + self.initialize() # 先进行同步部分的初始化 + self._hippocampus.initialize() + await self._hippocampus.entorhinal_cortex.sync_memory_from_db() + def get_hippocampus(self): if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index f5822acfb..4f91d15c6 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -246,11 +246,11 @@ class ChatManager: return stream # 检查数据库中是否存在 - def _db_find_stream_sync(s_id: str): - with get_db_session() as session: - return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar() + async def _db_find_stream_async(s_id: str): + async with get_db_session() as session: + return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalar() - model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id) + model_instance = await _db_find_stream_async(stream_id) if model_instance: # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 @@ -344,11 +344,10 @@ class ChatManager: return stream_data_dict = stream.to_dict() - def _db_save_stream_sync(s_data_dict: dict): - with get_db_session() as session: + async def _db_save_stream_async(s_data_dict: dict): + async with get_db_session() as session: user_info_d = s_data_dict.get("user_info") group_info_d = s_data_dict.get("group_info") - fields_to_save = { "platform": s_data_dict["platform"], "create_time": s_data_dict["create_time"], @@ -364,8 +363,6 @@ class ChatManager: "sleep_pressure": s_data_dict.get("sleep_pressure", 0.0), "focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value), } - - # 根据数据库类型选择插入语句 if global_config.database.database_type == "sqlite": stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save) @@ -375,15 +372,13 @@ class ChatManager: **{key: value for key, value in fields_to_save.items() if key != "stream_id"} ) else: - # 默认使用通用插入,尝试SQLite语法 stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save) - - session.execute(stmt) - session.commit() + await session.execute(stmt) + await session.commit() try: - await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) + await _db_save_stream_async(stream_data_dict) stream.saved = True except Exception as e: logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) @@ -397,10 +392,10 @@ class ChatManager: """从数据库加载所有聊天流""" logger.info("正在从数据库加载所有聊天流") - def _db_load_all_streams_sync(): + async def _db_load_all_streams_async(): loaded_streams_data = [] - with get_db_session() as session: - for model_instance in session.execute(select(ChatStreams)).scalars(): + async with get_db_session() as session: + for model_instance in (await session.execute(select(ChatStreams))).scalars(): user_info_data = { "platform": model_instance.user_platform, "user_id": model_instance.user_id, @@ -414,7 +409,6 @@ class ChatManager: "group_id": model_instance.group_id, "group_name": model_instance.group_name, } - data_for_from_dict = { "stream_id": model_instance.stream_id, "platform": model_instance.platform, @@ -427,11 +421,11 @@ class ChatManager: "focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value), } loaded_streams_data.append(data_for_from_dict) - session.commit() + await session.commit() return loaded_streams_data try: - all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync) + all_streams_data_list = await _db_load_all_streams_async() self.streams.clear() for data in all_streams_data_list: stream = ChatStream.from_dict(data) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 8219ee761..c362187e2 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -41,7 +41,7 @@ class MessageStorage: processed_plain_text = message.processed_plain_text if processed_plain_text: - processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text) + processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL) else: filtered_processed_plain_text = "" @@ -128,9 +128,9 @@ class MessageStorage: key_words=key_words, key_words_lite=key_words_lite, ) - with get_db_session() as session: + async with get_db_session() as session: session.add(new_message) - session.commit() + await session.commit() except Exception: logger.exception("存储消息失败") @@ -173,16 +173,18 @@ class MessageStorage: # 使用上下文管理器确保session正确管理 from src.common.database.sqlalchemy_models import get_db_session - with get_db_session() as session: - matched_message = session.execute( - select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) + async with get_db_session() as session: + matched_message = ( + await session.execute( + select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) + ) ).scalar() if matched_message: - session.execute( + await session.execute( update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) ) - session.commit() + await session.commit() # 会在上下文管理器中自动调用 logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") else: @@ -196,28 +198,36 @@ class MessageStorage: ) @staticmethod - def replace_image_descriptions(text: str) -> str: + async def replace_image_descriptions(text: str) -> str: """将[图片:描述]替换为[picid:image_id]""" # 先检查文本中是否有图片标记 pattern = r"\[图片:([^\]]+)\]" - matches = re.findall(pattern, text) + matches = list(re.finditer(pattern, text)) if not matches: logger.debug("文本中没有图片标记,直接返回原文本") return text - def replace_match(match): + new_text = "" + last_end = 0 + for match in matches: + new_text += text[last_end : match.start()] description = match.group(1).strip() try: from src.common.database.sqlalchemy_models import get_db_session - with get_db_session() as session: - image_record = session.execute( - select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) + async with get_db_session() as session: + image_record = ( + await session.execute( + select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) + ) ).scalar() - session.commit() - return f"[picid:{image_record.image_id}]" if image_record else match.group(0) + if image_record: + new_text += f"[picid:{image_record.image_id}]" + else: + new_text += match.group(0) except Exception: - return match.group(0) - - return re.sub(r"\[图片:([^\]]+)\]", replace_match, text) + new_text += match.group(0) + last_end = match.end() + new_text += text[last_end:] + return new_text diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index e9cc1d106..a061c15ae 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -97,12 +97,12 @@ class ActionModifier: for action_name, reason in chat_type_removals: logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}") - message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat( chat_id=self.chat_stream.stream_id, timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 10), ) - chat_content = build_readable_messages( + chat_content = await build_readable_messages( message_list_before_now_half, replace_bot_name=True, merge_messages=False, diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index 53e1e4a80..2d05f0511 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -152,7 +152,7 @@ class PlanFilter: ) return prompt, message_id_list - chat_content_block, message_id_list = build_readable_messages_with_id( + chat_content_block, message_id_list = await build_readable_messages_with_id( messages=[msg.flatten() for msg in plan.chat_history], timestamp_mode="normal", read_mark=self.last_obs_time_mark, @@ -167,7 +167,7 @@ class PlanFilter: limit=5, ) - actions_before_now_block = build_readable_actions(actions=actions_before_now) + actions_before_now_block = build_readable_actions(actions=await actions_before_now) actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" self.last_obs_time_mark = time.time() diff --git a/src/chat/planner_actions/plan_generator.py b/src/chat/planner_actions/plan_generator.py index 5dd1b680c..96af31c4b 100644 --- a/src/chat/planner_actions/plan_generator.py +++ b/src/chat/planner_actions/plan_generator.py @@ -63,7 +63,7 @@ class PlanGenerator: timestamp=time.time(), limit=int(global_config.chat.max_context_size), ) - chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw] + chat_history = [DatabaseMessages(**msg) for msg in await chat_history_raw] plan = Plan( diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index d2a8eb850..705924c73 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -828,7 +828,8 @@ class DefaultReplyer: platform, # type: ignore reply_message.get("user_id"), # type: ignore ) - person_name = await person_info_manager.get_value(person_id, "person_name") + person_info = await person_info_manager.get_values(person_id, ["person_name", "user_id"]) + person_name = person_info.get("person_name") # 如果person_name为None,使用fallback值 if person_name is None: @@ -839,7 +840,7 @@ class DefaultReplyer: # 检查是否是bot自己的名字,如果是则替换为"(你)" bot_user_id = str(global_config.bot.qq_account) - current_user_id = person_info_manager.get_value_sync(person_id, "user_id") + current_user_id = person_info.get("user_id") current_platform = reply_message.get("chat_info_platform") if current_user_id == bot_user_id and current_platform == global_config.bot.platform: @@ -872,18 +873,18 @@ class DefaultReplyer: action_descriptions += f"- {action_name}: {action_description}\n" action_descriptions += "\n" - message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size * 2, ) - message_list_before_short = get_raw_msg_before_timestamp_with_chat( + message_list_before_short = await get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.33), ) - chat_talking_prompt_short = build_readable_messages( + chat_talking_prompt_short = await build_readable_messages( message_list_before_short, replace_bot_name=True, merge_messages=False, @@ -895,7 +896,7 @@ class DefaultReplyer: # 获取目标用户信息,用于s4u模式 target_user_info = None if sender: - target_user_info = await person_info_manager.get_person_info_by_name(sender) + target_user_info = person_info_manager.get_person_info_by_name(sender) from src.chat.utils.prompt import Prompt # 并行执行六个构建任务 @@ -1122,12 +1123,12 @@ class DefaultReplyer: else: mood_prompt = "" - message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 15), ) - chat_talking_prompt_half = build_readable_messages( + chat_talking_prompt_half = await build_readable_messages( message_list_before_now_half, replace_bot_name=True, merge_messages=False, diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 83b1b0587..e49c218c4 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -121,7 +121,8 @@ async def replace_user_references_async( if replace_bot_name and user_id == global_config.bot.qq_account: return f"{global_config.bot.nickname}(你)" person_id = PersonInfoManager.get_person_id(platform, user_id) - return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore + person_info = await person_info_manager.get_values(person_id, ["person_name"]) + return person_info.get("person_name") or user_id name_resolver = default_resolver @@ -169,7 +170,7 @@ async def replace_user_references_async( return content -def get_raw_msg_by_timestamp( +async def get_raw_msg_by_timestamp( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """ @@ -180,10 +181,10 @@ def get_raw_msg_by_timestamp( filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}} # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -def get_raw_msg_by_timestamp_with_chat( +async def get_raw_msg_by_timestamp_with_chat( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -200,7 +201,7 @@ def get_raw_msg_by_timestamp_with_chat( # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None # 直接将 limit_mode 传递给 find_messages - return find_messages( + return await find_messages( message_filter=filter_query, sort=sort_order, limit=limit, @@ -210,7 +211,7 @@ def get_raw_msg_by_timestamp_with_chat( ) -def get_raw_msg_by_timestamp_with_chat_inclusive( +async def get_raw_msg_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -227,12 +228,12 @@ def get_raw_msg_by_timestamp_with_chat_inclusive( sort_order = [("time", 1)] if limit == 0 else None # 直接将 limit_mode 传递给 find_messages - return find_messages( + return await find_messages( message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot ) -def get_raw_msg_by_timestamp_with_chat_users( +async def get_raw_msg_by_timestamp_with_chat_users( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -251,10 +252,10 @@ def get_raw_msg_by_timestamp_with_chat_users( } # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -def get_actions_by_timestamp_with_chat( +async def get_actions_by_timestamp_with_chat( chat_id: str, timestamp_start: float = 0, timestamp_end: float = time.time(), @@ -273,10 +274,10 @@ def get_actions_by_timestamp_with_chat( f"limit={limit}, limit_mode={limit_mode}" ) - with get_db_session() as session: + async with get_db_session() as session: if limit > 0: if limit_mode == "latest": - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -306,7 +307,7 @@ def get_actions_by_timestamp_with_chat( } actions_result.append(action_dict) else: # earliest - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -336,7 +337,7 @@ def get_actions_by_timestamp_with_chat( } actions_result.append(action_dict) else: - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -367,14 +368,14 @@ def get_actions_by_timestamp_with_chat( return actions_result -def get_actions_by_timestamp_with_chat_inclusive( +async def get_actions_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" - with get_db_session() as session: + async with get_db_session() as session: if limit > 0: if limit_mode == "latest": - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -389,7 +390,7 @@ def get_actions_by_timestamp_with_chat_inclusive( actions = list(query.scalars()) return [action.__dict__ for action in reversed(actions)] else: # earliest - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -402,7 +403,7 @@ def get_actions_by_timestamp_with_chat_inclusive( .limit(limit) ) else: - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -418,14 +419,14 @@ def get_actions_by_timestamp_with_chat_inclusive( return [action.__dict__ for action in actions] -def get_raw_msg_by_timestamp_random( +async def get_raw_msg_by_timestamp_random( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """ 先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息 """ # 获取所有消息,只取chat_id字段 - all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end) + all_msgs = await get_raw_msg_by_timestamp(timestamp_start, timestamp_end) if not all_msgs: return [] # 随机选一条 @@ -433,10 +434,10 @@ def get_raw_msg_by_timestamp_random( chat_id = msg["chat_id"] timestamp_start = msg["time"] # 用 chat_id 获取该聊天在指定时间戳范围内的消息 - return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") + return await get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") -def get_raw_msg_by_timestamp_with_users( +async def get_raw_msg_by_timestamp_with_users( timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 @@ -446,37 +447,39 @@ def get_raw_msg_by_timestamp_with_users( filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}} # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"time": {"$lt": timestamp}} sort_order = [("time", 1)] - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}} sort_order = [("time", 1)] - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp_with_users( + timestamp: float, person_ids: list, limit: int = 0 +) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}} sort_order = [("time", 1)] - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: +async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: """ 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 @@ -490,10 +493,10 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp return 0 # 起始时间大于等于结束时间,没有新消息 filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}} - return count_messages(message_filter=filter_query) + return await count_messages(message_filter=filter_query) -def num_new_messages_since_with_users( +async def num_new_messages_since_with_users( chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list ) -> int: """检查某些特定用户在特定聊天在指定时间戳之间有多少新消息""" @@ -504,10 +507,10 @@ def num_new_messages_since_with_users( "time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}, } - return count_messages(message_filter=filter_query) + return await count_messages(message_filter=filter_query) -def _build_readable_messages_internal( +async def _build_readable_messages_internal( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -627,7 +630,8 @@ def _build_readable_messages_internal( if replace_bot_name and user_id == global_config.bot.qq_account: person_name = f"{global_config.bot.nickname}(你)" else: - person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore + person_info = await person_info_manager.get_values(person_id, ["person_name"]) + person_name = person_info.get("person_name") # type: ignore # 如果 person_name 未设置,则使用消息中的 nickname 或默认名称 if not person_name: @@ -796,7 +800,7 @@ def _build_readable_messages_internal( ) -def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: +async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: # sourcery skip: use-contextlib-suppress """ 构建图片映射信息字符串,显示图片的具体描述内容 @@ -819,9 +823,9 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: # 从数据库中获取图片描述 description = "[图片内容未知]" # 默认描述 try: - with get_db_session() as session: - image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none() - if image and image.description: # type: ignore + async with get_db_session() as session: + image = (await session.execute(select(Images).where(Images.image_id == pic_id))).scalar_one_or_none() + if image and image.description: # type: ignore description = image.description except Exception: # 如果查询失败,保持默认描述 @@ -917,17 +921,17 @@ async def build_readable_messages_with_list( 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 允许通过参数控制格式化行为。 """ - formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( + formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal( messages, replace_bot_name, merge_messages, timestamp_mode, truncate ) - if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): + if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping): formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" return formatted_string, details_list -def build_readable_messages_with_id( +async def build_readable_messages_with_id( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -943,7 +947,7 @@ def build_readable_messages_with_id( """ message_id_list = assign_message_ids(messages) - formatted_string = build_readable_messages( + formatted_string = await build_readable_messages( messages=messages, replace_bot_name=replace_bot_name, merge_messages=merge_messages, @@ -958,7 +962,7 @@ def build_readable_messages_with_id( return formatted_string, message_id_list -def build_readable_messages( +async def build_readable_messages( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -999,24 +1003,28 @@ def build_readable_messages( from src.common.database.sqlalchemy_database_api import get_db_session - with get_db_session() as session: + async with get_db_session() as session: # 获取这个时间范围内的动作记录,并匹配chat_id - actions_in_range = session.execute( - select(ActionRecords) - .where( - and_( - ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id + actions_in_range = ( + await session.execute( + select(ActionRecords) + .where( + and_( + ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id + ) ) + .order_by(ActionRecords.time) ) - .order_by(ActionRecords.time) ).scalars() # 获取最新消息之后的第一个动作记录 - action_after_latest = session.execute( - select(ActionRecords) - .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) - .order_by(ActionRecords.time) - .limit(1) + action_after_latest = ( + await session.execute( + select(ActionRecords) + .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) + .order_by(ActionRecords.time) + .limit(1) + ) ).scalars() # 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError @@ -1048,7 +1056,7 @@ def build_readable_messages( if read_mark <= 0: # 没有有效的 read_mark,直接格式化所有消息 - formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal( + formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal( copy_messages, replace_bot_name, merge_messages, @@ -1059,7 +1067,7 @@ def build_readable_messages( ) # 生成图片映射信息并添加到最前面 - pic_mapping_info = build_pic_mapping_info(pic_id_mapping) + pic_mapping_info = await build_pic_mapping_info(pic_id_mapping) if pic_mapping_info: return f"{pic_mapping_info}\n\n{formatted_string}" else: @@ -1074,7 +1082,7 @@ def build_readable_messages( pic_counter = 1 # 分别格式化,但使用共享的图片映射 - formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal( + formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal( messages_before_mark, replace_bot_name, merge_messages, @@ -1085,7 +1093,7 @@ def build_readable_messages( show_pic=show_pic, message_id_list=message_id_list, ) - formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal( + formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal( messages_after_mark, replace_bot_name, merge_messages, @@ -1101,7 +1109,7 @@ def build_readable_messages( # 生成图片映射信息 if pic_id_mapping: - pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n" + pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n" else: pic_mapping_info = "聊天记录信息:\n" diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 78e856f39..63d4c000d 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -25,7 +25,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]: return {col.name: getattr(instance, col.name) for col in instance.__table__.columns} -def find_messages( +async def find_messages( message_filter: dict[str, Any], sort: Optional[List[tuple[str, int]]] = None, limit: int = 0, @@ -46,7 +46,7 @@ def find_messages( 消息字典列表,如果出错则返回空列表。 """ try: - with get_db_session() as session: + async with get_db_session() as session: query = select(Messages) # 应用过滤器 @@ -96,7 +96,7 @@ def find_messages( # 获取时间最早的 limit 条记录,已经是正序 query = query.order_by(Messages.time.asc()).limit(limit) try: - results = session.execute(query).scalars().all() + results = (await session.execute(query)).scalars().all() except Exception as e: logger.error(f"执行earliest查询失败: {e}") results = [] @@ -104,7 +104,7 @@ def find_messages( # 获取时间最晚的 limit 条记录 query = query.order_by(Messages.time.desc()).limit(limit) try: - latest_results = session.execute(query).scalars().all() + latest_results = (await session.execute(query)).scalars().all() # 将结果按时间正序排列 results = sorted(latest_results, key=lambda msg: msg.time) except Exception as e: @@ -128,12 +128,12 @@ def find_messages( if sort_terms: query = query.order_by(*sort_terms) try: - results = session.execute(query).scalars().all() + results = (await session.execute(query)).scalars().all() except Exception as e: logger.error(f"执行无限制查询失败: {e}") results = [] - return [_model_to_dict(msg) for msg in results] + return [_model_to_dict(msg) for msg in results] except Exception as e: log_message = ( f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" @@ -143,7 +143,7 @@ def find_messages( return [] -def count_messages(message_filter: dict[str, Any]) -> int: +async def count_messages(message_filter: dict[str, Any]) -> int: """ 根据提供的过滤器计算消息数量。 @@ -154,7 +154,7 @@ def count_messages(message_filter: dict[str, Any]) -> int: 符合条件的消息数量,如果出错则返回 0。 """ try: - with get_db_session() as session: + async with get_db_session() as session: query = select(func.count(Messages.id)) # 应用过滤器 @@ -192,7 +192,7 @@ def count_messages(message_filter: dict[str, Any]) -> int: if conditions: query = query.where(*conditions) - count = session.execute(query).scalar() + count = (await session.execute(query)).scalar() return count or 0 except Exception as e: log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" diff --git a/src/main.py b/src/main.py index 103884c10..d37a5b630 100644 --- a/src/main.py +++ b/src/main.py @@ -40,6 +40,9 @@ if not global_config.memory.enable_memory: def initialize(self): pass + async def initialize_async(self): + pass + def get_hippocampus(self): return None @@ -248,7 +251,7 @@ MoFox_Bot(第三方修改版) logger.info("聊天管理器初始化成功") # 初始化记忆系统 - self.hippocampus_manager.initialize() + await self.hippocampus_manager.initialize_async() logger.info("记忆系统初始化成功") # 初始化LPMM知识库 @@ -283,7 +286,7 @@ MoFox_Bot(第三方修改版) if global_config.planning_system.monthly_plan_enable: logger.info("正在初始化月度计划管理器...") try: - await monthly_plan_manager.start_monthly_plan_generation() + await monthly_plan_manager.initialize() logger.info("月度计划管理器初始化成功") except Exception as e: logger.error(f"月度计划管理器初始化失败: {e}") @@ -291,8 +294,7 @@ MoFox_Bot(第三方修改版) # 初始化日程管理器 if global_config.planning_system.schedule_enable: logger.info("日程表功能已启用,正在初始化管理器...") - await schedule_manager.load_or_generate_today_schedule() - await schedule_manager.start_daily_schedule_generation() + await schedule_manager.initialize() logger.info("日程表管理器初始化成功。") try: diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 2f89c43ff..95dc41cfb 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -73,14 +73,15 @@ class PersonInfoManager: # # 初始化时读取所有person_name try: + pass # 在这里获取会话 - with get_db_session() as session: - for record in session.execute( - select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None)) - ).fetchall(): - if record.person_name: - self.person_name_list[record.person_id] = record.person_name - logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)") + # with get_db_session() as session: + # for record in session.execute( + # select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None)) + # ).fetchall(): + # if record.person_name: + # self.person_name_list[record.person_id] = record.person_name + # logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)") except Exception as e: logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}") @@ -102,23 +103,25 @@ class PersonInfoManager: """判断是否认识某人""" person_id = self.get_person_id(platform, user_id) - def _db_check_known_sync(p_id: str): + async def _db_check_known_async(p_id: str): # 在需要时获取会话 - with get_db_session() as session: - return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None + async with get_db_session() as session: + return ( + await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)) + ).scalar() is not None try: - return await asyncio.to_thread(_db_check_known_sync, person_id) + return await _db_check_known_async(person_id) except Exception as e: logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}") return False - def get_person_id_by_person_name(self, person_name: str) -> str: + async def get_person_id_by_person_name(self, person_name: str) -> str: """根据用户名获取用户ID""" try: # 在需要时获取会话 - with get_db_session() as session: - record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar() + async with get_db_session() as session: + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))).scalar() return record.person_id if record else "" except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}") @@ -172,19 +175,18 @@ class PersonInfoManager: final_data[key] = orjson.dumps([]).decode("utf-8") # If it's already a string, assume it's valid JSON or a non-JSON string field - def _db_create_sync(p_data: dict): - with get_db_session() as session: + async def _db_create_async(p_data: dict): + async with get_db_session() as session: try: new_person = PersonInfo(**p_data) session.add(new_person) - session.commit() - + await session.commit() return True except Exception as e: logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") return False - await asyncio.to_thread(_db_create_sync, final_data) + await _db_create_async(final_data) async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None): """安全地创建用户信息,处理竞态条件""" @@ -229,11 +231,11 @@ class PersonInfoManager: elif final_data[key] is None: # Default for lists is [], store as "[]" final_data[key] = orjson.dumps([]).decode("utf-8") - def _db_safe_create_sync(p_data: dict): - with get_db_session() as session: + async def _db_safe_create_async(p_data: dict): + async with get_db_session() as session: try: - existing = session.execute( - select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"]) + existing = ( + await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])) ).scalar() if existing: logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建") @@ -242,18 +244,17 @@ class PersonInfoManager: # 尝试创建 new_person = PersonInfo(**p_data) session.add(new_person) - session.commit() - + await session.commit() return True except Exception as e: if "UNIQUE constraint failed" in str(e): logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误") - return True # 其他协程已创建,视为成功 + return True else: logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") return False - await asyncio.to_thread(_db_safe_create_sync, final_data) + await _db_safe_create_async(final_data) async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None): """更新某一个字段,会补全""" @@ -270,37 +271,33 @@ class PersonInfoManager: elif value is None: # Store None as "[]" for JSON list fields processed_value = orjson.dumps([]).decode("utf-8") - def _db_update_sync(p_id: str, f_name: str, val_to_set): + async def _db_update_async(p_id: str, f_name: str, val_to_set): start_time = time.time() - with get_db_session() as session: + async with get_db_session() as session: try: - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() query_time = time.time() - if record: setattr(record, f_name, val_to_set) - save_time = time.time() - total_time = save_time - start_time - if total_time > 0.5: # 如果超过500ms就记录日志 + if total_time > 0.5: logger.warning( f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}" ) - session.commit() - - return True, False # Found and updated, no creation needed + await session.commit() + return True, False else: total_time = time.time() - start_time if total_time > 0.5: logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}") - return False, True # Not found, needs creation + return False, True except Exception as e: total_time = time.time() - start_time logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}") raise - found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value) + found, needs_creation = await _db_update_async(person_id, field_name, processed_value) if needs_creation: logger.info(f"{person_id} 不存在,将新建。") @@ -338,13 +335,13 @@ class PersonInfoManager: logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义。") return False - def _db_has_field_sync(p_id: str, f_name: str): - with get_db_session() as session: - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + async def _db_has_field_async(p_id: str, f_name: str): + async with get_db_session() as session: + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() return bool(record) try: - return await asyncio.to_thread(_db_has_field_sync, person_id, field_name) + return await _db_has_field_async(person_id, field_name) except Exception as e: logger.error(f"检查字段 {field_name} for {person_id} 时出错 (SQLAlchemy): {e}") return False @@ -449,14 +446,14 @@ class PersonInfoManager: logger.info(f"尝试给用户{user_nickname} {person_id} 取名,但是 {generated_nickname} 已存在,重试中...") else: - def _db_check_name_exists_sync(name_to_check): - with get_db_session() as session: + async def _db_check_name_exists_async(name_to_check): + async with get_db_session() as session: return ( - session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() + (await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))).scalar() is not None ) - if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname): + if await _db_check_name_exists_async(generated_nickname): is_duplicate = True current_name_set.add(generated_nickname) @@ -492,91 +489,26 @@ class PersonInfoManager: logger.debug("删除失败:person_id 不能为空") return - def _db_delete_sync(p_id: str): + async def _db_delete_async(p_id: str): try: - with get_db_session() as session: - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + async with get_db_session() as session: + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() if record: - session.delete(record) - session.commit() - return 1 + await session.delete(record) + await session.commit() + return 1 return 0 except Exception as e: logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}") return 0 - deleted_count = await asyncio.to_thread(_db_delete_sync, person_id) + deleted_count = await _db_delete_async(person_id) if deleted_count > 0: - logger.debug(f"删除成功:person_id={person_id} (Peewee)") + logger.debug(f"删除成功:person_id={person_id}") else: - logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)") + logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行") - @staticmethod - async def get_value(person_id: str, field_name: str): - """获取指定用户指定字段的值""" - default_value_for_field = person_info_default.get(field_name) - if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: - default_value_for_field = [] # Ensure JSON fields default to [] if not in DB - - def _db_get_value_sync(p_id: str, f_name: str): - with get_db_session() as session: - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() - if record: - val = getattr(record, f_name, None) - if f_name in JSON_SERIALIZED_FIELDS: - if isinstance(val, str): - try: - return orjson.loads(val) - except orjson.JSONDecodeError: - logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.") - return [] # Default for JSON fields on error - elif val is None: # Field exists in DB but is None - return [] # Default for JSON fields - # If val is already a list/dict (e.g. if somehow set without serialization) - return val # Should ideally not happen if update_one_field is always used - return val - return None # Record not found - - try: - value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name) - if value_from_db is not None: - return value_from_db - if field_name in person_info_default: - return default_value_for_field - logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。") - return None # Ultimate fallback - except Exception as e: - logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}") - # Fallback to default in case of any error during DB access - return default_value_for_field if field_name in person_info_default else None - - @staticmethod - def get_value_sync(person_id: str, field_name: str): - """同步获取指定用户指定字段的值""" - default_value_for_field = person_info_default.get(field_name) - with get_db_session() as session: - if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: - default_value_for_field = [] - - if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar(): - val = getattr(record, field_name, None) - if field_name in JSON_SERIALIZED_FIELDS: - if isinstance(val, str): - try: - return orjson.loads(val) - except orjson.JSONDecodeError: - logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.") - return [] - elif val is None: - return [] - return val - return val - - if field_name in person_info_default: - return default_value_for_field - logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。") - return None @staticmethod async def get_values(person_id: str, field_names: list) -> dict: @@ -587,11 +519,11 @@ class PersonInfoManager: result = {} - def _db_get_record_sync(p_id: str): - with get_db_session() as session: - return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + async def _db_get_record_async(p_id: str): + async with get_db_session() as session: + return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() - record = await asyncio.to_thread(_db_get_record_sync, person_id) + record = await _db_get_record_async(person_id) # 获取 SQLAlchemy 模型的所有字段名 model_fields = [column.name for column in PersonInfo.__table__.columns] @@ -628,14 +560,15 @@ class PersonInfoManager: # 获取 SQLAlchemy 模型的所有字段名 model_fields = [column.name for column in PersonInfo.__table__.columns] if field_name not in model_fields: - logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模 modelo中定义") + logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模型中定义") return {} - def _db_get_specific_sync(f_name: str): + async def _db_get_specific_async(f_name: str): found_results = {} try: - with get_db_session() as session: - for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall(): + async with get_db_session() as session: + result = await session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))) + for record in result.fetchall(): value = getattr(record, f_name) if way(value): found_results[record.person_id] = value @@ -646,9 +579,9 @@ class PersonInfoManager: return found_results try: - return await asyncio.to_thread(_db_get_specific_sync, field_name) + return await _db_get_specific_async(field_name) except Exception as e: - logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True) + logger.error(f"执行 get_specific_value_list 时出错: {str(e)}", exc_info=True) return {} async def get_or_create_person( @@ -661,40 +594,38 @@ class PersonInfoManager: """ person_id = self.get_person_id(platform, user_id) - def _db_get_or_create_sync(p_id: str, init_data: dict): + async def _db_get_or_create_async(p_id: str, init_data: dict): """原子性的获取或创建操作""" - with get_db_session() as session: + async with get_db_session() as session: # 首先尝试获取现有记录 - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() if record: return record, False # 记录存在,未创建 - # 记录不存在,尝试创建 - try: - new_person = PersonInfo(**init_data) - session.add(new_person) - session.commit() - - return session.execute( - select(PersonInfo).where(PersonInfo.person_id == p_id) - ).scalar(), True # 创建成功 - except Exception as e: - # 如果创建失败(可能是因为竞态条件),再次尝试获取 - if "UNIQUE constraint failed" in str(e): - logger.debug(f"检测到并发创建用户 {p_id},获取现有记录") - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() - if record: - return record, False # 其他协程已创建,返回现有记录 - # 如果仍然失败,重新抛出异常 - raise e - + # 记录不存在,尝试创建 + try: + new_person = PersonInfo(**init_data) + session.add(new_person) + await session.commit() + await session.refresh(new_person) + return new_person, True # 创建成功 + except Exception as e: + # 如果创建失败(可能是因为竞态条件),再次尝试获取 + if "UNIQUE constraint failed" in str(e): + logger.debug(f"检测到并发创建用户 {p_id},获取现有记录") + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() + if record: + return record, False # 其他协程已创建,返回现有记录 + # 如果仍然失败,重新抛出异常 + raise e + unique_nickname = await self._generate_unique_person_name(nickname) initial_data = { "person_id": person_id, "platform": platform, "user_id": str(user_id), "nickname": nickname, - "person_name": unique_nickname, # 使用群昵称作为person_name + "person_name": unique_nickname, "name_reason": "从群昵称获取", "know_times": 0, "know_since": int(datetime.datetime.now().timestamp()), @@ -704,7 +635,6 @@ class PersonInfoManager: "forgotten_points": [], } - # 序列化JSON字段 for key in JSON_SERIALIZED_FIELDS: if key in initial_data: if isinstance(initial_data[key], (list, dict)): @@ -712,15 +642,14 @@ class PersonInfoManager: elif initial_data[key] is None: initial_data[key] = orjson.dumps([]).decode("utf-8") - # 获取 SQLAlchemy 模odel的所有字段名 model_fields = [column.name for column in PersonInfo.__table__.columns] filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} - record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data) + record, was_created = await _db_get_or_create_async(person_id, filtered_initial_data) if was_created: - logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。") - logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") + logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。") + logger.info(f"已为 {person_id} 创建新记录,初始数据: {filtered_initial_data}") else: logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。") @@ -740,11 +669,13 @@ class PersonInfoManager: if not found_person_id: - def _db_find_by_name_sync(p_name_to_find: str): - with get_db_session() as session: - return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar() + async def _db_find_by_name_async(p_name_to_find: str): + async with get_db_session() as session: + return ( + await session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)) + ).scalar() - record = await asyncio.to_thread(_db_find_by_name_sync, person_name) + record = await _db_find_by_name_async(person_name) if record: found_person_id = record.person_id if ( diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 1c62dec1a..e903915a7 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -99,16 +99,18 @@ class RelationshipFetcher: self._cleanup_expired_cache() person_info_manager = get_person_info_manager() - person_name = await person_info_manager.get_value(person_id, "person_name") - short_impression = await person_info_manager.get_value(person_id, "short_impression") - - nickname_str = await person_info_manager.get_value(person_id, "nickname") - platform = await person_info_manager.get_value(person_id, "platform") + person_info = await person_info_manager.get_values( + person_id, ["person_name", "short_impression", "nickname", "platform", "points"] + ) + person_name = person_info.get("person_name") + short_impression = person_info.get("short_impression") + nickname_str = person_info.get("nickname") + platform = person_info.get("platform") if person_name == nickname_str and not short_impression: return "" - current_points = await person_info_manager.get_value(person_id, "points") or [] + current_points = person_info.get("points") or [] # 按时间排序forgotten_points current_points.sort(key=lambda x: x[2]) @@ -170,7 +172,8 @@ class RelationshipFetcher: nickname_str = ",".join(global_config.bot.alias_names) name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" person_info_manager = get_person_info_manager() - person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore + person_info = await person_info_manager.get_values(person_id, ["person_name"]) + person_name: str = person_info.get("person_name") # type: ignore info_cache_block = self._build_info_cache_block() @@ -252,7 +255,8 @@ class RelationshipFetcher: person_info_manager = get_person_info_manager() # 首先检查 info_list 缓存 - info_list = await person_info_manager.get_value(person_id, "info_list") or [] + person_info = await person_info_manager.get_values(person_id, ["info_list"]) + info_list = person_info.get("info_list") or [] cached_info = None # 查找对应的 info_type @@ -279,8 +283,9 @@ class RelationshipFetcher: # 如果缓存中没有,尝试从用户档案中提取 try: - person_impression = await person_info_manager.get_value(person_id, "impression") - points = await person_info_manager.get_value(person_id, "points") + person_info = await person_info_manager.get_values(person_id, ["impression", "points"]) + person_impression = person_info.get("impression") + points = person_info.get("points") # 构建印象信息块 if person_impression: @@ -372,7 +377,8 @@ class RelationshipFetcher: person_info_manager = get_person_info_manager() # 获取现有的 info_list - info_list = await person_info_manager.get_value(person_id, "info_list") or [] + person_info = await person_info_manager.get_values(person_id, ["info_list"]) + info_list = person_info.get("info_list") or [] # 查找是否已存在相同 info_type 的记录 found_index = -1 diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 98fab2342..3d161b847 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -62,7 +62,7 @@ def get_messages_by_time( return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode) -def get_messages_by_time_in_chat( +async def get_messages_by_time_in_chat( chat_id: str, start_time: float, end_time: float, @@ -97,13 +97,13 @@ def get_messages_by_time_in_chat( if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") if filter_mai: - return filter_mai_messages( - get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command) + return await filter_mai_messages( + await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command) ) - return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command) + return await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command) -def get_messages_by_time_in_chat_inclusive( +async def get_messages_by_time_in_chat_inclusive( chat_id: str, start_time: float, end_time: float, @@ -138,12 +138,12 @@ def get_messages_by_time_in_chat_inclusive( if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") if filter_mai: - return filter_mai_messages( - get_raw_msg_by_timestamp_with_chat_inclusive( + return await filter_mai_messages( + await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id, start_time, end_time, limit, limit_mode, filter_command ) ) - return get_raw_msg_by_timestamp_with_chat_inclusive( + return await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id, start_time, end_time, limit, limit_mode, filter_command ) @@ -478,7 +478,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s # ============================================================================= -def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +async def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 从消息列表中移除麦麦的消息 Args: diff --git a/src/schedule/database.py b/src/schedule/database.py index 88337f4df..5025c1fa3 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -1,6 +1,7 @@ # mmc/src/schedule/database.py from typing import List +from sqlalchemy import select, func, update, delete from src.common.database.sqlalchemy_models import MonthlyPlan, get_db_session from src.common.logger import get_logger from src.config.config import global_config @@ -8,21 +9,22 @@ from src.config.config import global_config logger = get_logger("schedule_database") -def add_new_plans(plans: List[str], month: str): +async def add_new_plans(plans: List[str], month: str): """ 批量添加新生成的月度计划到数据库,并确保不超过上限。 :param plans: 计划内容列表。 :param month: 目标月份,格式为 "YYYY-MM"。 """ - with get_db_session() as session: + async with get_db_session() as session: try: # 1. 获取当前有效计划数量(状态为 'active') - current_plan_count = ( - session.query(MonthlyPlan) - .filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") - .count() + result = await session.execute( + select(func.count(MonthlyPlan.id)).where( + MonthlyPlan.target_month == month, MonthlyPlan.status == "active" + ) ) + current_plan_count = result.scalar_one() # 2. 从配置获取上限 max_plans = global_config.planning_system.max_plans_per_month @@ -41,7 +43,7 @@ def add_new_plans(plans: List[str], month: str): MonthlyPlan(plan_text=plan, target_month=month, status="active") for plan in plans_to_add ] session.add_all(new_plan_objects) - session.commit() + await session.commit() logger.info(f"成功向数据库添加了 {len(new_plan_objects)} 条 {month} 的月度计划。") if len(plans) > len(plans_to_add): @@ -49,32 +51,31 @@ def add_new_plans(plans: List[str], month: str): except Exception as e: logger.error(f"添加月度计划时发生错误: {e}") - session.rollback() + await session.rollback() raise -def get_active_plans_for_month(month: str) -> List[MonthlyPlan]: +async def get_active_plans_for_month(month: str) -> List[MonthlyPlan]: """ 获取指定月份所有状态为 'active' 的计划。 :param month: 目标月份,格式为 "YYYY-MM"。 :return: MonthlyPlan 对象列表。 """ - with get_db_session() as session: + async with get_db_session() as session: try: - plans = ( - session.query(MonthlyPlan) - .filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") + result = await session.execute( + select(MonthlyPlan) + .where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") .order_by(MonthlyPlan.created_at.desc()) - .all() ) - return plans + return result.scalars().all() except Exception as e: logger.error(f"查询 {month} 的有效月度计划时发生错误: {e}") return [] -def mark_plans_completed(plan_ids: List[int]): +async def mark_plans_completed(plan_ids: List[int]): """ 将指定ID的计划标记为已完成。 @@ -83,9 +84,10 @@ def mark_plans_completed(plan_ids: List[int]): if not plan_ids: return - with get_db_session() as session: + async with get_db_session() as session: try: - plans_to_mark = session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).all() + result = await session.execute(select(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids))) + plans_to_mark = result.scalars().all() if not plans_to_mark: logger.info("没有需要标记为完成的月度计划。") return @@ -93,17 +95,17 @@ def mark_plans_completed(plan_ids: List[int]): plan_details = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans_to_mark)]) logger.info(f"以下 {len(plans_to_mark)} 条月度计划将被标记为已完成:\n{plan_details}") - session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update( - {"status": "completed"}, synchronize_session=False + await session.execute( + update(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)).values(status="completed") ) - session.commit() + await session.commit() except Exception as e: logger.error(f"标记月度计划为完成时发生错误: {e}") - session.rollback() + await session.rollback() raise -def delete_plans_by_ids(plan_ids: List[int]): +async def delete_plans_by_ids(plan_ids: List[int]): """ 根据ID列表从数据库中物理删除月度计划。 @@ -112,10 +114,11 @@ def delete_plans_by_ids(plan_ids: List[int]): if not plan_ids: return - with get_db_session() as session: + async with get_db_session() as session: try: # 先查询要删除的计划,用于日志记录 - plans_to_delete = session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).all() + result = await session.execute(select(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids))) + plans_to_delete = result.scalars().all() if not plans_to_delete: logger.info("没有找到需要删除的月度计划。") return @@ -124,16 +127,16 @@ def delete_plans_by_ids(plan_ids: List[int]): logger.info(f"检测到月度计划超额,将删除以下 {len(plans_to_delete)} 条计划:\n{plan_details}") # 执行删除 - session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).delete(synchronize_session=False) - session.commit() + await session.execute(delete(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids))) + await session.commit() except Exception as e: logger.error(f"删除月度计划时发生错误: {e}") - session.rollback() + await session.rollback() raise -def update_plan_usage(plan_ids: List[int], used_date: str): +async def update_plan_usage(plan_ids: List[int], used_date: str): """ 更新计划的使用统计信息。 @@ -143,44 +146,47 @@ def update_plan_usage(plan_ids: List[int], used_date: str): if not plan_ids: return - with get_db_session() as session: + async with get_db_session() as session: try: # 获取完成阈值配置,如果不存在则使用默认值 completion_threshold = getattr(global_config.planning_system, "completion_threshold", 3) # 批量更新使用次数和最后使用日期 - session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update( - {"usage_count": MonthlyPlan.usage_count + 1, "last_used_date": used_date}, synchronize_session=False + await session.execute( + update(MonthlyPlan) + .where(MonthlyPlan.id.in_(plan_ids)) + .values(usage_count=MonthlyPlan.usage_count + 1, last_used_date=used_date) ) # 检查是否有计划达到完成阈值 - plans_to_complete = ( - session.query(MonthlyPlan) - .filter( + result = await session.execute( + select(MonthlyPlan).where( MonthlyPlan.id.in_(plan_ids), MonthlyPlan.usage_count >= completion_threshold, MonthlyPlan.status == "active", ) - .all() ) + plans_to_complete = result.scalars().all() if plans_to_complete: completed_ids = [plan.id for plan in plans_to_complete] - session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(completed_ids)).update( - {"status": "completed"}, synchronize_session=False + await session.execute( + update(MonthlyPlan).where(MonthlyPlan.id.in_(completed_ids)).values(status="completed") ) logger.info(f"计划 {completed_ids} 已达到使用阈值 ({completion_threshold}),标记为已完成。") - session.commit() + await session.commit() logger.info(f"成功更新了 {len(plan_ids)} 条月度计划的使用统计。") except Exception as e: logger.error(f"更新月度计划使用统计时发生错误: {e}") - session.rollback() + await session.rollback() raise -def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> List[MonthlyPlan]: +async def get_smart_plans_for_daily_schedule( + month: str, max_count: int = 3, avoid_days: int = 7 +) -> List[MonthlyPlan]: """ 智能抽取月度计划用于每日日程生成。 @@ -196,19 +202,24 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day """ from datetime import datetime, timedelta - with get_db_session() as session: + async with get_db_session() as session: try: # 计算避免重复的日期阈值 avoid_date = (datetime.now() - timedelta(days=avoid_days)).strftime("%Y-%m-%d") # 查询符合条件的计划 - query = session.query(MonthlyPlan).filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") + query = select(MonthlyPlan).where( + MonthlyPlan.target_month == month, MonthlyPlan.status == "active" + ) # 排除最近使用过的计划 - query = query.filter((MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date)) + query = query.where( + (MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date) + ) # 按使用次数升序排列,优先选择使用次数少的 - plans = query.order_by(MonthlyPlan.usage_count.asc()).all() + result = await session.execute(query.order_by(MonthlyPlan.usage_count.asc())) + plans = result.scalars().all() if not plans: logger.info(f"没有找到符合条件的 {month} 月度计划。") @@ -228,31 +239,31 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day return [] -def archive_active_plans_for_month(month: str): +async def archive_active_plans_for_month(month: str): """ 将指定月份所有状态为 'active' 的计划归档为 'archived'。 通常在月底调用。 :param month: 目标月份,格式为 "YYYY-MM"。 """ - with get_db_session() as session: + async with get_db_session() as session: try: - updated_count = ( - session.query(MonthlyPlan) - .filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") - .update({"status": "archived"}, synchronize_session=False) + result = await session.execute( + update(MonthlyPlan) + .where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") + .values(status="archived") ) - - session.commit() + updated_count = result.rowcount + await session.commit() logger.info(f"成功将 {updated_count} 条 {month} 的活跃月度计划归档。") return updated_count except Exception as e: logger.error(f"归档 {month} 的月度计划时发生错误: {e}") - session.rollback() + await session.rollback() raise -def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]: +async def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]: """ 获取指定月份所有状态为 'archived' 的计划。 用于生成下个月计划时的参考。 @@ -260,34 +271,34 @@ def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]: :param month: 目标月份,格式为 "YYYY-MM"。 :return: MonthlyPlan 对象列表。 """ - with get_db_session() as session: + async with get_db_session() as session: try: - plans = ( - session.query(MonthlyPlan) - .filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "archived") - .all() + result = await session.execute( + select(MonthlyPlan).where( + MonthlyPlan.target_month == month, MonthlyPlan.status == "archived" + ) ) - return plans + return result.scalars().all() except Exception as e: logger.error(f"查询 {month} 的归档月度计划时发生错误: {e}") return [] -def has_active_plans(month: str) -> bool: +async def has_active_plans(month: str) -> bool: """ 检查指定月份是否存在任何状态为 'active' 的计划。 :param month: 目标月份,格式为 "YYYY-MM"。 :return: 如果存在则返回 True,否则返回 False。 """ - with get_db_session() as session: + async with get_db_session() as session: try: - count = ( - session.query(MonthlyPlan) - .filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") - .count() + result = await session.execute( + select(func.count(MonthlyPlan.id)).where( + MonthlyPlan.target_month == month, MonthlyPlan.status == "active" + ) ) - return count > 0 + return result.scalar_one() > 0 except Exception as e: logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}") return False \ No newline at end of file diff --git a/src/schedule/monthly_plan_manager.py b/src/schedule/monthly_plan_manager.py index 1d5984ea3..7deaaf77d 100644 --- a/src/schedule/monthly_plan_manager.py +++ b/src/schedule/monthly_plan_manager.py @@ -14,6 +14,11 @@ class MonthlyPlanManager: self.plan_manager = PlanManager() self.monthly_task_started = False + async def initialize(self): + logger.info("正在初始化月度计划管理器...") + await self.start_monthly_plan_generation() + logger.info("月度计划管理器初始化成功") + async def start_monthly_plan_generation(self): if not self.monthly_task_started: logger.info(" 正在启动每月月度计划生成任务...") diff --git a/src/schedule/plan_manager.py b/src/schedule/plan_manager.py index 0fae5c381..b84a37b72 100644 --- a/src/schedule/plan_manager.py +++ b/src/schedule/plan_manager.py @@ -28,20 +28,20 @@ class PlanManager: if target_month is None: target_month = datetime.now().strftime("%Y-%m") - if not has_active_plans(target_month): + if not await has_active_plans(target_month): logger.info(f" {target_month} 没有任何有效的月度计划,将触发同步生成。") generation_successful = await self._generate_monthly_plans_logic(target_month) return generation_successful else: logger.info(f"{target_month} 已存在有效的月度计划。") - plans = get_active_plans_for_month(target_month) + plans = await get_active_plans_for_month(target_month) max_plans = global_config.planning_system.max_plans_per_month if len(plans) > max_plans: logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。") plans_to_delete = plans[: len(plans) - max_plans] delete_ids = [p.id for p in plans_to_delete] - delete_plans_by_ids(delete_ids) # type: ignore - plans = get_active_plans_for_month(target_month) + await delete_plans_by_ids(delete_ids) # type: ignore + plans = await get_active_plans_for_month(target_month) if plans: plan_texts = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans)]) @@ -64,11 +64,11 @@ class PlanManager: return False last_month = self._get_previous_month(target_month) - archived_plans = get_archived_plans_for_month(last_month) + archived_plans = await get_archived_plans_for_month(last_month) plans = await self.llm_generator.generate_plans_with_llm(target_month, archived_plans) if plans: - add_new_plans(plans, target_month) + await add_new_plans(plans, target_month) logger.info(f"成功为 {target_month} 生成并保存了 {len(plans)} 条月度计划。") return True else: @@ -95,11 +95,11 @@ class PlanManager: if target_month is None: target_month = datetime.now().strftime("%Y-%m") logger.info(f" 开始归档 {target_month} 的活跃月度计划...") - archived_count = archive_active_plans_for_month(target_month) + archived_count = await archive_active_plans_for_month(target_month) logger.info(f" 成功归档了 {archived_count} 条 {target_month} 的月度计划。") except Exception as e: logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}") - def get_plans_for_schedule(self, month: str, max_count: int) -> List: + async def get_plans_for_schedule(self, month: str, max_count: int) -> List: avoid_days = global_config.planning_system.avoid_repetition_days - return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) \ No newline at end of file + return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) \ No newline at end of file diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index f97d7c03c..822131dec 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -23,6 +23,13 @@ class ScheduleManager: self.daily_task_started = False self.schedule_generation_running = False + async def initialize(self): + if global_config.planning_system.schedule_enable: + logger.info("日程表功能已启用,正在初始化管理器...") + await self.load_or_generate_today_schedule() + await self.start_daily_schedule_generation() + logger.info("日程表管理器初始化成功。") + async def start_daily_schedule_generation(self): if not self.daily_task_started: logger.info("正在启动每日日程生成任务...") @@ -40,7 +47,7 @@ class ScheduleManager: today_str = datetime.now().strftime("%Y-%m-%d") try: - schedule_data = self._load_schedule_from_db(today_str) + schedule_data = await self._load_schedule_from_db(today_str) if schedule_data: self.today_schedule = schedule_data self._log_loaded_schedule(today_str) @@ -54,9 +61,10 @@ class ScheduleManager: logger.info("尝试生成日程作为备用方案...") await self.generate_and_save_schedule() - def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]: - with get_db_session() as session: - schedule_record = session.query(Schedule).filter(Schedule.date == date_str).first() + async def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]: + async with get_db_session() as session: + result = await session.execute(select(Schedule).filter(Schedule.date == date_str)) + schedule_record = result.scalars().first() if schedule_record: logger.info(f"从数据库加载今天的日程 ({date_str})。") schedule_data = orjson.loads(str(schedule_record.schedule_data)) @@ -90,35 +98,35 @@ class ScheduleManager: sampled_plans = [] if global_config.planning_system.monthly_plan_enable: await self.plan_manager.ensure_and_generate_plans_if_needed(current_month_str) - sampled_plans = self.plan_manager.get_plans_for_schedule(current_month_str, max_count=3) + sampled_plans = await self.plan_manager.get_plans_for_schedule(current_month_str, max_count=3) schedule_data = await self.llm_generator.generate_schedule_with_llm(sampled_plans) if schedule_data: - self._save_schedule_to_db(today_str, schedule_data) + await self._save_schedule_to_db(today_str, schedule_data) self.today_schedule = schedule_data self._log_generated_schedule(today_str, schedule_data) if sampled_plans: used_plan_ids = [plan.id for plan in sampled_plans] logger.info(f"更新使用过的月度计划 {used_plan_ids} 的统计信息。") - update_plan_usage(used_plan_ids, today_str) + await update_plan_usage(used_plan_ids, today_str) finally: self.schedule_generation_running = False logger.info("日程生成任务结束") - def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]): - with get_db_session() as session: + async def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]): + async with get_db_session() as session: schedule_json = orjson.dumps(schedule_data).decode("utf-8") - existing_schedule = session.query(Schedule).filter(Schedule.date == date_str).first() + result = await session.execute(select(Schedule).filter(Schedule.date == date_str)) + existing_schedule = result.scalars().first() if existing_schedule: - session.query(Schedule).filter(Schedule.date == date_str).update( - {Schedule.schedule_data: schedule_json, Schedule.updated_at: datetime.now()} - ) + existing_schedule.schedule_data = schedule_json + existing_schedule.updated_at = datetime.now() else: new_schedule = Schedule(date=date_str, schedule_data=schedule_json) session.add(new_schedule) - session.commit() + await session.commit() def _log_generated_schedule(self, date_str: str, schedule_data: List[Dict[str, Any]]): schedule_str = f"✅ 成功生成并保存今天的日程 ({date_str}):\n" From 898208f42566454c57fef3a9fd623f6351a7abd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 10:55:06 +0800 Subject: [PATCH 05/31] =?UTF-8?q?perf(methods):=20=E9=80=9A=E8=BF=87?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=E4=B8=8D=E5=BF=85=E8=A6=81=E7=9A=84=20self?= =?UTF-8?q?=20=E5=8F=82=E6=95=B0=E4=BC=98=E5=8C=96=E6=96=B9=E6=B3=95?= =?UTF-8?q?=E7=AD=BE=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在包括 chat、plugin_system、schedule 和 mais4u 在内的多个模块中,消除冗余的实例引用。此次改动将无需访问实例状态的实用函数转换为静态方法,从而提升了内存效率,并使方法依赖关系更加清晰。 --- src/__init__.py | 3 +- src/chat/antipromptinjector/anti_injector.py | 6 +- src/chat/antipromptinjector/core/detector.py | 9 +- src/chat/antipromptinjector/core/shield.py | 14 +- src/chat/antipromptinjector/counter_attack.py | 3 +- .../decision/counter_attack.py | 3 +- .../decision/decision_maker.py | 3 +- src/chat/antipromptinjector/decision_maker.py | 3 +- src/chat/antipromptinjector/detector.py | 9 +- .../management/statistics.py | 9 +- .../processors/message_processor.py | 9 +- src/chat/chat_loop/heartFC_chat.py | 5 +- src/chat/chat_loop/hfc_context.py | 16 +-- .../chat_loop/proactive/proactive_thinker.py | 17 ++- .../chat_loop/sleep_manager/sleep_manager.py | 4 +- .../chat_loop/sleep_manager/time_checker.py | 3 +- src/chat/emoji_system/emoji_history.py | 3 +- src/chat/emoji_system/emoji_manager.py | 13 +- src/chat/express/expression_learner.py | 128 ++++++++++-------- src/chat/express/expression_selector.py | 6 +- src/chat/frequency_analyzer/analyzer.py | 3 +- src/chat/frequency_analyzer/tracker.py | 3 +- .../heart_flow/heartflow_message_processor.py | 14 +- src/chat/knowledge/embedding_store.py | 9 +- src/chat/memory_system/Hippocampus.py | 4 +- .../memory_system/async_memory_optimizer.py | 9 +- src/chat/memory_system/instant_memory.py | 6 +- .../memory_system/vector_instant_memory.py | 3 +- src/chat/message_receive/bot.py | 12 +- src/chat/message_receive/chat_stream.py | 3 +- src/chat/message_receive/message.py | 22 +-- src/chat/message_receive/storage.py | 22 +-- src/chat/planner_actions/action_manager.py | 4 +- src/chat/planner_actions/action_modifier.py | 3 +- src/chat/planner_actions/plan_executor.py | 3 +- src/chat/planner_actions/plan_filter.py | 15 +- src/chat/planner_actions/plan_generator.py | 2 +- src/chat/planner_actions/planner.py | 2 - src/chat/replyer/default_generator.py | 13 +- src/chat/utils/prompt.py | 19 +-- src/chat/utils/statistic.py | 15 +- src/chat/utils/utils.py | 5 +- src/chat/utils/utils_image.py | 3 +- src/chat/utils/utils_video.py | 26 ++-- src/chat/utils/utils_video_legacy.py | 7 +- src/common/cache_manager.py | 6 +- src/common/data_models/info_data_model.py | 4 +- src/common/data_models/llm_data_model.py | 3 +- src/common/data_models/message_data_model.py | 4 +- src/common/database/database.py | 3 +- .../database/sqlalchemy_database_api.py | 13 +- src/common/database/sqlalchemy_models.py | 17 +-- src/common/server.py | 8 +- src/config/api_ada_configs.py | 8 -- src/config/config.py | 2 - src/config/config_base.py | 2 +- src/config/official_configs.py | 9 +- src/individuality/individuality.py | 3 +- src/llm_models/payload_content/message.py | 4 +- src/llm_models/utils.py | 4 +- src/llm_models/utils_model.py | 10 +- src/main.py | 68 +++++----- .../body_emotion_action_manager.py | 4 +- src/mais4u/mais4u_chat/s4u_chat.py | 21 ++- src/mais4u/mais4u_chat/s4u_mood_manager.py | 3 +- src/mais4u/mais4u_chat/s4u_msg_processor.py | 17 ++- src/mais4u/mais4u_chat/s4u_prompt.py | 18 ++- .../mais4u_chat/s4u_stream_generator.py | 3 +- src/mais4u/mais4u_chat/super_chat_manager.py | 3 +- src/mais4u/s4u_config.py | 2 +- src/manager/async_task_manager.py | 4 +- src/person_info/person_info.py | 22 +-- src/person_info/relationship_builder.py | 22 ++- src/person_info/relationship_manager.py | 9 +- src/plugin_system/apis/message_api.py | 28 ++-- src/plugin_system/apis/permission_api.py | 2 +- src/plugin_system/apis/send_api.py | 4 +- src/plugin_system/base/base_event.py | 6 +- src/plugin_system/base/plugin_base.py | 3 +- src/plugin_system/base/plus_command.py | 2 +- src/plugin_system/core/component_registry.py | 6 +- src/plugin_system/core/event_manager.py | 3 +- src/plugin_system/core/plugin_hot_reload.py | 15 +- src/plugin_system/core/plugin_manager.py | 15 +- src/plugin_system/utils/dependency_manager.py | 6 +- src/plugin_system/utils/manifest_utils.py | 4 +- .../utils/permission_decorators.py | 16 +-- .../services/content_service.py | 7 +- .../services/cookie_service.py | 3 +- .../services/image_service.py | 3 +- .../services/qzone_service.py | 9 +- .../services/reply_tracker_service.py | 3 +- .../services/scheduler_service.py | 6 +- .../maizone_refactored/utils/history_utils.py | 3 +- .../built_in/napcat_adapter_plugin/plugin.py | 5 +- .../src/message_buffer.py | 9 +- .../src/message_chunker.py | 3 +- .../src/recv_handler/message_handler.py | 33 +++-- .../src/recv_handler/meta_event_handler.py | 2 +- .../src/recv_handler/notice_handler.py | 8 +- .../napcat_adapter_plugin/src/send_handler.py | 67 ++++++--- .../built_in/permission_management/plugin.py | 3 +- .../built_in/plugin_management/plugin.py | 7 +- .../built_in/reminder_plugin/plugin.py | 3 +- src/plugins/built_in/tts_plugin/plugin.py | 3 +- .../web_search_tool/engines/bing_engine.py | 3 +- .../web_search_tool/tools/url_parser.py | 2 +- src/schedule/llm_generator.py | 6 +- src/schedule/plan_manager.py | 9 +- src/schedule/schedule_manager.py | 11 +- src/utils/message_chunker.py | 3 +- 111 files changed, 643 insertions(+), 467 deletions(-) diff --git a/src/__init__.py b/src/__init__.py index d359f56eb..bdb90be85 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -48,7 +48,8 @@ class BaseMain: """初始化基础主程序""" self.easter_egg() - def easter_egg(self): + @staticmethod + def easter_egg(): # 彩蛋 init() items = [ diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index f270759d6..751a7d87e 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -249,7 +249,8 @@ class AntiPromptInjector: await self._update_message_in_storage(message_data, modified_content) logger.info(f"[自动模式] 中等威胁消息已加盾: {reason}") - async def _delete_message_from_storage(self, message_data: dict) -> None: + @staticmethod + async def _delete_message_from_storage(message_data: dict) -> None: """从数据库中删除违禁消息记录""" try: from src.common.database.sqlalchemy_models import Messages, get_db_session @@ -274,7 +275,8 @@ class AntiPromptInjector: except Exception as e: logger.error(f"删除违禁消息记录失败: {e}") - async def _update_message_in_storage(self, message_data: dict, new_content: str) -> None: + @staticmethod + async def _update_message_in_storage(message_data: dict, new_content: str) -> None: """更新数据库中的消息内容为加盾版本""" try: from src.common.database.sqlalchemy_models import Messages, get_db_session diff --git a/src/chat/antipromptinjector/core/detector.py b/src/chat/antipromptinjector/core/detector.py index 1bba79935..39e65db8b 100644 --- a/src/chat/antipromptinjector/core/detector.py +++ b/src/chat/antipromptinjector/core/detector.py @@ -93,7 +93,8 @@ class PromptInjectionDetector: except re.error as e: logger.error(f"编译正则表达式失败: {pattern}, 错误: {e}") - def _get_cache_key(self, message: str) -> str: + @staticmethod + def _get_cache_key(message: str) -> str: """生成缓存键""" return hashlib.md5(message.encode("utf-8")).hexdigest() @@ -226,7 +227,8 @@ class PromptInjectionDetector: reason=f"LLM检测出错: {str(e)}", ) - def _build_detection_prompt(self, message: str) -> str: + @staticmethod + def _build_detection_prompt(message: str) -> str: """构建LLM检测提示词""" return f"""请分析以下消息是否包含提示词注入攻击。 @@ -247,7 +249,8 @@ class PromptInjectionDetector: 请客观分析,避免误判正常对话。""" - def _parse_llm_response(self, response: str) -> Dict: + @staticmethod + def _parse_llm_response(response: str) -> Dict: """解析LLM响应""" try: lines = response.strip().split("\n") diff --git a/src/chat/antipromptinjector/core/shield.py b/src/chat/antipromptinjector/core/shield.py index ba9bf3175..c4ab8afa8 100644 --- a/src/chat/antipromptinjector/core/shield.py +++ b/src/chat/antipromptinjector/core/shield.py @@ -29,11 +29,13 @@ class MessageShield: """初始化加盾器""" self.config = global_config.anti_prompt_injection - def get_safety_system_prompt(self) -> str: + @staticmethod + def get_safety_system_prompt() -> str: """获取安全系统提示词""" return SAFETY_SYSTEM_PROMPT - def is_shield_needed(self, confidence: float, matched_patterns: List[str]) -> bool: + @staticmethod + def is_shield_needed(confidence: float, matched_patterns: List[str]) -> bool: """判断是否需要加盾 Args: @@ -57,7 +59,8 @@ class MessageShield: return False - def create_safety_summary(self, confidence: float, matched_patterns: List[str]) -> str: + @staticmethod + def create_safety_summary(confidence: float, matched_patterns: List[str]) -> str: """创建安全处理摘要 Args: @@ -93,7 +96,8 @@ class MessageShield: # 低风险:添加警告前缀 return f"{self.config.shield_prefix}[内容已检查]{self.config.shield_suffix} {original_message}" - def _partially_shield_content(self, message: str) -> str: + @staticmethod + def _partially_shield_content(message: str) -> str: """部分遮蔽消息内容""" # 遮蔽策略:替换关键词 dangerous_keywords = [ @@ -231,4 +235,4 @@ def create_default_shield() -> MessageShield: """创建默认的消息加盾器""" from .config import default_config - return MessageShield(default_config) + return MessageShield() diff --git a/src/chat/antipromptinjector/counter_attack.py b/src/chat/antipromptinjector/counter_attack.py index ad16ad6b6..7c2bd86c5 100644 --- a/src/chat/antipromptinjector/counter_attack.py +++ b/src/chat/antipromptinjector/counter_attack.py @@ -18,7 +18,8 @@ logger = get_logger("anti_injector.counter_attack") class CounterAttackGenerator: """反击消息生成器""" - def get_personality_context(self) -> str: + @staticmethod + def get_personality_context() -> str: """获取人格上下文信息 Returns: diff --git a/src/chat/antipromptinjector/decision/counter_attack.py b/src/chat/antipromptinjector/decision/counter_attack.py index c12e7697e..9d6aac2ff 100644 --- a/src/chat/antipromptinjector/decision/counter_attack.py +++ b/src/chat/antipromptinjector/decision/counter_attack.py @@ -18,7 +18,8 @@ logger = get_logger("anti_injector.counter_attack") class CounterAttackGenerator: """反击消息生成器""" - def get_personality_context(self) -> str: + @staticmethod + def get_personality_context() -> str: """获取人格上下文信息 Returns: diff --git a/src/chat/antipromptinjector/decision/decision_maker.py b/src/chat/antipromptinjector/decision/decision_maker.py index a988512c4..12a2c95b5 100644 --- a/src/chat/antipromptinjector/decision/decision_maker.py +++ b/src/chat/antipromptinjector/decision/decision_maker.py @@ -22,7 +22,8 @@ class ProcessingDecisionMaker: """ self.config = config - def determine_auto_action(self, detection_result: DetectionResult) -> str: + @staticmethod + def determine_auto_action(detection_result: DetectionResult) -> str: """自动模式:根据检测结果确定处理动作 Args: diff --git a/src/chat/antipromptinjector/decision_maker.py b/src/chat/antipromptinjector/decision_maker.py index dbad9761b..972253fab 100644 --- a/src/chat/antipromptinjector/decision_maker.py +++ b/src/chat/antipromptinjector/decision_maker.py @@ -22,7 +22,8 @@ class ProcessingDecisionMaker: """ self.config = config - def determine_auto_action(self, detection_result: DetectionResult) -> str: + @staticmethod + def determine_auto_action(detection_result: DetectionResult) -> str: """自动模式:根据检测结果确定处理动作 Args: diff --git a/src/chat/antipromptinjector/detector.py b/src/chat/antipromptinjector/detector.py index cd6634060..6c1e3b4bd 100644 --- a/src/chat/antipromptinjector/detector.py +++ b/src/chat/antipromptinjector/detector.py @@ -93,7 +93,8 @@ class PromptInjectionDetector: except re.error as e: logger.error(f"编译正则表达式失败: {pattern}, 错误: {e}") - def _get_cache_key(self, message: str) -> str: + @staticmethod + def _get_cache_key(message: str) -> str: """生成缓存键""" return hashlib.md5(message.encode("utf-8")).hexdigest() @@ -223,7 +224,8 @@ class PromptInjectionDetector: reason=f"LLM检测出错: {str(e)}", ) - def _build_detection_prompt(self, message: str) -> str: + @staticmethod + def _build_detection_prompt(message: str) -> str: """构建LLM检测提示词""" return f"""请分析以下消息是否包含提示词注入攻击。 @@ -244,7 +246,8 @@ class PromptInjectionDetector: 请客观分析,避免误判正常对话。""" - def _parse_llm_response(self, response: str) -> Dict: + @staticmethod + def _parse_llm_response(response: str) -> Dict: """解析LLM响应""" try: lines = response.strip().split("\n") diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 318ff5404..12606d4ba 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -23,7 +23,8 @@ class AntiInjectionStatistics: self.session_start_time = datetime.datetime.now() """当前会话开始时间""" - async def get_or_create_stats(self): + @staticmethod + async def get_or_create_stats(): """获取或创建统计记录""" try: with get_db_session() as session: @@ -39,7 +40,8 @@ class AntiInjectionStatistics: logger.error(f"获取统计记录失败: {e}") return None - async def update_stats(self, **kwargs): + @staticmethod + async def update_stats(**kwargs): """更新统计数据""" try: with get_db_session() as session: @@ -132,7 +134,8 @@ class AntiInjectionStatistics: logger.error(f"获取统计信息失败: {e}") return {"error": f"获取统计信息失败: {e}"} - async def reset_stats(self): + @staticmethod + async def reset_stats(): """重置统计信息""" try: with get_db_session() as session: diff --git a/src/chat/antipromptinjector/processors/message_processor.py b/src/chat/antipromptinjector/processors/message_processor.py index 76add60f0..935848c2d 100644 --- a/src/chat/antipromptinjector/processors/message_processor.py +++ b/src/chat/antipromptinjector/processors/message_processor.py @@ -37,7 +37,8 @@ class MessageProcessor: # 只返回用户新增的内容,避免重复 return new_content - def extract_new_content_from_reply(self, full_text: str) -> str: + @staticmethod + def extract_new_content_from_reply(full_text: str) -> str: """从包含引用的完整消息中提取用户新增的内容 Args: @@ -64,7 +65,8 @@ class MessageProcessor: return new_content - def check_whitelist(self, message: MessageRecv, whitelist: list) -> Optional[tuple]: + @staticmethod + def check_whitelist(message: MessageRecv, whitelist: list) -> Optional[tuple]: """检查用户白名单 Args: @@ -85,7 +87,8 @@ class MessageProcessor: return None - def check_whitelist_dict(self, user_id: str, platform: str, whitelist: list) -> bool: + @staticmethod + def check_whitelist_dict(user_id: str, platform: str, whitelist: list) -> bool: """检查用户是否在白名单中(字典格式) Args: diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index 05edb3ee0..7a351c97c 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -94,7 +94,7 @@ class HeartFChatting: self.context.running = True self.context.relationship_builder = relationship_builder_manager.get_or_create_builder(self.context.stream_id) - self.context.expression_learner = expression_learner_manager.get_expression_learner(self.context.stream_id) + self.context.expression_learner = await expression_learner_manager.get_expression_learner(self.context.stream_id) # 启动主动思考监视器 if global_config.chat.enable_proactive_thinking: @@ -281,7 +281,8 @@ class HeartFChatting: logger.error(f"{self.context.log_prefix} 动态间隔计算出错: {e},使用固定间隔") return max(300, abs(global_config.chat.proactive_thinking_interval)) - def _format_duration(self, seconds: float) -> str: + @staticmethod + def _format_duration(seconds: float) -> str: """ 格式化时长为可读字符串 diff --git a/src/chat/chat_loop/hfc_context.py b/src/chat/chat_loop/hfc_context.py index fe5d283ae..67606de12 100644 --- a/src/chat/chat_loop/hfc_context.py +++ b/src/chat/chat_loop/hfc_context.py @@ -1,17 +1,15 @@ -from typing import List, Optional, TYPE_CHECKING import time -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.person_info.relationship_builder_manager import RelationshipBuilder -from src.chat.express.expression_learner import ExpressionLearner -from src.chat.planner_actions.action_manager import ActionManager +from typing import List, Optional, TYPE_CHECKING + from src.chat.chat_loop.hfc_utils import CycleDetail +from src.chat.express.expression_learner import ExpressionLearner +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.planner_actions.action_manager import ActionManager from src.config.config import global_config +from src.person_info.relationship_builder_manager import RelationshipBuilder if TYPE_CHECKING: - from .sleep_manager.wakeup_manager import WakeUpManager - from .energy_manager import EnergyManager - from .heartFC_chat import HeartFChatting - from .sleep_manager.sleep_manager import SleepManager + pass class HfcContext: diff --git a/src/chat/chat_loop/proactive/proactive_thinker.py b/src/chat/chat_loop/proactive/proactive_thinker.py index 3cbcc2529..adf187dca 100644 --- a/src/chat/chat_loop/proactive/proactive_thinker.py +++ b/src/chat/chat_loop/proactive/proactive_thinker.py @@ -2,19 +2,18 @@ import time import traceback from typing import TYPE_CHECKING, Dict, Any +from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages_with_id +from src.common.database.sqlalchemy_database_api import store_action_info from src.common.logger import get_logger -from src.plugin_system.base.component_types import ChatMode -from ..hfc_context import HfcContext -from .events import ProactiveTriggerEvent +from src.config.config import global_config +from src.mood.mood_manager import mood_manager +from src.plugin_system import tool_api from src.plugin_system.apis import generator_api from src.plugin_system.apis.generator_api import process_human_text +from src.plugin_system.base.component_types import ChatMode from src.schedule.schedule_manager import schedule_manager -from src.plugin_system import tool_api -from src.config.config import global_config -from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages_with_id -from src.mood.mood_manager import mood_manager -from src.common.database.sqlalchemy_database_api import store_action_info, db_get -from src.common.database.sqlalchemy_models import Messages +from .events import ProactiveTriggerEvent +from ..hfc_context import HfcContext if TYPE_CHECKING: from ..cycle_processor import CycleProcessor diff --git a/src/chat/chat_loop/sleep_manager/sleep_manager.py b/src/chat/chat_loop/sleep_manager/sleep_manager.py index 677555aef..ad4aa1ced 100644 --- a/src/chat/chat_loop/sleep_manager/sleep_manager.py +++ b/src/chat/chat_loop/sleep_manager/sleep_manager.py @@ -5,12 +5,12 @@ from typing import Optional, TYPE_CHECKING from src.common.logger import get_logger from src.config.config import global_config +from .notification_sender import NotificationSender from .sleep_state import SleepState, SleepStateSerializer from .time_checker import TimeChecker -from .notification_sender import NotificationSender if TYPE_CHECKING: - from .wakeup_manager import WakeUpManager + pass logger = get_logger("sleep_manager") diff --git a/src/chat/chat_loop/sleep_manager/time_checker.py b/src/chat/chat_loop/sleep_manager/time_checker.py index cbe3d45e8..47376ac35 100644 --- a/src/chat/chat_loop/sleep_manager/time_checker.py +++ b/src/chat/chat_loop/sleep_manager/time_checker.py @@ -34,7 +34,8 @@ class TimeChecker: return self._daily_sleep_offset, self._daily_wake_offset - def get_today_schedule(self) -> Optional[List[Dict[str, Any]]]: + @staticmethod + def get_today_schedule() -> Optional[List[Dict[str, Any]]]: """从全局 ScheduleManager 获取今天的日程安排。""" return schedule_manager.today_schedule diff --git a/src/chat/emoji_system/emoji_history.py b/src/chat/emoji_system/emoji_history.py index a25063f52..d0e2ca856 100644 --- a/src/chat/emoji_system/emoji_history.py +++ b/src/chat/emoji_system/emoji_history.py @@ -2,9 +2,8 @@ """ 表情包发送历史记录模块 """ -import os -from typing import List, Dict from collections import deque +from typing import List, Dict from src.common.logger import get_logger diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index ce7b0d074..e2a6eb7f1 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -424,7 +424,8 @@ class EmojiManager: # if not self._initialized: # raise RuntimeError("EmojiManager not initialized") - async def record_usage(self, emoji_hash: str) -> None: + @staticmethod + async def record_usage(emoji_hash: str) -> None: """记录表情使用次数""" try: async with get_db_session() as session: @@ -436,7 +437,6 @@ class EmojiManager: else: emoji_update.usage_count += 1 emoji_update.last_used_time = time.time() - await session.commit() except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -523,7 +523,7 @@ class EmojiManager: # 7. 获取选中的表情包并更新使用记录 selected_emoji = candidate_emojis[selected_index] - self.record_usage(selected_emoji.hash) + await self.record_usage(selected_emoji.emoji_hash) _time_end = time.time() logger.info( @@ -680,7 +680,8 @@ class EmojiManager: self.emoji_objects = [] # 加载失败则清空列表 self.emoji_num = 0 - async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]: + @staticmethod + async def get_emoji_from_db(emoji_hash: Optional[str] = None) -> List["MaiEmoji"]: """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) 参数: @@ -747,8 +748,8 @@ class EmojiManager: try: emoji_record = await self.get_emoji_from_db(emoji_hash) if emoji_record and emoji_record[0].emotion: - logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...") - return emoji_record.emotion + logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record[0].emotion[:50]}...") + return emoji_record[0].emotion except Exception as e: logger.error(f"从数据库查询表情包描述时出错: {e}") diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 1b9fcf267..a709ee78f 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -4,7 +4,7 @@ import orjson import os from datetime import datetime -from typing import List, Dict, Optional, Any, Tuple +from typing import List, Dict, Optional, Any, Tuple, Coroutine from src.common.logger import get_logger from src.common.database.sqlalchemy_database_api import get_db_session @@ -112,7 +112,7 @@ class ExpressionLearner: logger.error(f"检查学习权限失败: {e}") return False - def should_trigger_learning(self) -> bool: + async def should_trigger_learning(self) -> bool: """ 检查是否应该触发学习 @@ -146,7 +146,7 @@ class ExpressionLearner: return False # 检查消息数量(只检查指定聊天流的消息) - recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( + recent_messages = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_learning_time, timestamp_end=time.time(), @@ -193,7 +193,7 @@ class ExpressionLearner: logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}") return False - def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]: + async def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]: """ 获取指定chat_id的style和grammar表达方式 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 @@ -202,8 +202,8 @@ class ExpressionLearner: learnt_grammar_expressions = [] # 直接从数据库查询 - with get_db_session() as session: - style_query = session.execute( + async with get_db_session() as session: + style_query = await session.execute( select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")) ) for expr in style_query.scalars(): @@ -220,7 +220,7 @@ class ExpressionLearner: "create_date": create_date, } ) - grammar_query = session.execute( + grammar_query = await session.execute( select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")) ) for expr in grammar_query.scalars(): @@ -239,14 +239,15 @@ class ExpressionLearner: ) return learnt_style_expressions, learnt_grammar_expressions - def _apply_global_decay_to_database(self, current_time: float) -> None: + async def _apply_global_decay_to_database(self, current_time: float) -> None: """ 对数据库中的所有表达方式应用全局衰减 """ try: - with get_db_session() as session: + async with get_db_session() as session: # 获取所有表达方式 - all_expressions = session.execute(select(Expression)).scalars() + all_expressions = await session.execute(select(Expression)) + all_expressions = all_expressions.scalars().all() updated_count = 0 deleted_count = 0 @@ -263,7 +264,7 @@ class ExpressionLearner: if new_count <= 0.01: # 如果count太小,删除这个表达方式 session.delete(expr) - session.commit() + await session.commit() deleted_count += 1 else: # 更新count @@ -276,7 +277,8 @@ class ExpressionLearner: except Exception as e: logger.error(f"数据库全局衰减失败: {e}") - def calculate_decay_factor(self, time_diff_days: float) -> float: + @staticmethod + def calculate_decay_factor(time_diff_days: float) -> float: """ 计算衰减值 当时间差为0天时,衰减值为0(最近活跃的不衰减) @@ -298,7 +300,7 @@ class ExpressionLearner: return min(0.01, decay) - async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]: + async def learn_and_store(self, type: str, num: int = 10) -> None | list[Any] | list[tuple[str, str, str]]: # sourcery skip: use-join """ 学习并存储表达方式 @@ -349,19 +351,20 @@ class ExpressionLearner: # 存储到数据库 Expression 表 for chat_id, expr_list in chat_dict.items(): - for new_expr in expr_list: - # 查找是否已存在相似表达方式 - with get_db_session() as session: - query = session.execute( + async with get_db_session() as session: + for new_expr in expr_list: + # 查找是否已存在相似表达方式 + query = await session.execute( select(Expression).where( (Expression.chat_id == chat_id) & (Expression.type == type) & (Expression.situation == new_expr["situation"]) & (Expression.style == new_expr["style"]) ) - ).scalar() - if query: - expr_obj = query + ) + existing_expr = query.scalar() + if existing_expr: + expr_obj = existing_expr # 50%概率替换内容 if random.random() < 0.5: expr_obj.situation = new_expr["situation"] @@ -378,23 +381,22 @@ class ExpressionLearner: type=type, create_date=current_time, # 手动设置创建日期 ) - session.add(new_expression) - session.commit() + await session.add(new_expression) + # 限制最大数量 - exprs = list( - session.execute( - select(Expression) - .where((Expression.chat_id == chat_id) & (Expression.type == type)) - .order_by(Expression.count.asc()) - ).scalars() + exprs_result = await session.execute( + select(Expression) + .where((Expression.chat_id == chat_id) & (Expression.type == type)) + .order_by(Expression.count.asc()) ) + exprs = list(exprs_result.scalars()) if len(exprs) > MAX_EXPRESSION_COUNT: # 删除count最小的多余表达方式 for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: - session.delete(expr) - session.commit() + await session.delete(expr) return learnt_expressions + return None async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]: """从指定聊天流学习表达方式 @@ -414,7 +416,7 @@ class ExpressionLearner: current_time = time.time() # 获取上次学习时间 - random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_with_chat_inclusive( + random_msg: Optional[List[Dict[str, Any]]] = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_learning_time, timestamp_end=current_time, @@ -449,7 +451,8 @@ class ExpressionLearner: return expressions, chat_id - def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]: + @staticmethod + def parse_expression_response(response: str, chat_id: str) -> List[Tuple[str, str, str]]: """ 解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组 """ @@ -488,15 +491,18 @@ class ExpressionLearnerManager: self.expression_learners = {} self._ensure_expression_directories() - self._auto_migrate_json_to_db() - self._migrate_old_data_create_date() - def get_expression_learner(self, chat_id: str) -> ExpressionLearner: + + async def get_expression_learner(self, chat_id: str) -> ExpressionLearner: + await self._auto_migrate_json_to_db() + await self._migrate_old_data_create_date() + if chat_id not in self.expression_learners: self.expression_learners[chat_id] = ExpressionLearner(chat_id) return self.expression_learners[chat_id] - def _ensure_expression_directories(self): + @staticmethod + def _ensure_expression_directories(): """ 确保表达方式相关的目录结构存在 """ @@ -514,7 +520,8 @@ class ExpressionLearnerManager: except Exception as e: logger.error(f"创建目录失败 {directory}: {e}") - def _auto_migrate_json_to_db(self): + @staticmethod + async def _auto_migrate_json_to_db(): """ 自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。 迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。 @@ -577,33 +584,33 @@ class ExpressionLearnerManager: continue # 查重:同chat_id+type+situation+style - with get_db_session() as session: - query = session.execute( + async with get_db_session() as session: + query = await session.execute( select(Expression).where( (Expression.chat_id == chat_id) & (Expression.type == type_str) & (Expression.situation == situation) & (Expression.style == style_val) ) - ).scalar() - if query: - expr_obj = query - expr_obj.count = max(expr_obj.count, count) - expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time) - else: - new_expression = Expression( - situation=situation, - style=style_val, - count=count, - last_active_time=last_active_time, - chat_id=chat_id, - type=type_str, - create_date=last_active_time, # 迁移时使用last_active_time作为创建时间 ) - session.add(new_expression) - session.commit() + existing_expr = query.scalar() + if existing_expr: + expr_obj = existing_expr + expr_obj.count = max(expr_obj.count, count) + expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time) + else: + new_expression = Expression( + situation=situation, + style=style_val, + count=count, + last_active_time=last_active_time, + chat_id=chat_id, + type=type_str, + create_date=last_active_time, # 迁移时使用last_active_time作为创建时间 + ) + await session.add(new_expression) - migrated_count += 1 + migrated_count += 1 logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式") except orjson.JSONDecodeError as e: logger.error(f"JSON解析失败 {expr_file}: {e}") @@ -628,15 +635,17 @@ class ExpressionLearnerManager: except Exception as e: logger.error(f"写入done.done标记文件失败: {e}") - def _migrate_old_data_create_date(self): + @staticmethod + async def _migrate_old_data_create_date(): """ 为没有create_date的老数据设置创建日期 使用last_active_time作为create_date的默认值 """ try: - with get_db_session() as session: + async with get_db_session() as session: # 查找所有create_date为空的表达方式 - old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars() + old_expressions_result = await session.execute(select(Expression).where(Expression.create_date.is_(None))) + old_expressions = old_expressions_result.scalars().all() updated_count = 0 for expr in old_expressions: @@ -646,7 +655,6 @@ class ExpressionLearnerManager: if updated_count > 0: logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期") - session.commit() except Exception as e: logger.error(f"迁移老数据创建日期失败: {e}") diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 2a269fbf9..ff4083a3b 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -76,7 +76,8 @@ class ExpressionSelector: model_set=model_config.model_task_config.utils_small, request_type="expression.selector" ) - def can_use_expression_for_chat(self, chat_id: str) -> bool: + @staticmethod + def can_use_expression_for_chat(chat_id: str) -> bool: """ 检查指定聊天流是否允许使用表达 @@ -193,7 +194,8 @@ class ExpressionSelector: return selected_style, selected_grammar - async def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): + @staticmethod + async def update_expressions_count_batch(expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库""" if not expressions_to_update: return diff --git a/src/chat/frequency_analyzer/analyzer.py b/src/chat/frequency_analyzer/analyzer.py index bd6331465..f888b9737 100644 --- a/src/chat/frequency_analyzer/analyzer.py +++ b/src/chat/frequency_analyzer/analyzer.py @@ -40,7 +40,8 @@ class ChatFrequencyAnalyzer: self._analysis_cache: dict[str, tuple[float, list[tuple[time, time]]]] = {} self._cache_ttl_seconds = 60 * 30 # 缓存30分钟 - def _find_peak_windows(self, timestamps: List[float]) -> List[Tuple[datetime, datetime]]: + @staticmethod + def _find_peak_windows(timestamps: List[float]) -> List[Tuple[datetime, datetime]]: """ 使用滑动窗口算法来识别时间戳列表中的高峰时段。 diff --git a/src/chat/frequency_analyzer/tracker.py b/src/chat/frequency_analyzer/tracker.py index bee9e4623..178435528 100644 --- a/src/chat/frequency_analyzer/tracker.py +++ b/src/chat/frequency_analyzer/tracker.py @@ -21,7 +21,8 @@ class ChatFrequencyTracker: def __init__(self): self._timestamps: Dict[str, List[float]] = self._load_timestamps() - def _load_timestamps(self) -> Dict[str, List[float]]: + @staticmethod + def _load_timestamps() -> Dict[str, List[float]]: """从本地文件加载时间戳数据。""" if not TRACKER_FILE.exists(): return {} diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index c68df532c..1e6376e6d 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -1,22 +1,20 @@ import asyncio -import re import math +import re import traceback -from datetime import datetime - from typing import Tuple, TYPE_CHECKING -from src.config.config import global_config +from src.chat.heart_flow.heartflow import heartflow from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.storage import MessageStorage -from src.chat.heart_flow.heartflow import heartflow -from src.chat.utils.utils import is_mentioned_bot_in_message -from src.chat.utils.timer_calculator import Timer from src.chat.utils.chat_message_builder import replace_user_references_sync +from src.chat.utils.timer_calculator import Timer +from src.chat.utils.utils import is_mentioned_bot_in_message from src.common.logger import get_logger -from src.person_info.relationship_manager import get_relationship_manager +from src.config.config import global_config from src.mood.mood_manager import mood_manager +from src.person_info.relationship_manager import get_relationship_manager if TYPE_CHECKING: from src.chat.heart_flow.sub_heartflow import SubHeartflow diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 75af35a7b..67296c0c9 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -125,7 +125,8 @@ class EmbeddingStore: self.faiss_index = None self.idx2hash = None - def _get_embedding(self, s: str) -> List[float]: + @staticmethod + def _get_embedding(s: str) -> List[float]: """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" # 创建新的事件循环并在完成后立即关闭 loop = asyncio.new_event_loop() @@ -157,8 +158,9 @@ class EmbeddingStore: except Exception: ... + @staticmethod def _get_embeddings_batch_threaded( - self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None + strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None ) -> List[Tuple[str, List[float]]]: """使用多线程批量获取嵌入向量 @@ -265,7 +267,8 @@ class EmbeddingStore: return ordered_results - def get_test_file_path(self): + @staticmethod + def get_test_file_path(): return EMBEDDING_TEST_FILE def save_embedding_test_vectors(self): diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 46d46b202..fcc8e65d2 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -838,7 +838,7 @@ class EntorhinalCortex: timestamp_start = target_timestamp timestamp_end = target_timestamp + time_window_seconds - if chosen_message := get_raw_msg_by_timestamp( + if chosen_message := await get_raw_msg_by_timestamp( timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, @@ -846,7 +846,7 @@ class EntorhinalCortex: ): chat_id: str = chosen_message[0].get("chat_id") # type: ignore - if messages := get_raw_msg_by_timestamp_with_chat( + if messages := await get_raw_msg_by_timestamp_with_chat( timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, diff --git a/src/chat/memory_system/async_memory_optimizer.py b/src/chat/memory_system/async_memory_optimizer.py index ee215abde..e80ad0efd 100644 --- a/src/chat/memory_system/async_memory_optimizer.py +++ b/src/chat/memory_system/async_memory_optimizer.py @@ -137,7 +137,8 @@ class AsyncMemoryQueue: except Exception: pass - async def _handle_store_task(self, task: MemoryTask) -> Any: + @staticmethod + async def _handle_store_task(task: MemoryTask) -> Any: """处理记忆存储任务""" # 这里需要根据具体的记忆系统来实现 # 为了避免循环导入,这里使用延迟导入 @@ -156,7 +157,8 @@ class AsyncMemoryQueue: logger.error(f"记忆存储失败: {e}") return False - async def _handle_retrieve_task(self, task: MemoryTask) -> Any: + @staticmethod + async def _handle_retrieve_task(task: MemoryTask) -> Any: """处理记忆检索任务""" try: # 获取包装器实例 @@ -173,7 +175,8 @@ class AsyncMemoryQueue: logger.error(f"记忆检索失败: {e}") return [] - async def _handle_build_task(self, task: MemoryTask) -> Any: + @staticmethod + async def _handle_build_task(task: MemoryTask) -> Any: """处理记忆构建任务(海马体系统)""" try: # 延迟导入避免循环依赖 diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index 6ea0163c0..0b4b0b2e3 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -106,7 +106,8 @@ class InstantMemory: else: logger.info(f"不需要记忆:{text}") - async def store_memory(self, memory_item: MemoryItem): + @staticmethod + async def store_memory(memory_item: MemoryItem): with get_db_session() as session: memory = Memory( memory_id=memory_item.memory_id, @@ -198,7 +199,8 @@ class InstantMemory: logger.error(f"获取记忆出现错误:{str(e)} {traceback.format_exc()}") return None - def _parse_time_range(self, time_str): + @staticmethod + def _parse_time_range(time_str): # sourcery skip: extract-duplicate-method, use-contextlib-suppress """ 支持解析如下格式: diff --git a/src/chat/memory_system/vector_instant_memory.py b/src/chat/memory_system/vector_instant_memory.py index 96af659d7..12d9622e0 100644 --- a/src/chat/memory_system/vector_instant_memory.py +++ b/src/chat/memory_system/vector_instant_memory.py @@ -243,7 +243,8 @@ class VectorInstantMemoryV2: logger.error(f"查找相似消息失败: {e}") return [] - def _format_time_ago(self, timestamp: float) -> str: + @staticmethod + def _format_time_ago(timestamp: float) -> str: """格式化时间差显示""" if timestamp <= 0: return "未知时间" diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 67c56be2a..3b68190a7 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -80,7 +80,8 @@ class ChatBot: # 初始化反注入系统 self._initialize_anti_injector() - def _initialize_anti_injector(self): + @staticmethod + def _initialize_anti_injector(): """初始化反注入系统""" try: initialize_anti_injector() @@ -100,7 +101,8 @@ class ChatBot: self._started = True - async def _process_plus_commands(self, message: MessageRecv): + @staticmethod + async def _process_plus_commands(message: MessageRecv): """独立处理PlusCommand系统""" try: text = message.processed_plain_text @@ -220,7 +222,8 @@ class ChatBot: logger.error(f"处理PlusCommand时出错: {e}") return False, None, True # 出错时继续处理消息 - async def _process_commands_with_new_system(self, message: MessageRecv): + @staticmethod + async def _process_commands_with_new_system(message: MessageRecv): # sourcery skip: use-named-expression """使用新插件系统处理命令""" try: @@ -310,7 +313,8 @@ class ChatBot: return False - async def handle_adapter_response(self, message: MessageRecv): + @staticmethod + async def handle_adapter_response(message: MessageRecv): """处理适配器命令响应""" try: from src.plugin_system.apis.send_api import put_adapter_response diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 4f91d15c6..e72d99686 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -203,7 +203,8 @@ class ChatManager: key = "_".join(components) return hashlib.md5(key.encode()).hexdigest() - def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str: + @staticmethod + def get_stream_id(platform: str, id: str, is_group: bool = True) -> str: """获取聊天流ID""" components = [platform, id] if is_group else [platform, id, "private"] key = "_".join(components) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 1df006a1c..fc57b6fc6 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,19 +1,18 @@ -import time -import urllib3 import base64 - -from abc import abstractmethod +import time +from abc import abstractmethod, ABCMeta from dataclasses import dataclass -from rich.traceback import install from typing import Optional, Any -from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase -from src.common.logger import get_logger +import urllib3 +from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase +from rich.traceback import install + from src.chat.utils.utils_image import get_image_manager -from src.chat.utils.utils_voice import get_voice_text from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available +from src.chat.utils.utils_voice import get_voice_text +from src.common.logger import get_logger from src.config.config import global_config -from .chat_stream import ChatStream install(extra_lines=3) @@ -28,7 +27,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @dataclass -class Message(MessageBase): +class Message(MessageBase, metaclass=ABCMeta): chat_stream: "ChatStream" = None # type: ignore reply: Optional["Message"] = None processed_plain_text: str = "" @@ -96,12 +95,13 @@ class Message(MessageBase): class MessageRecv(Message): """接收消息类,用于处理从MessageCQ序列化的消息""" - def __init__(self, message_dict: dict[str, Any]): + def __init__(self, message_dict: dict[str, Any], message_id: str, chat_stream: "ChatStream", user_info: UserInfo): """从MessageCQ的字典初始化 Args: message_dict: MessageCQ序列化后的字典 """ + super().__init__(message_id, chat_stream, user_info) self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) self.raw_message = message_dict.get("raw_message") diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index c362187e2..eb0dc5d1e 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,14 +1,14 @@ import re import traceback -import orjson from typing import Union -from src.common.database.sqlalchemy_models import Messages, Images +import orjson +from sqlalchemy import select, desc, update + +from src.common.database.sqlalchemy_models import Messages, Images, get_db_session from src.common.logger import get_logger from .chat_stream import ChatStream from .message import MessageSending, MessageRecv -from src.common.database.sqlalchemy_database_api import get_db_session -from sqlalchemy import select, update, desc logger = get_logger("message_storage") @@ -116,21 +116,13 @@ class MessageStorage: user_nickname=user_info_dict.get("user_nickname"), user_cardname=user_info_dict.get("user_cardname"), processed_plain_text=filtered_processed_plain_text, - display_message=filtered_display_message, - memorized_times=message.memorized_times, - interest_value=interest_value, priority_mode=priority_mode, priority_info=priority_info_json, is_emoji=is_emoji, is_picid=is_picid, - is_notify=is_notify, - is_command=is_command, - key_words=key_words, - key_words_lite=key_words_lite, ) async with get_db_session() as session: - session.add(new_message) - await session.commit() + await session.add(new_message) except Exception: logger.exception("存储消息失败") @@ -153,8 +145,7 @@ class MessageStorage: qq_message_id = message.message_segment.data.get("id") elif message.message_segment.type == "reply": qq_message_id = message.message_segment.data.get("id") - if qq_message_id: - logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}") + logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}") elif message.message_segment.type == "adapter_response": logger.debug("适配器响应消息,不需要更新ID") return @@ -197,7 +188,6 @@ class MessageStorage: f"segment_type={getattr(message.message_segment, 'type', 'N/A')}" ) - @staticmethod async def replace_image_descriptions(text: str) -> str: """将[图片:描述]替换为[picid:image_id]""" # 先检查文本中是否有图片标记 diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 267b7a8ff..23755e42d 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -27,9 +27,9 @@ class ActionManager: # === 执行Action方法 === + @staticmethod def create_action( - self, - action_name: str, + action_name: str, action_data: dict, reasoning: str, cycle_timers: dict, diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index a061c15ae..154fe62a7 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -243,7 +243,8 @@ class ActionModifier: return deactivated_actions - def _generate_context_hash(self, chat_content: str) -> str: + @staticmethod + def _generate_context_hash(chat_content: str) -> str: """生成上下文的哈希值用于缓存""" context_content = f"{chat_content}" return hashlib.md5(context_content.encode("utf-8")).hexdigest() diff --git a/src/chat/planner_actions/plan_executor.py b/src/chat/planner_actions/plan_executor.py index b27ef12e3..591389f99 100644 --- a/src/chat/planner_actions/plan_executor.py +++ b/src/chat/planner_actions/plan_executor.py @@ -27,7 +27,8 @@ class PlanExecutor: """ self.action_manager = action_manager - async def execute(self, plan: Plan): + @staticmethod + async def execute(plan: Plan): """ 遍历并执行 Plan 对象中 `decided_actions` 列表里的所有动作。 diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index 2d05f0511..91237c9cb 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -297,15 +297,17 @@ class PlanFilter: ) return parsed_actions + @staticmethod def _filter_no_actions( - self, action_list: List[ActionPlannerInfo] + action_list: List[ActionPlannerInfo] ) -> List[ActionPlannerInfo]: non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]] if non_no_actions: return non_no_actions return action_list[:1] if action_list else [] - async def _get_long_term_memory_context(self) -> str: + @staticmethod + async def _get_long_term_memory_context() -> str: try: now = datetime.now() keywords = ["今天", "日程", "计划"] @@ -329,7 +331,8 @@ class PlanFilter: logger.error(f"获取长期记忆时出错: {e}") return "回忆时出现了一些问题。" - async def _build_action_options(self, current_available_actions: Dict[str, ActionInfo]) -> str: + @staticmethod + async def _build_action_options(current_available_actions: Dict[str, ActionInfo]) -> str: action_options_block = "" for action_name, action_info in current_available_actions.items(): param_text = "" @@ -347,7 +350,8 @@ class PlanFilter: ) return action_options_block - def _find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: + @staticmethod + def _find_message_by_id(message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: if message_id.isdigit(): message_id = f"m{message_id}" for item in message_id_list: @@ -355,7 +359,8 @@ class PlanFilter: return item.get("message") return None - def _get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]: + @staticmethod + def _get_latest_message(message_id_list: list) -> Optional[Dict[str, Any]]: if not message_id_list: return None return message_id_list[-1].get("message") diff --git a/src/chat/planner_actions/plan_generator.py b/src/chat/planner_actions/plan_generator.py index 96af31c4b..5d1ab9c38 100644 --- a/src/chat/planner_actions/plan_generator.py +++ b/src/chat/planner_actions/plan_generator.py @@ -2,7 +2,7 @@ PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。 """ import time -from typing import Dict, Optional, Tuple +from typing import Dict from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat from src.chat.utils.utils import get_chat_type_and_target_info diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 8e4f18fae..6e45b7907 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -8,12 +8,10 @@ from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.plan_executor import PlanExecutor from src.chat.planner_actions.plan_filter import PlanFilter from src.chat.planner_actions.plan_generator import PlanGenerator -from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.logger import get_logger from src.plugin_system.base.component_types import ChatMode # 导入提示词模块以确保其被初始化 -from . import planner_prompts logger = get_logger("planner") diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 705924c73..06cff0c9d 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -591,7 +591,8 @@ class DefaultReplyer: logger.error(f"工具信息获取失败: {e}") return "" - def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: + @staticmethod + def _parse_reply_target(target_message: str) -> Tuple[str, str]: """解析回复目标消息 - 使用共享工具""" from src.chat.utils.prompt import Prompt if target_message is None: @@ -599,7 +600,8 @@ class DefaultReplyer: return "未知用户", "(无消息内容)" return Prompt.parse_reply_target(target_message) - async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: + @staticmethod + async def build_keywords_reaction_prompt(target: Optional[str]) -> str: """构建关键词反应提示 Args: @@ -641,7 +643,8 @@ class DefaultReplyer: return keywords_reaction_prompt - async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]: + @staticmethod + async def _time_and_run_task(coroutine, name: str) -> Tuple[str, Any, float]: """计时并运行异步任务的辅助函数 Args: @@ -730,9 +733,9 @@ class DefaultReplyer: return core_dialogue_prompt, all_dialogue_prompt + @staticmethod def build_mai_think_context( - self, - chat_id: str, + chat_id: str, memory_block: str, relation_info: str, time_block: str, diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index ec5446e64..9dec72a28 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -215,6 +215,10 @@ class PromptManager: result = prompt.format(**kwargs) return result + @property + def context(self): + return self._context + # 全局单例 global_prompt_manager = PromptManager() @@ -256,7 +260,7 @@ class Prompt: self._processed_template = self._process_escaped_braces(template) # 自动注册 - if should_register and not global_prompt_manager._context._current_context: + if should_register and not global_prompt_manager.context._current_context: global_prompt_manager.register(self) @staticmethod @@ -459,8 +463,9 @@ class Prompt: context_data["chat_info"] = f"""群里的聊天内容: {self.parameters.chat_talking_prompt_short}""" + @staticmethod async def _build_s4u_chat_history_prompts( - self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str + message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str ) -> Tuple[str, str]: """构建S4U风格的分离对话prompt""" # 实现逻辑与原有SmartPromptBuilder相同 @@ -537,14 +542,10 @@ class Prompt: ) # 创建表情选择器 - expression_selector = ExpressionSelector(self.parameters.chat_id) + expression_selector = ExpressionSelector() # 选择合适的表情 selected_expressions = await expression_selector.select_suitable_expressions_llm( - chat_history=chat_history, - current_message=self.parameters.target, - emotional_tone="neutral", - topic_type="general" ) # 构建表达习惯块 @@ -991,7 +992,7 @@ async def create_prompt_async( ) -> Prompt: """异步创建Prompt实例""" prompt = create_prompt(template, name, parameters, **kwargs) - if global_prompt_manager._context._current_context: - await global_prompt_manager._context.register_async(prompt) + if global_prompt_manager.context._current_context: + await global_prompt_manager.context.register_async(prompt) return prompt diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index da775d36c..891f7653c 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -763,7 +763,8 @@ class StatisticOutputTask(AsyncTask): output.append("") return "\n".join(output) - def _get_chat_display_name_from_id(self, chat_id: str) -> str: + @staticmethod + def _get_chat_display_name_from_id(chat_id: str) -> str: """从chat_id获取显示名称""" try: # 首先尝试从chat_stream获取真实群组名称 @@ -1109,7 +1110,8 @@ class StatisticOutputTask(AsyncTask): return chart_data - def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: + @staticmethod + def _collect_interval_data(now: datetime, hours: int, interval_minutes: int) -> dict: """收集指定时间范围内每个间隔的数据""" # 生成时间点 start_time = now - timedelta(hours=hours) @@ -1199,7 +1201,8 @@ class StatisticOutputTask(AsyncTask): "message_by_chat": message_by_chat, } - def _generate_chart_tab(self, chart_data: dict) -> str: + @staticmethod + def _generate_chart_tab(chart_data: dict) -> str: # sourcery skip: extract-duplicate-method, move-assign-in-block """生成图表选项卡HTML内容""" @@ -1563,13 +1566,13 @@ class AsyncStatisticOutputTask(AsyncTask): return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: - return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore + return StatisticOutputTask._collect_interval_data(now, hours, interval_minutes) # type: ignore def _generate_chart_tab(self, chart_data: dict) -> str: - return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore + return StatisticOutputTask._generate_chart_tab(chart_data) # type: ignore def _get_chat_display_name_from_id(self, chat_id: str) -> str: - return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore + return StatisticOutputTask._get_chat_display_name_from_id(chat_id) # type: ignore def _convert_defaultdict_to_dict(self, data): return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 675bf4b85..99647e36c 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -7,7 +7,7 @@ import numpy as np from collections import Counter from maim_message import UserInfo -from typing import Optional, Tuple, Dict, List, Any +from typing import Optional, Tuple, Dict, List, Any, Coroutine from src.common.logger import get_logger from src.common.message_repository import find_messages, count_messages @@ -540,7 +540,8 @@ def get_western_ratio(paragraph): return western_count / len(alnum_chars) -def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]: +def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int] | tuple[ + Coroutine[Any, Any, int], int]: """计算两个时间点之间的消息数量和文本总长度 Args: diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 8069fd616..18adc8a9b 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -134,7 +134,8 @@ class ImageManager: except Exception as e: logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}") - async def get_emoji_tag(self, image_base64: str) -> str: + @staticmethod + async def get_emoji_tag(image_base64: str) -> str: from src.chat.emoji_system.emoji_manager import get_emoji_manager emoji_manager = get_emoji_manager() diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 1e186f058..6ea5a111f 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -167,7 +167,8 @@ class VideoAnalyzer: # 获取Rust模块系统信息 self._log_system_info() - def _log_system_info(self): + @staticmethod + def _log_system_info(): """记录系统信息""" if not RUST_VIDEO_AVAILABLE: logger.info("⚠️ Rust模块不可用,跳过系统信息获取") @@ -196,13 +197,15 @@ class VideoAnalyzer: except Exception as e: logger.warning(f"获取系统信息失败: {e}") - def _calculate_video_hash(self, video_data: bytes) -> str: + @staticmethod + def _calculate_video_hash(video_data: bytes) -> str: """计算视频文件的hash值""" hash_obj = hashlib.sha256() hash_obj.update(video_data) return hash_obj.hexdigest() - def _check_video_exists(self, video_hash: str) -> Optional[Videos]: + @staticmethod + def _check_video_exists(video_hash: str) -> Optional[Videos]: """检查视频是否已经分析过""" try: with get_db_session() as session: @@ -213,8 +216,9 @@ class VideoAnalyzer: logger.warning(f"检查视频是否存在时出错: {e}") return None + @staticmethod def _store_video_result( - self, video_hash: str, description: str, metadata: Optional[Dict] = None + video_hash: str, description: str, metadata: Optional[Dict] = None ) -> Optional[Videos]: """存储视频分析结果到数据库""" # 检查描述是否为错误信息,如果是则不保存 @@ -619,7 +623,7 @@ class VideoAnalyzer: if self.disabled: error_msg = "❌ 视频分析功能已禁用:没有可用的视频处理实现" logger.warning(error_msg) - return (False, error_msg) + return False, error_msg try: logger.info(f"开始分析视频: {os.path.basename(video_path)}") @@ -628,7 +632,7 @@ class VideoAnalyzer: frames = await self.extract_frames(video_path) if not frames: error_msg = "❌ 无法从视频中提取有效帧" - return (False, error_msg) + return False, error_msg # 根据模式选择分析方法 if self.analysis_mode == "auto": @@ -645,12 +649,12 @@ class VideoAnalyzer: result = await self.analyze_frames_sequential(frames, user_question) logger.info("✅ 视频分析完成") - return (True, result) + return True, result except Exception as e: error_msg = f"❌ 视频分析失败: {str(e)}" logger.error(error_msg) - return (False, error_msg) + return False, error_msg async def analyze_video_from_bytes( self, video_bytes: bytes, filename: str = None, user_question: str = None, prompt: str = None @@ -783,7 +787,8 @@ class VideoAnalyzer: return {"summary": error_msg} - def is_supported_video(self, file_path: str) -> bool: + @staticmethod + def is_supported_video(file_path: str) -> bool: """检查是否为支持的视频格式""" supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"} return Path(file_path).suffix.lower() in supported_formats @@ -818,7 +823,8 @@ class VideoAnalyzer: logger.error(f"获取处理能力信息失败: {e}") return {"error": str(e), "available": False} - def _get_recommended_settings(self, cpu_features: Dict[str, bool]) -> Dict[str, any]: + @staticmethod + def _get_recommended_settings(cpu_features: Dict[str, bool]) -> Dict[str, any]: """根据CPU特性推荐最佳设置""" settings = { "use_simd": any(cpu_features.values()), diff --git a/src/chat/utils/utils_video_legacy.py b/src/chat/utils/utils_video_legacy.py index ef5f49301..4d8e06681 100644 --- a/src/chat/utils/utils_video_legacy.py +++ b/src/chat/utils/utils_video_legacy.py @@ -13,7 +13,7 @@ import base64 import numpy as np from PIL import Image from pathlib import Path -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Any import io from concurrent.futures import ThreadPoolExecutor @@ -31,7 +31,7 @@ def _extract_frames_worker( max_image_size: int, frame_extraction_mode: str, frame_interval_seconds: Optional[float], -) -> List[Tuple[str, float]]: +) -> list[Any] | list[tuple[str, str]]: """线程池中提取视频帧的工作函数""" frames = [] try: @@ -568,7 +568,8 @@ class LegacyVideoAnalyzer: logger.error(error_msg) return error_msg - def is_supported_video(self, file_path: str) -> bool: + @staticmethod + def is_supported_video(file_path: str) -> bool: """检查是否为支持的视频格式""" supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"} return Path(file_path).suffix.lower() in supported_formats diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index d8abc241b..ec940ec6c 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -53,7 +53,8 @@ class CacheManager: self._initialized = True logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)") - def _validate_embedding(self, embedding_result: Any) -> Optional[np.ndarray]: + @staticmethod + def _validate_embedding(embedding_result: Any) -> Optional[np.ndarray]: """ 验证和标准化嵌入向量格式 """ @@ -90,7 +91,8 @@ class CacheManager: logger.error(f"验证嵌入向量时发生错误: {e}") return None - def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path]) -> str: + @staticmethod + def _generate_key(tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path]) -> str: """生成确定性的缓存键,包含文件修改时间以实现自动失效。""" try: tool_file_path = Path(tool_file_path) diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index 32893706d..7de787060 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -1,10 +1,10 @@ from dataclasses import dataclass, field from typing import Optional, Dict, List, TYPE_CHECKING + from . import BaseDataModel if TYPE_CHECKING: - from .database_data_model import DatabaseMessages - from src.plugin_system.base.component_types import ActionInfo, ChatMode + pass @dataclass diff --git a/src/common/data_models/llm_data_model.py b/src/common/data_models/llm_data_model.py index 1d5b75e0c..cd706bc55 100644 --- a/src/common/data_models/llm_data_model.py +++ b/src/common/data_models/llm_data_model.py @@ -2,8 +2,9 @@ from dataclasses import dataclass from typing import Optional, List, Tuple, TYPE_CHECKING, Any from . import BaseDataModel + if TYPE_CHECKING: - from src.llm_models.payload_content.tool_option import ToolCall + pass @dataclass class LLMGenerationDataModel(BaseDataModel): diff --git a/src/common/data_models/message_data_model.py b/src/common/data_models/message_data_model.py index 8e0b77862..bf08a0d6a 100644 --- a/src/common/data_models/message_data_model.py +++ b/src/common/data_models/message_data_model.py @@ -1,10 +1,10 @@ -from typing import Optional, TYPE_CHECKING from dataclasses import dataclass, field +from typing import Optional, TYPE_CHECKING from . import BaseDataModel if TYPE_CHECKING: - from .database_data_model import DatabaseMessages + pass @dataclass diff --git a/src/common/database/database.py b/src/common/database/database.py index d196df032..3279a67ed 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -25,7 +25,8 @@ class DatabaseProxy: self._engine = None self._session = None - def initialize(self, *args, **kwargs): + @staticmethod + def initialize(*args, **kwargs): """初始化数据库连接""" return initialize_database_compat() diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 4f0258c2b..13ef39c1a 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -4,16 +4,14 @@ 支持自动重连、连接池管理和更好的错误处理 """ -import traceback import time -import asyncio -from typing import Dict, List, Any, Union, Type, Optional -from sqlalchemy.exc import SQLAlchemyError +import traceback +from typing import Dict, List, Any, Union, Optional + from sqlalchemy import desc, asc, func, and_, select -from sqlalchemy.ext.asyncio import AsyncSession -from src.common.logger import get_logger +from sqlalchemy.exc import SQLAlchemyError + from src.common.database.sqlalchemy_models import ( - Base, get_db_session, Messages, ActionRecords, @@ -33,6 +31,7 @@ from src.common.database.sqlalchemy_models import ( MaiZoneScheduleStatus, CacheEntries, ) +from src.common.logger import get_logger logger = get_logger("sqlalchemy_database_api") diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index e5eacee1f..0c193e358 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -3,17 +3,18 @@ 替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力 """ -from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker -from sqlalchemy.orm import Mapped, mapped_column -import os import datetime +import os import time -from typing import Iterator, Optional, Any, Dict, AsyncGenerator -from src.common.logger import get_logger from contextlib import asynccontextmanager -import asyncio +from typing import Optional, Any, Dict, AsyncGenerator + +from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Mapped, mapped_column + +from src.common.logger import get_logger logger = get_logger("sqlalchemy_models") diff --git a/src/common/server.py b/src/common/server.py index 3263589a2..a06cf1151 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -1,10 +1,10 @@ +import os +from typing import Optional + from fastapi import FastAPI, APIRouter from fastapi.middleware.cors import CORSMiddleware # 新增导入 -from typing import Optional -from uvicorn import Config, Server as UvicornServer -from src.config.config import global_config from rich.traceback import install -import os +from uvicorn import Config, Server as UvicornServer install(extra_lines=3) diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index 0b9d333c4..5e5e035dd 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -22,7 +22,6 @@ class APIProvider(ValidatedConfigBase): enable_content_obfuscation: bool = Field(default=False, description="是否启用内容混淆(用于特定场景下的内容处理)") obfuscation_intensity: int = Field(default=1, ge=1, le=3, description="混淆强度(1-3级,数值越高混淆程度越强)") - @field_validator("base_url") @classmethod def validate_base_url(cls, v): """验证base_url,确保URL格式正确""" @@ -30,7 +29,6 @@ class APIProvider(ValidatedConfigBase): raise ValueError("base_url必须以http://或https://开头") return v - @field_validator("api_key") @classmethod def validate_api_key(cls, v): """验证API密钥不能为空""" @@ -75,7 +73,6 @@ class ModelInfo(ValidatedConfigBase): extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数(用于API调用时的额外配置)") anti_truncation: bool = Field(default=False, description="是否启用反截断功能,防止模型输出被截断") - @field_validator("price_in", "price_out") @classmethod def validate_prices(cls, v): """验证价格必须为非负数""" @@ -83,7 +80,6 @@ class ModelInfo(ValidatedConfigBase): raise ValueError("价格不能为负数") return v - @field_validator("model_identifier") @classmethod def validate_model_identifier(cls, v): """验证模型标识符不能为空且不能包含特殊字符""" @@ -94,7 +90,6 @@ class ModelInfo(ValidatedConfigBase): raise ValueError("模型标识符不能包含空格或换行符") return v - @field_validator("name") @classmethod def validate_name(cls, v): """验证模型名称不能为空""" @@ -111,7 +106,6 @@ class TaskConfig(ValidatedConfigBase): temperature: float = Field(default=0.7, description="模型温度") concurrency_count: int = Field(default=1, description="并发请求数量") - @field_validator("model_list") @classmethod def validate_model_list(cls, v): """验证模型列表不能为空""" @@ -178,7 +172,6 @@ class APIAdapterConfig(ValidatedConfigBase): self.api_providers_dict = {provider.name: provider for provider in self.api_providers} self.models_dict = {model.name: model for model in self.models} - @field_validator("models") @classmethod def validate_models_list(cls, v): """验证模型列表""" @@ -197,7 +190,6 @@ class APIAdapterConfig(ValidatedConfigBase): return v - @field_validator("api_providers") @classmethod def validate_api_providers_list(cls, v): """验证API提供商列表""" diff --git a/src/config/config.py b/src/config/config.py index 933b7e567..3fbd7e9e6 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -412,7 +412,6 @@ class APIAdapterConfig(ValidatedConfigBase): self.api_providers_dict = {provider.name: provider for provider in self.api_providers} self.models_dict = {model.name: model for model in self.models} - @field_validator("models") @classmethod def validate_models_list(cls, v): """验证模型列表""" @@ -431,7 +430,6 @@ class APIAdapterConfig(ValidatedConfigBase): return v - @field_validator("api_providers") @classmethod def validate_api_providers_list(cls, v): """验证API提供商列表""" diff --git a/src/config/config_base.py b/src/config/config_base.py index 5d8c7c195..764ec5b71 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -50,7 +50,7 @@ class ConfigBase: except Exception as e: raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e - return cls(**init_args) + return cls() @classmethod def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index be51a21e3..978c5e47b 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -122,7 +122,8 @@ class ChatConfig(ValidatedConfigBase): global_frequency = self._get_global_frequency() return self.talk_frequency if global_frequency is None else global_frequency - def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]: + @staticmethod + def _get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]: """ 根据时间配置列表获取当前时段的频率 @@ -201,7 +202,8 @@ class ChatConfig(ValidatedConfigBase): return None - def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: + @staticmethod + def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: """ 解析流配置字符串并生成对应的 chat_id @@ -280,7 +282,8 @@ class ExpressionConfig(ValidatedConfigBase): rules: List[ExpressionRule] = Field(default_factory=list, description="表达学习规则") - def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: + @staticmethod + def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: """ 解析流配置字符串并生成对应的 chat_id diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 39aef9b3b..a2e0f2621 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -94,8 +94,9 @@ class Individuality: prompt_personality = f"{personality}\n{identity}" return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" + @staticmethod def _get_config_hash( - self, bot_nickname: str, personality_core: str, personality_side: str, identity: str + bot_nickname: str, personality_core: str, personality_side: str, identity: str ) -> tuple[str, str]: """获取personality和identity配置的哈希值 diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py index f70c3ded5..4253efab0 100644 --- a/src/llm_models/payload_content/message.py +++ b/src/llm_models/payload_content/message.py @@ -58,7 +58,7 @@ class MessageBuilder: self, image_format: str, image_base64: str, - support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式 + support_formats=None, # 默认支持格式 ) -> "MessageBuilder": """ 添加图片内容 @@ -66,6 +66,8 @@ class MessageBuilder: :param image_base64: 图片的base64编码 :return: MessageBuilder对象 """ + if support_formats is None: + support_formats = SUPPORTED_IMAGE_FORMATS if image_format.lower() not in support_formats: raise ValueError("不受支持的图片格式") if not image_base64: diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index ee20533ee..48f7be5f2 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -145,9 +145,9 @@ class LLMUsageRecorder: LLM使用情况记录器(SQLAlchemy版本) """ + @staticmethod def record_usage_to_database( - self, - model_info: ModelInfo, + model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index fa0ea6916..4969644e6 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -625,9 +625,9 @@ class LLMRequest: logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") return -1, None # 不再重试请求该模型 + @staticmethod def _check_retry( - self, - remain_try: int, + remain_try: int, retry_interval: int, can_retry_msg: str, cannot_retry_msg: str, @@ -745,7 +745,8 @@ class LLMRequest: ) return -1, None - def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: + @staticmethod + def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: # sourcery skip: extract-method """构建工具选项列表""" if not tools: @@ -809,7 +810,8 @@ class LLMRequest: return final_text - def _inject_random_noise(self, text: str, intensity: int) -> str: + @staticmethod + def _inject_random_noise(text: str, intensity: int) -> str: """在文本中注入随机乱码""" import random import string diff --git a/src/main.py b/src/main.py index d37a5b630..22c6218c9 100644 --- a/src/main.py +++ b/src/main.py @@ -1,37 +1,35 @@ # 再用这个就写一行注释来混提交的我直接全部🌿飞😡 import asyncio -import time import signal import sys +import time + from maim_message import MessageServer - -from src.common.remote import TelemetryHeartBeatTask -from src.manager.async_task_manager import async_task_manager -from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask -from src.common.remote import TelemetryHeartBeatTask -from src.chat.emoji_system.emoji_manager import get_emoji_manager -from src.chat.message_receive.chat_stream import get_chat_manager -from src.config.config import global_config -from src.chat.message_receive.bot import chat_bot -from src.common.logger import get_logger -from src.individuality.individuality import get_individuality, Individuality -from src.common.server import get_global_server, Server -from src.mood.mood_manager import mood_manager from rich.traceback import install -from src.schedule.schedule_manager import schedule_manager -from src.schedule.monthly_plan_manager import monthly_plan_manager -from src.plugin_system.core.event_manager import event_manager -from src.plugin_system.base.component_types import EventType -# from src.api.main import start_api_server - -# 导入新的插件管理器和热重载管理器 -from src.plugin_system.core.plugin_manager import plugin_manager -from src.plugin_system.core.plugin_hot_reload import hot_reload_manager +from src.chat.emoji_system.emoji_manager import get_emoji_manager +from src.chat.memory_system.Hippocampus import hippocampus_manager +from src.chat.message_receive.bot import chat_bot +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask +from src.common.logger import get_logger # 导入消息API和traceback模块 from src.common.message import get_global_api +from src.common.remote import TelemetryHeartBeatTask +from src.common.server import get_global_server, Server +from src.config.config import global_config +from src.individuality.individuality import get_individuality, Individuality +from src.manager.async_task_manager import async_task_manager +from src.mood.mood_manager import mood_manager +from src.plugin_system.base.component_types import EventType +from src.plugin_system.core.event_manager import event_manager +from src.plugin_system.core.plugin_hot_reload import hot_reload_manager +# 导入新的插件管理器和热重载管理器 +from src.plugin_system.core.plugin_manager import plugin_manager +from src.schedule.monthly_plan_manager import monthly_plan_manager +from src.schedule.schedule_manager import schedule_manager -from src.chat.memory_system.Hippocampus import hippocampus_manager +# from src.api.main import start_api_server if not global_config.memory.enable_memory: import src.chat.memory_system.Hippocampus as hippocampus_module @@ -43,7 +41,8 @@ if not global_config.memory.enable_memory: async def initialize_async(self): pass - def get_hippocampus(self): + @staticmethod + def get_hippocampus(): return None async def build_memory(self): @@ -55,9 +54,9 @@ if not global_config.memory.enable_memory: async def consolidate_memory(self): pass + @staticmethod async def get_memory_from_text( - self, - text: str, + text: str, max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3, @@ -65,20 +64,24 @@ if not global_config.memory.enable_memory: ) -> list: return [] + @staticmethod async def get_memory_from_topic( - self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 + valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 ) -> list: return [] + @staticmethod async def get_activate_from_text( - self, text: str, max_depth: int = 3, fast_retrieval: bool = False + text: str, max_depth: int = 3, fast_retrieval: bool = False ) -> tuple[float, list[str]]: return 0.0, [] - def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: + @staticmethod + def get_memory_from_keyword(keyword: str, max_depth: int = 2) -> list: return [] - def get_all_node_names(self) -> list: + @staticmethod + def get_all_node_names() -> list: return [] hippocampus_module.hippocampus_manager = MockHippocampusManager() @@ -114,7 +117,8 @@ class MainSystem: signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - def _cleanup(self): + @staticmethod + def _cleanup(): """清理资源""" try: # 停止消息重组器 diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py index bf3640be0..5807e2acf 100644 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py @@ -118,7 +118,7 @@ class ChatAction: self.regression_count = 0 message_time: float = message.message_info.time # type: ignore - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, @@ -182,7 +182,7 @@ class ChatAction: async def regress_action(self): message_time = time.time() - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py index 3b2ccac30..192e858b6 100644 --- a/src/mais4u/mais4u_chat/s4u_chat.py +++ b/src/mais4u/mais4u_chat/s4u_chat.py @@ -58,7 +58,8 @@ class MessageSenderContainer: """恢复发送。""" self._paused_event.set() - def _calculate_typing_delay(self, text: str) -> float: + @staticmethod + def _calculate_typing_delay(text: str) -> float: """根据文本长度计算模拟打字延迟。""" chars_per_second = s4u_config.chars_per_second min_delay = s4u_config.min_typing_delay @@ -150,6 +151,10 @@ class MessageSenderContainer: if self._task: await self._task + @property + def task(self): + return self._task + class S4UChatManager: def __init__(self): @@ -177,6 +182,7 @@ class S4UChat: def __init__(self, chat_stream: ChatStream): """初始化 S4UChat 实例。""" + self.last_msg_id = self.msg_id self.chat_stream = chat_stream self.stream_id = chat_stream.stream_id self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id @@ -206,7 +212,8 @@ class S4UChat: logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.") - def _get_priority_info(self, message: MessageRecv) -> dict: + @staticmethod + def _get_priority_info(message: MessageRecv) -> dict: """安全地从消息中提取和解析 priority_info""" priority_info_raw = message.priority_info priority_info = {} @@ -219,7 +226,8 @@ class S4UChat: priority_info = priority_info_raw return priority_info - def _is_vip(self, priority_info: dict) -> bool: + @staticmethod + def _is_vip(priority_info: dict) -> bool: """检查消息是否来自VIP用户。""" return priority_info.get("message_type") == "vip" @@ -468,7 +476,6 @@ class S4UChat: await asyncio.sleep(1) def get_processing_message_id(self): - self.last_msg_id = self.msg_id self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}" async def _generate_and_send(self, message: MessageRecv): @@ -565,7 +572,7 @@ class S4UChat: # 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的) sender_container.resume() - if not sender_container._task.done(): + if not sender_container.task.done(): await sender_container.close() await sender_container.join() logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。") @@ -586,3 +593,7 @@ class S4UChat: await self._processing_task except asyncio.CancelledError: logger.info(f"处理任务已成功取消: {self.stream_name}") + + @property + def new_message_event(self): + return self._new_message_event diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index 8d1e22b8f..db852567e 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -124,7 +124,8 @@ class ChatMood: # 发送初始情绪状态到ws端 asyncio.create_task(self.send_emotion_update(self.mood_values)) - def _parse_numerical_mood(self, response: str) -> dict[str, int] | None: + @staticmethod + def _parse_numerical_mood(response: str) -> dict[str, int] | None: try: # The LLM might output markdown with json inside if "```json" in response: diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index 7bd1fe29e..e3682a450 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -161,7 +161,8 @@ class S4UMessageProcessor: else: logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}") - async def handle_internal_message(self, message: MessageRecvS4U): + @staticmethod + async def handle_internal_message(message: MessageRecvS4U): if message.is_internal: group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心") @@ -173,7 +174,7 @@ class S4UMessageProcessor: message.message_info.platform = s4u_chat.chat_stream.platform s4u_chat.internal_message.append(message) - s4u_chat._new_message_event.set() + s4u_chat.new_message_event.set() logger.info( f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}" @@ -182,20 +183,23 @@ class S4UMessageProcessor: return True return False - async def handle_screen_message(self, message: MessageRecvS4U): + @staticmethod + async def handle_screen_message(message: MessageRecvS4U): if message.is_screen: screen_manager.set_screen(message.screen_info) return True return False - async def hadle_if_voice_done(self, message: MessageRecvS4U): + @staticmethod + async def hadle_if_voice_done(message: MessageRecvS4U): if message.voice_done: s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream) s4u_chat.voice_done = message.voice_done return True return False - async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool: + @staticmethod + async def check_if_fake_gift(message: MessageRecvS4U) -> bool: """检查消息是否为假礼物""" if message.is_gift: return False @@ -227,7 +231,8 @@ class S4UMessageProcessor: return True # 非礼物消息,继续正常处理 - async def _handle_context_web_update(self, chat_id: str, message: MessageRecv): + @staticmethod + async def _handle_context_web_update(chat_id: str, message: MessageRecv): """处理上下文网页更新的独立task Args: diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index db6a6edf9..2590a388f 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -98,7 +98,8 @@ class PromptBuilder: self.prompt_built = "" self.activate_messages = "" - async def build_expression_habits(self, chat_stream: ChatStream, chat_history, target): + @staticmethod + async def build_expression_habits(chat_stream: ChatStream, chat_history, target): style_habits = [] grammar_habits = [] @@ -133,7 +134,8 @@ class PromptBuilder: return expression_habits_block - async def build_relation_info(self, chat_stream) -> str: + @staticmethod + async def build_relation_info(chat_stream) -> str: is_group_chat = bool(chat_stream.group_info) who_chat_in_group = [] if is_group_chat: @@ -167,7 +169,8 @@ class PromptBuilder: ) return relation_prompt - async def build_memory_block(self, text: str) -> str: + @staticmethod + async def build_memory_block(text: str) -> str: related_memory = await hippocampus_manager.get_memory_from_text( text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False ) @@ -179,7 +182,8 @@ class PromptBuilder: return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info) return "" - def build_chat_history_prompts(self, chat_stream: ChatStream, message: MessageRecvS4U): + @staticmethod + def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U): message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), @@ -270,7 +274,8 @@ class PromptBuilder: return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str - def build_gift_info(self, message: MessageRecvS4U): + @staticmethod + def build_gift_info(message: MessageRecvS4U): if message.is_gift: return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户" else: @@ -279,7 +284,8 @@ class PromptBuilder: return "" - def build_sc_info(self, message: MessageRecvS4U): + @staticmethod + def build_sc_info(message: MessageRecvS4U): super_chat_manager = get_super_chat_manager() return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id) diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index da12d9f9d..d4ec70edd 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -49,7 +49,8 @@ class S4UStreamGenerator: self.chat_stream = None - async def build_last_internal_message(self, message: MessageRecvS4U, previous_reply_context: str = ""): + @staticmethod + async def build_last_internal_message(message: MessageRecvS4U, previous_reply_context: str = ""): # person_id = PersonInfoManager.get_person_id( # message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id # ) diff --git a/src/mais4u/mais4u_chat/super_chat_manager.py b/src/mais4u/mais4u_chat/super_chat_manager.py index c09367292..5f0ee2ac2 100644 --- a/src/mais4u/mais4u_chat/super_chat_manager.py +++ b/src/mais4u/mais4u_chat/super_chat_manager.py @@ -105,7 +105,8 @@ class SuperChatManager: logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True) await asyncio.sleep(60) # 出错时等待更长时间 - def _calculate_expire_time(self, price: float) -> float: + @staticmethod + def _calculate_expire_time(price: float) -> float: """根据SuperChat金额计算过期时间""" current_time = time.time() diff --git a/src/mais4u/s4u_config.py b/src/mais4u/s4u_config.py index d93cf8345..79a8f92c4 100644 --- a/src/mais4u/s4u_config.py +++ b/src/mais4u/s4u_config.py @@ -78,7 +78,7 @@ class S4UConfigBase: except Exception as e: raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e - return cls(**init_args) + return cls() @classmethod def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: diff --git a/src/manager/async_task_manager.py b/src/manager/async_task_manager.py index 0a2c0d215..92f6675bd 100644 --- a/src/manager/async_task_manager.py +++ b/src/manager/async_task_manager.py @@ -1,4 +1,4 @@ -from abc import abstractmethod +from abc import abstractmethod, ABCMeta import asyncio from asyncio import Task, Event, Lock @@ -9,7 +9,7 @@ from src.common.logger import get_logger logger = get_logger("async_task_manager") -class AsyncTask: +class AsyncTask(metaclass=ABCMeta): """异步任务基类""" def __init__(self, task_name: str | None = None, wait_before_start: int = 0, run_interval: int = 0): diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 95dc41cfb..3a036d029 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -1,18 +1,18 @@ import copy -import hashlib import datetime -import asyncio -import orjson +import hashlib import time - -from json_repair import repair_json from typing import Any, Callable, Dict, Union, Optional + +import orjson +from json_repair import repair_json from sqlalchemy import select -from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import PersonInfo + from src.common.database.sqlalchemy_database_api import get_db_session -from src.llm_models.utils_model import LLMRequest +from src.common.database.sqlalchemy_models import PersonInfo +from src.common.logger import get_logger from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest """ PersonInfoManager 类方法功能摘要: @@ -116,7 +116,8 @@ class PersonInfoManager: logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}") return False - async def get_person_id_by_person_name(self, person_name: str) -> str: + @staticmethod + async def get_person_id_by_person_name(person_name: str) -> str: """根据用户名获取用户ID""" try: # 在需要时获取会话 @@ -188,7 +189,8 @@ class PersonInfoManager: await _db_create_async(final_data) - async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None): + @staticmethod + async def _safe_create_person_info(person_id: str, data: Optional[dict] = None): """安全地创建用户信息,处理竞态条件""" if not person_id: logger.debug("创建失败,person_id不存在") diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 34e5332c9..720076eb2 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -3,7 +3,7 @@ import traceback import os import pickle import random -from typing import List, Dict, Any +from typing import List, Dict, Any, Coroutine from src.config.config import global_config from src.common.logger import get_logger from src.person_info.relationship_manager import get_relationship_manager @@ -201,7 +201,7 @@ class RelationshipBuilder: messages = get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) return len(messages) - def _count_messages_between(self, start_time: float, end_time: float) -> int: + def _count_messages_between(self, start_time: float, end_time: float) -> Coroutine[Any, Any, int]: """计算两个时间点之间的消息数量(不包含边界),用于间隔检查""" return num_new_messages_since(self.chat_id, start_time, end_time) @@ -314,18 +314,12 @@ class RelationshipBuilder: if not self.person_engaged_cache: return f"{self.log_prefix} 关系缓存为空" - status_lines = [f"{self.log_prefix} 关系缓存状态:"] - status_lines.append( - f"最后处理消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_processed_message_time)) if self.last_processed_message_time > 0 else '未设置'}" - ) - status_lines.append( - f"最后清理时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_cleanup_time)) if self.last_cleanup_time > 0 else '未执行'}" - ) - status_lines.append(f"总用户数:{len(self.person_engaged_cache)}") - status_lines.append( - f"清理配置:{'启用' if SEGMENT_CLEANUP_CONFIG['enable_cleanup'] else '禁用'} (最大保存{SEGMENT_CLEANUP_CONFIG['max_segment_age_days']}天, 每用户最多{SEGMENT_CLEANUP_CONFIG['max_segments_per_user']}段)" - ) - status_lines.append("") + status_lines = [f"{self.log_prefix} 关系缓存状态:", + f"最后处理消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_processed_message_time)) if self.last_processed_message_time > 0 else '未设置'}", + f"最后清理时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_cleanup_time)) if self.last_cleanup_time > 0 else '未执行'}", + f"总用户数:{len(self.person_engaged_cache)}", + f"清理配置:{'启用' if SEGMENT_CLEANUP_CONFIG['enable_cleanup'] else '禁用'} (最大保存{SEGMENT_CLEANUP_CONFIG['max_segment_age_days']}天, 每用户最多{SEGMENT_CLEANUP_CONFIG['max_segments_per_user']}段)", + ""] for person_id, segments in self.person_engaged_cache.items(): total_count = self._get_total_message_count(person_id) diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index b3badbe0c..c2a3ffb96 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -492,7 +492,8 @@ class RelationshipManager: return current_points - def calculate_time_weight(self, point_time: str, current_time: str) -> float: + @staticmethod + def calculate_time_weight(point_time: str, current_time: str) -> float: """计算基于时间的权重系数""" try: point_timestamp = datetime.strptime(point_time, "%Y-%m-%d %H:%M:%S") @@ -516,7 +517,8 @@ class RelationshipManager: logger.error(f"计算时间权重失败: {e}") return 0.5 # 发生错误时返回中等权重 - def tfidf_similarity(self, s1, s2): + @staticmethod + def tfidf_similarity(s1, s2): """ 使用 TF-IDF 和余弦相似度计算两个句子的相似性。 """ @@ -551,7 +553,8 @@ class RelationshipManager: # 返回 s1 和 s2 的相似度 return similarity_matrix[0, 1] - def sequence_similarity(self, s1, s2): + @staticmethod + def sequence_similarity(s1, s2): """ 使用 SequenceMatcher 计算两个句子的相似性。 """ diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 3d161b847..76dc8f5cb 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -8,7 +8,7 @@ readable_text = message_api.build_readable_messages(messages) """ -from typing import List, Dict, Any, Tuple, Optional +from typing import List, Dict, Any, Tuple, Optional, Coroutine from src.config.config import global_config import time from src.chat.utils.chat_message_builder import ( @@ -36,7 +36,7 @@ from src.chat.utils.chat_message_builder import ( def get_messages_by_time( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> List[Dict[str, Any]]: +) -> Coroutine[Any, Any, list[dict[str, Any]]]: """ 获取指定时间范围内的消息 @@ -155,7 +155,7 @@ def get_messages_by_time_in_chat_for_users( person_ids: List[str], limit: int = 0, limit_mode: str = "latest", -) -> List[Dict[str, Any]]: +) -> Coroutine[Any, Any, list[dict[str, Any]]]: """ 获取指定聊天中指定用户在指定时间范围内的消息 @@ -186,7 +186,7 @@ def get_messages_by_time_in_chat_for_users( def get_random_chat_messages( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> List[Dict[str, Any]]: +) -> Coroutine[Any, Any, list[dict[str, Any]]]: """ 随机选择一个聊天,返回该聊天在指定时间范围内的消息 @@ -214,7 +214,7 @@ def get_random_chat_messages( def get_messages_by_time_for_users( start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" -) -> List[Dict[str, Any]]: +) -> Coroutine[Any, Any, list[dict[str, Any]]]: """ 获取指定用户在所有聊天中指定时间范围内的消息 @@ -238,7 +238,8 @@ def get_messages_by_time_for_users( return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode) -def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]: +def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> Coroutine[ + Any, Any, list[dict[str, Any]]]: """ 获取指定时间戳之前的消息 @@ -264,7 +265,7 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool def get_messages_before_time_in_chat( chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False -) -> List[Dict[str, Any]]: +) -> Coroutine[Any, Any, list[dict[str, Any]]]: """ 获取指定聊天中指定时间戳之前的消息 @@ -293,7 +294,8 @@ def get_messages_before_time_in_chat( return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit) -def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]: +def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> Coroutine[ + Any, Any, list[dict[str, Any]]]: """ 获取指定用户在指定时间戳之前的消息 @@ -317,7 +319,7 @@ def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], def get_recent_messages( chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False -) -> List[Dict[str, Any]]: +) -> Coroutine[Any, Any, list[dict[str, Any]]]: """ 获取指定聊天中最近一段时间的消息 @@ -354,7 +356,8 @@ def get_recent_messages( # ============================================================================= -def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int: +def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> Coroutine[ + Any, Any, int]: """ 计算指定聊天中从开始时间到结束时间的新消息数量 @@ -378,7 +381,8 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional return num_new_messages_since(chat_id, start_time, end_time) -def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: +def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> Coroutine[ + Any, Any, int]: """ 计算指定聊天中指定用户从开始时间到结束时间的新消息数量 @@ -416,7 +420,7 @@ def build_readable_messages_to_str( read_mark: float = 0.0, truncate: bool = False, show_actions: bool = False, -) -> str: +) -> Coroutine[Any, Any, str]: """ 将消息列表构建成可读的字符串 diff --git a/src/plugin_system/apis/permission_api.py b/src/plugin_system/apis/permission_api.py index 94c5c3fdd..f40198352 100644 --- a/src/plugin_system/apis/permission_api.py +++ b/src/plugin_system/apis/permission_api.py @@ -44,7 +44,7 @@ class UserInfo: def to_tuple(self) -> tuple[str, str]: """转换为元组格式""" - return (self.platform, self.user_id) + return self.platform, self.user_id class IPermissionManager(ABC): diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 334308795..390397a22 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -118,10 +118,10 @@ async def wait_adapter_response(request_id: str, timeout: float = 30.0) -> dict: response = await asyncio.wait_for(future, timeout=timeout) return response except asyncio.TimeoutError: - _adapter_response_pool.pop(request_id, None) + await _adapter_response_pool.pop(request_id, None) return {"status": "error", "message": "timeout"} except Exception as e: - _adapter_response_pool.pop(request_id, None) + await _adapter_response_pool.pop(request_id, None) return {"status": "error", "message": str(e)} diff --git a/src/plugin_system/base/base_event.py b/src/plugin_system/base/base_event.py index ff33e30cd..c7dd09a58 100644 --- a/src/plugin_system/base/base_event.py +++ b/src/plugin_system/base/base_event.py @@ -1,5 +1,6 @@ import asyncio from typing import List, Dict, Any, Optional + from src.common.logger import get_logger logger = get_logger("base_event") @@ -90,8 +91,6 @@ class BaseEvent: self.allowed_subscribers = allowed_subscribers # 记录事件处理器名 self.allowed_triggers = allowed_triggers # 记录插件名 - from src.plugin_system.base.base_events_handler import BaseEventHandler - self.subscribers: List["BaseEventHandler"] = [] # 订阅该事件的事件处理器列表 self.event_handle_lock = asyncio.Lock() @@ -150,7 +149,8 @@ class BaseEvent: return HandlerResultsCollection(processed_results) - async def _execute_subscriber(self, subscriber, params: dict) -> HandlerResult: + @staticmethod + async def _execute_subscriber(subscriber, params: dict) -> HandlerResult: """执行单个订阅者处理器""" try: return await subscriber.execute(params) diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index 6cf78b19f..fa5936b81 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -277,7 +277,8 @@ class PluginBase(ABC): return config_version_field.default return "1.0.0" - def _get_current_config_version(self, config: Dict[str, Any]) -> str: + @staticmethod + def _get_current_config_version(config: Dict[str, Any]) -> str: """从配置文件中获取当前版本号""" if "plugin" in config and "config_version" in config["plugin"]: return str(config["plugin"]["config_version"]) diff --git a/src/plugin_system/base/plus_command.py b/src/plugin_system/base/plus_command.py index 0d9780ada..a64866806 100644 --- a/src/plugin_system/base/plus_command.py +++ b/src/plugin_system/base/plus_command.py @@ -149,7 +149,7 @@ class PlusCommand(ABC): Returns: bool: 如果匹配返回True """ - return not self.args.is_empty() or self._is_exact_command_call() + return not self.args.is_empty or self._is_exact_command_call() def _is_exact_command_call(self) -> bool: """检查是否是精确的命令调用(无参数)""" diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 529f327a3..0f64aa9ec 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -31,6 +31,7 @@ class ComponentRegistry: def __init__(self): # 命名空间式组件名构成法 f"{component_type}.{component_name}" + self._plus_command_registry: Dict[str, Type[PlusCommand]] = {} self._components: Dict[str, ComponentInfo] = {} """组件注册表 命名空间式组件名 -> 组件信息""" self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} @@ -618,7 +619,7 @@ class ComponentRegistry: def get_plus_command_registry(self) -> Dict[str, Type[PlusCommand]]: """获取PlusCommand注册表""" if not hasattr(self, "_plus_command_registry"): - self._plus_command_registry: Dict[str, Type[PlusCommand]] = {} + pass return self._plus_command_registry.copy() def get_registered_plus_command_info(self, command_name: str) -> Optional[PlusCommandInfo]: @@ -667,7 +668,8 @@ class ComponentRegistry: plugin_info = self.get_plugin_info(plugin_name) return plugin_info.components if plugin_info else [] - def get_plugin_config(self, plugin_name: str) -> dict: + @staticmethod + def get_plugin_config(plugin_name: str) -> dict: """获取插件配置 Args: diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index f359409af..4108adad0 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -7,6 +7,7 @@ from typing import Dict, Type, List, Optional, Any, Union from threading import Lock from src.common.logger import get_logger +from src.plugin_system import BaseEventHandler from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.component_types import EventType @@ -198,7 +199,7 @@ class EventManager: """ return self._event_handlers.get(handler_name) - def get_all_event_handlers(self) -> Dict[str, BaseEventHandler]: + def get_all_event_handlers(self) -> dict[str, type[BaseEventHandler]]: """获取所有已注册的事件处理器 Returns: diff --git a/src/plugin_system/core/plugin_hot_reload.py b/src/plugin_system/core/plugin_hot_reload.py index e85010efb..12c87a6ef 100644 --- a/src/plugin_system/core/plugin_hot_reload.py +++ b/src/plugin_system/core/plugin_hot_reload.py @@ -290,7 +290,8 @@ class PluginHotReloadManager: logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}", exc_info=True) return False - def _resolve_plugin_name(self, folder_name: str) -> str: + @staticmethod + def _resolve_plugin_name(folder_name: str) -> str: """ 将文件夹名称解析为实际的插件名称 通过检查插件管理器中的路径映射来找到对应的插件名 @@ -349,7 +350,8 @@ class PluginHotReloadManager: # 出错时尝试简单重载 return self._reload_plugin(plugin_name) - def _force_clear_plugin_modules(self, plugin_name: str): + @staticmethod + def _force_clear_plugin_modules(plugin_name: str): """强制清理插件相关的模块缓存""" # 找到所有相关的模块名 @@ -366,7 +368,8 @@ class PluginHotReloadManager: logger.debug(f"🗑️ 清理模块缓存: {module_name}") del sys.modules[module_name] - def _force_reimport_plugin(self, plugin_name: str): + @staticmethod + def _force_reimport_plugin(plugin_name: str): """强制重新导入插件(委托给插件管理器)""" try: # 使用插件管理器的重载功能 @@ -377,7 +380,8 @@ class PluginHotReloadManager: logger.error(f"❌ 强制重新导入插件 {plugin_name} 时发生错误: {e}", exc_info=True) return False - def _unload_plugin(self, plugin_name: str): + @staticmethod + def _unload_plugin(plugin_name: str): """卸载指定插件""" try: logger.info(f"🗑️ 开始卸载插件: {plugin_name}") @@ -490,7 +494,8 @@ class PluginHotReloadManager: "debounce_delay": self.file_handlers[0].debounce_delay if self.file_handlers else 0, } - def clear_all_caches(self): + @staticmethod + def clear_all_caches(): """清理所有Python模块缓存""" try: logger.info("🧹 开始清理所有Python模块缓存...") diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 07d33b773..237cb6429 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -67,7 +67,8 @@ class PluginManager: except Exception as e: logger.error(f"同步插件 '{plugin_name}' 配置时发生未知错误: {e}") - def _copy_default_config_to_central(self, plugin_name: str, plugin_config_file: Path, central_config_dir: Path): + @staticmethod + def _copy_default_config_to_central(plugin_name: str, plugin_config_file: Path, central_config_dir: Path): """ 如果中央配置不存在,则将插件的默认 config.toml 复制到中央目录。 """ @@ -96,7 +97,8 @@ class PluginManager: shutil.copy2(central_file, target_plugin_file) logger.info(f"已将中央配置 '{central_file.name}' 同步到插件 '{plugin_name}'") - def _is_file_content_identical(self, file1: Path, file2: Path) -> bool: + @staticmethod + def _is_file_content_identical(file1: Path, file2: Path) -> bool: """ 通过比较 MD5 哈希值检查两个文件的内容是否相同。 """ @@ -403,7 +405,8 @@ class PluginManager: # == 兼容性检查 == - def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: + @staticmethod + def _check_plugin_version_compatibility(plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: """检查插件版本兼容性 Args: @@ -557,7 +560,8 @@ class PluginManager: else: logger.warning("😕 没有成功加载任何插件") - def _show_plugin_components(self, plugin_name: str) -> None: + @staticmethod + def _show_plugin_components(plugin_name: str) -> None: if plugin_info := component_registry.get_plugin_info(plugin_name): component_types = {} for comp in plugin_info.components: @@ -673,7 +677,8 @@ class PluginManager: """ return self.reload_plugin(plugin_name) - def clear_all_plugin_caches(self): + @staticmethod + def clear_all_plugin_caches(): """清理所有插件相关的模块缓存(简化版)""" try: logger.info("🧹 清理模块缓存...") diff --git a/src/plugin_system/utils/dependency_manager.py b/src/plugin_system/utils/dependency_manager.py index 106748e79..980f538cc 100644 --- a/src/plugin_system/utils/dependency_manager.py +++ b/src/plugin_system/utils/dependency_manager.py @@ -162,7 +162,8 @@ class DependencyManager: return False, all_errors - def _normalize_dependencies(self, dependencies: Any) -> List[PythonDependency]: + @staticmethod + def _normalize_dependencies(dependencies: Any) -> List[PythonDependency]: """将依赖列表标准化为PythonDependency对象""" normalized = [] @@ -191,7 +192,8 @@ class DependencyManager: return normalized - def _check_single_dependency(self, dep: PythonDependency) -> bool: + @staticmethod + def _check_single_dependency(dep: PythonDependency) -> bool: """检查单个依赖是否满足要求""" def _try_check(import_name: str) -> bool: diff --git a/src/plugin_system/utils/manifest_utils.py b/src/plugin_system/utils/manifest_utils.py index 9b19c033e..b714aefd7 100644 --- a/src/plugin_system/utils/manifest_utils.py +++ b/src/plugin_system/utils/manifest_utils.py @@ -82,10 +82,10 @@ class VersionComparator: normalized = VersionComparator.normalize_version(version) try: parts = normalized.split(".") - return (int(parts[0]), int(parts[1]), int(parts[2])) + return int(parts[0]), int(parts[1]), int(parts[2]) except (ValueError, IndexError): logger.warning(f"无法解析版本号: {version},使用默认版本 0.0.0") - return (0, 0, 0) + return 0, 0, 0 @staticmethod def compare_versions(version1: str, version2: str) -> int: diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py index 67db667fb..67322ba34 100644 --- a/src/plugin_system/utils/permission_decorators.py +++ b/src/plugin_system/utils/permission_decorators.py @@ -58,7 +58,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) if chat_stream is None: logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") - return + return None # 检查权限 has_permission = permission_api.check_permission( @@ -72,7 +72,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) # 对于PlusCommand的execute方法,需要返回适当的元组 if func.__name__ == "execute" and hasattr(args[0], "send_text"): return False, "权限不足", True - return + return None # 权限检查通过,执行原函数 return await func(*args, **kwargs) @@ -90,7 +90,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) if chat_stream is None: logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") - return + return None # 检查权限 has_permission = permission_api.check_permission( @@ -101,7 +101,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) logger.warning( f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}" ) - return + return None # 权限检查通过,执行原函数 return func(*args, **kwargs) @@ -156,7 +156,7 @@ def require_master(deny_message: Optional[str] = None): if chat_stream is None: logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") - return + return None # 检查是否为Master用户 is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id) @@ -166,7 +166,7 @@ def require_master(deny_message: Optional[str] = None): await text_to_stream(message, chat_stream.stream_id) if func.__name__ == "execute" and hasattr(args[0], "send_text"): return False, "需要Master权限", True - return + return None # 权限检查通过,执行原函数 return await func(*args, **kwargs) @@ -184,14 +184,14 @@ def require_master(deny_message: Optional[str] = None): if chat_stream is None: logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") - return + return None # 检查是否为Master用户 is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id) if not is_master: logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 不是Master用户") - return + return None # 权限检查通过,执行原函数 return func(*args, **kwargs) diff --git a/src/plugins/built_in/maizone_refactored/services/content_service.py b/src/plugins/built_in/maizone_refactored/services/content_service.py index 27f2a0ee9..9f7da7ccf 100644 --- a/src/plugins/built_in/maizone_refactored/services/content_service.py +++ b/src/plugins/built_in/maizone_refactored/services/content_service.py @@ -119,10 +119,12 @@ class ContentService: logger.error(f"生成说说内容时发生异常: {e}") return "" - async def generate_comment(self, content: str, target_name: str, rt_con: str = "", images: list = []) -> str: + async def generate_comment(self, content: str, target_name: str, rt_con: str = "", images=None) -> str: """ 针对一条具体的说说内容生成评论。 """ + if images is None: + images = [] for i in range(3): # 重试3次 try: chat_manager = get_chat_manager() @@ -180,7 +182,8 @@ class ContentService: return "" return "" - async def generate_comment_reply(self, story_content: str, comment_content: str, commenter_name: str) -> str: + @staticmethod + async def generate_comment_reply(story_content: str, comment_content: str, commenter_name: str) -> str: """ 针对自己说说的评论,生成回复。 """ diff --git a/src/plugins/built_in/maizone_refactored/services/cookie_service.py b/src/plugins/built_in/maizone_refactored/services/cookie_service.py index b4aedf322..1c61a29fd 100644 --- a/src/plugins/built_in/maizone_refactored/services/cookie_service.py +++ b/src/plugins/built_in/maizone_refactored/services/cookie_service.py @@ -50,7 +50,8 @@ class CookieService: logger.error(f"无法读取或解析Cookie文件 {cookie_file_path}: {e}") return None - async def _get_cookies_from_adapter(self, stream_id: Optional[str]) -> Optional[Dict[str, str]]: + @staticmethod + async def _get_cookies_from_adapter(stream_id: Optional[str]) -> Optional[Dict[str, str]]: """通过Adapter API获取Cookie""" try: params = {"domain": "user.qzone.qq.com"} diff --git a/src/plugins/built_in/maizone_refactored/services/image_service.py b/src/plugins/built_in/maizone_refactored/services/image_service.py index cbb411da7..1ffcd7d70 100644 --- a/src/plugins/built_in/maizone_refactored/services/image_service.py +++ b/src/plugins/built_in/maizone_refactored/services/image_service.py @@ -59,7 +59,8 @@ class ImageService: logger.error(f"处理AI配图时发生异常: {e}") return False - async def _call_siliconflow_api(self, api_key: str, story: str, image_dir: str, batch_size: int) -> bool: + @staticmethod + async def _call_siliconflow_api(api_key: str, story: str, image_dir: str, batch_size: int) -> bool: """ 调用硅基流动(SiliconFlow)的API来生成图片。 diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index abb5f97e6..545e615a0 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -187,7 +187,8 @@ class QZoneService: # --- Internal Helper Methods --- - async def _get_intercom_context(self, stream_id: str) -> Optional[str]: + @staticmethod + async def _get_intercom_context(stream_id: str) -> Optional[str]: """ 根据 stream_id 查找其所属的互通组,并构建该组的聊天上下文。 @@ -398,7 +399,8 @@ class QZoneService: logger.error(f"加载本地图片失败: {e}") return [] - def _generate_gtk(self, skey: str) -> str: + @staticmethod + def _generate_gtk(skey: str) -> str: hash_val = 5381 for char in skey: hash_val += (hash_val << 5) + ord(char) @@ -435,7 +437,8 @@ class QZoneService: logger.error(f"更新或加载Cookie时发生异常: {e}") return None - async def _fetch_cookies_http(self, host: str, port: str, napcat_token: str) -> Optional[Dict]: + @staticmethod + async def _fetch_cookies_http(host: str, port: str, napcat_token: str) -> Optional[Dict]: """通过HTTP服务器获取Cookie""" url = f"http://{host}:{port}/get_cookies" max_retries = 5 diff --git a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py index 0fa7edb99..3aabc88b6 100644 --- a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py +++ b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py @@ -36,7 +36,8 @@ class ReplyTrackerService: self._load_data() logger.debug(f"ReplyTrackerService initialized with data file: {self.reply_record_file}") - def _validate_data(self, data: Any) -> bool: + @staticmethod + def _validate_data(data: Any) -> bool: """验证加载的数据格式是否正确""" if not isinstance(data, dict): logger.error("加载的数据不是字典格式") diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index ed32da48d..69ec0956e 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -129,7 +129,8 @@ class SchedulerService: logger.error(f"定时任务循环中发生未知错误: {e}\n{traceback.format_exc()}") await asyncio.sleep(300) # 发生错误后,等待一段时间再重试 - async def _is_processed(self, hour_str: str, activity: str) -> bool: + @staticmethod + async def _is_processed(hour_str: str, activity: str) -> bool: """ 检查指定的任务(某个小时的某个活动)是否已经被成功处理过。 @@ -152,7 +153,8 @@ class SchedulerService: logger.error(f"检查日程处理状态时发生数据库错误: {e}") return False # 数据库异常时,默认为未处理,允许重试 - async def _mark_as_processed(self, hour_str: str, activity: str, success: bool, content: str): + @staticmethod + async def _mark_as_processed(hour_str: str, activity: str, success: bool, content: str): """ 将任务的处理状态和结果写入数据库。 diff --git a/src/plugins/built_in/maizone_refactored/utils/history_utils.py b/src/plugins/built_in/maizone_refactored/utils/history_utils.py index 19b3e7baa..171396de2 100644 --- a/src/plugins/built_in/maizone_refactored/utils/history_utils.py +++ b/src/plugins/built_in/maizone_refactored/utils/history_utils.py @@ -49,7 +49,8 @@ class _SimpleQZoneAPI: if p_skey: self.gtk2 = self._generate_gtk(p_skey) - def _generate_gtk(self, skey: str) -> str: + @staticmethod + def _generate_gtk(skey: str) -> str: hash_val = 5381 for char in skey: hash_val += (hash_val << 5) + ord(char) diff --git a/src/plugins/built_in/napcat_adapter_plugin/plugin.py b/src/plugins/built_in/napcat_adapter_plugin/plugin.py index e15e17b8f..3cd7973e7 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/plugin.py +++ b/src/plugins/built_in/napcat_adapter_plugin/plugin.py @@ -400,9 +400,8 @@ class NapcatAdapterPlugin(BasePlugin): def get_plugin_components(self): self.register_events() - components = [] - components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler)) - components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler)) + components = [(LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler), + (StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler)] for handler in get_classes_in_module(event_handlers): if issubclass(handler, BaseEventHandler): components.append((handler.get_handler_info(), handler)) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py b/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py index 64a1e3faa..73216942e 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py @@ -49,7 +49,8 @@ class SimpleMessageBuffer: """设置插件配置""" self.plugin_config = plugin_config - def get_session_id(self, event_data: Dict[str, Any]) -> str: + @staticmethod + def get_session_id(event_data: Dict[str, Any]) -> str: """根据事件数据生成会话ID""" message_type = event_data.get("message_type", "unknown") user_id = event_data.get("user_id", "unknown") @@ -62,7 +63,8 @@ class SimpleMessageBuffer: else: return f"{message_type}_{user_id}" - def extract_text_from_message(self, message: List[Dict[str, Any]]) -> Optional[str]: + @staticmethod + def extract_text_from_message(message: List[Dict[str, Any]]) -> Optional[str]: """从OneBot消息中提取纯文本,如果包含非文本内容则返回None""" text_parts = [] has_non_text = False @@ -177,7 +179,8 @@ class SimpleMessageBuffer: logger.debug(f"文本消息已添加到缓冲器 {session_id}: {text[:50]}...") return True - async def _cancel_session_timers(self, session: BufferedSession): + @staticmethod + async def _cancel_session_timers(session: BufferedSession): """取消会话的所有定时器""" for task_name in ["timer_task", "delay_task"]: task = getattr(session, task_name) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py b/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py index 0f25bd62e..9757e7cf5 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py @@ -112,7 +112,8 @@ class MessageChunker: else: return [{"_original_message": message}] - def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool: + @staticmethod + def is_chunk_message(message: Union[str, Dict[str, Any]]) -> bool: """判断是否是切片消息""" try: if isinstance(message, str): diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index f289af687..c50f17e7b 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -26,7 +26,7 @@ import json import websockets as Server import base64 from pathlib import Path -from typing import List, Tuple, Optional, Dict, Any +from typing import List, Tuple, Optional, Dict, Any, Coroutine import uuid from maim_message import ( @@ -351,6 +351,7 @@ class MessageHandler: logger.debug("发送到Maibot处理信息") await message_send_instance.message_send(message_base) + return None async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None: # sourcery skip: low-code-quality @@ -518,7 +519,8 @@ class MessageHandler: logger.debug(f"handle_real_message完成,处理了{len(real_message)}个消息段,生成了{len(seg_message)}个seg") return seg_message - async def handle_text_message(self, raw_message: dict) -> Seg: + @staticmethod + async def handle_text_message(raw_message: dict) -> Seg: """ 处理纯文本信息 Parameters: @@ -530,7 +532,8 @@ class MessageHandler: plain_text: str = message_data.get("text") return Seg(type="text", data=plain_text) - async def handle_face_message(self, raw_message: dict) -> Seg | None: + @staticmethod + async def handle_face_message(raw_message: dict) -> Seg | None: """ 处理表情消息 Parameters: @@ -547,7 +550,8 @@ class MessageHandler: logger.warning(f"不支持的表情:{face_raw_id}") return None - async def handle_image_message(self, raw_message: dict) -> Seg | None: + @staticmethod + async def handle_image_message(raw_message: dict) -> Seg | None: """ 处理图片消息与表情包消息 Parameters: @@ -603,6 +607,7 @@ class MessageHandler: return Seg(type="at", data=f"{member_info.get('nickname')}:{member_info.get('user_id')}") else: return None + return None async def handle_record_message(self, raw_message: dict) -> Seg | None: """ @@ -631,7 +636,8 @@ class MessageHandler: return None return Seg(type="voice", data=audio_base64) - async def handle_video_message(self, raw_message: dict) -> Seg | None: + @staticmethod + async def handle_video_message(raw_message: dict) -> Seg | None: """ 处理视频消息 Parameters: @@ -762,7 +768,7 @@ class MessageHandler: return None processed_message: Seg - if image_count < 5 and image_count > 0: + if 5 > image_count > 0: # 处理图片数量小于5的情况,此时解析图片为base64 logger.debug("图片数量小于5,开始解析图片为base64") processed_message = await self._recursive_parse_image_seg(handled_message, True) @@ -779,15 +785,18 @@ class MessageHandler: forward_hint = Seg(type="text", data="这是一条转发消息:\n") return Seg(type="seglist", data=[forward_hint, processed_message]) - async def handle_dice_message(self, raw_message: dict) -> Seg: + @staticmethod + async def handle_dice_message(raw_message: dict) -> Seg: message_data: dict = raw_message.get("data", {}) res = message_data.get("result", "") return Seg(type="text", data=f"[扔了一个骰子,点数是{res}]") - async def handle_shake_message(self, raw_message: dict) -> Seg: + @staticmethod + async def handle_shake_message(raw_message: dict) -> Seg: return Seg(type="text", data="[向你发送了窗口抖动,现在你的屏幕猛烈地震了一下!]") - async def handle_json_message(self, raw_message: dict) -> Seg: + @staticmethod + async def handle_json_message(raw_message: dict) -> Seg | None: """ 处理JSON消息 Parameters: @@ -906,7 +915,8 @@ class MessageHandler: logger.error(f"处理JSON消息时出错: {e}") return None - async def handle_rps_message(self, raw_message: dict) -> Seg: + @staticmethod + async def handle_rps_message(raw_message: dict) -> Seg: message_data: dict = raw_message.get("data", {}) res = message_data.get("result", "") if res == "1": @@ -1089,7 +1099,8 @@ class MessageHandler: return None return response_data.get("messages") - async def _send_buffered_message(self, session_id: str, merged_text: str, original_event: Dict[str, Any]): + @staticmethod + async def _send_buffered_message(session_id: str, merged_text: str, original_event: Dict[str, Any]): """发送缓冲的合并消息""" try: # 从原始事件数据中提取信息 diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py index 217347c36..83d19a1d7 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py @@ -14,6 +14,7 @@ class MetaEventHandler: """ def __init__(self): + self.last_heart_beat = time.time() self.interval = 5.0 # 默认值,稍后通过set_plugin_config设置 self._interval_checking = False self.plugin_config = None @@ -37,7 +38,6 @@ class MetaEventHandler: if message["status"].get("online") and message["status"].get("good"): if not self._interval_checking: asyncio.create_task(self.check_heartbeat()) - self.last_heart_beat = time.time() self.interval = message.get("interval") / 1000 else: self_id = message.get("self_id") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index c373a9a10..a9eaead16 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -197,9 +197,11 @@ class NoticeHandler: if system_notice: await self.put_notice(message_base) + return None else: logger.debug("发送到Maibot处理通知信息") await message_send_instance.message_send(message_base) + return None async def handle_poke_notify( self, raw_message: dict, group_id: int, user_id: int @@ -464,7 +466,8 @@ class NoticeHandler: ) return seg_data, operator_info - async def put_notice(self, message_base: MessageBase) -> None: + @staticmethod + async def put_notice(message_base: MessageBase) -> None: """ 将处理后的通知消息放入通知队列 """ @@ -577,7 +580,8 @@ class NoticeHandler: self.banned_list.remove(ban_record) await asyncio.sleep(5) - async def send_notice(self) -> None: + @staticmethod + async def send_notice() -> None: """ 发送通知消息到Napcat """ diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index 5d6d91467..ef380c82f 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -76,7 +76,7 @@ class SendHandler: processed_message = await self.handle_seg_recursive(message_segment, user_info) except Exception as e: logger.error(f"处理消息时发生错误: {e}") - return + return None if not processed_message: logger.critical("现在暂时不支持解析此回复!") @@ -94,7 +94,7 @@ class SendHandler: id_name = "user_id" else: logger.error("无法识别的消息类型") - return + return None logger.info("尝试发送到napcat") response = await self.send_message_to_napcat( action, @@ -107,8 +107,10 @@ class SendHandler: logger.info("消息发送成功") qq_message_id = response.get("data", {}).get("message_id") await self.message_sent_back(raw_message_base, qq_message_id) + return None else: logger.warning(f"消息发送失败,napcat返回:{str(response)}") + return None async def send_command(self, raw_message_base: MessageBase) -> None: """ @@ -146,7 +148,7 @@ class SendHandler: command, args_dict = self.handle_send_like_command(args) case _: logger.error(f"未知命令: {command_name}") - return + return None except Exception as e: logger.error(f"处理命令时发生错误: {e}") return None @@ -158,8 +160,10 @@ class SendHandler: response = await self.send_message_to_napcat(command, args_dict) if response.get("status") == "ok": logger.info(f"命令 {command_name} 执行成功") + return None else: logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}") + return None async def handle_adapter_command(self, raw_message_base: MessageBase) -> None: """ @@ -265,7 +269,8 @@ class SendHandler: new_payload = self.build_payload(payload, self.handle_file_message(file_path), False) return new_payload - def build_payload(self, payload: list, addon: dict | list, is_reply: bool = False) -> list: + @staticmethod + def build_payload(payload: list, addon: dict | list, is_reply: bool = False) -> list: # sourcery skip: for-append-to-extend, merge-list-append, simplify-generator """构建发送的消息体""" if is_reply: @@ -324,11 +329,13 @@ class SendHandler: return reply_seg - def handle_text_message(self, message: str) -> dict: + @staticmethod + def handle_text_message(message: str) -> dict: """处理文本消息""" return {"type": "text", "data": {"text": message}} - def handle_image_message(self, encoded_image: str) -> dict: + @staticmethod + def handle_image_message(encoded_image: str) -> dict: """处理图片消息""" return { "type": "image", @@ -338,7 +345,8 @@ class SendHandler: }, } # base64 编码的图片 - def handle_emoji_message(self, encoded_emoji: str) -> dict: + @staticmethod + def handle_emoji_message(encoded_emoji: str) -> dict: """处理表情消息""" encoded_image = encoded_emoji image_format = get_image_format(encoded_emoji) @@ -369,39 +377,45 @@ class SendHandler: "data": {"file": f"base64://{encoded_voice}"}, } - def handle_voiceurl_message(self, voice_url: str) -> dict: + @staticmethod + def handle_voiceurl_message(voice_url: str) -> dict: """处理语音链接消息""" return { "type": "record", "data": {"file": voice_url}, } - def handle_music_message(self, song_id: str) -> dict: + @staticmethod + def handle_music_message(song_id: str) -> dict: """处理音乐消息""" return { "type": "music", "data": {"type": "163", "id": song_id}, } - def handle_videourl_message(self, video_url: str) -> dict: + @staticmethod + def handle_videourl_message(video_url: str) -> dict: """处理视频链接消息""" return { "type": "video", "data": {"file": video_url}, } - def handle_file_message(self, file_path: str) -> dict: + @staticmethod + def handle_file_message(file_path: str) -> dict: """处理文件消息""" return { "type": "file", "data": {"file": f"file://{file_path}"}, } - def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + @staticmethod + def delete_msg_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: """处理删除消息命令""" return "delete_msg", {"message_id": args["message_id"]} - def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + @staticmethod + def handle_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理封禁命令 Args: @@ -429,7 +443,8 @@ class SendHandler: }, ) - def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + @staticmethod + def handle_whole_ban_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理全体禁言命令 Args: @@ -452,7 +467,8 @@ class SendHandler: }, ) - def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + @staticmethod + def handle_kick_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理群成员踢出命令 Args: @@ -477,7 +493,8 @@ class SendHandler: }, ) - def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + @staticmethod + def handle_poke_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理戳一戳命令 Args: @@ -504,7 +521,8 @@ class SendHandler: }, ) - def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + @staticmethod + def handle_set_emoji_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: """处理设置表情回应命令 Args: @@ -526,7 +544,8 @@ class SendHandler: {"message_id": message_id, "emoji_id": emoji_id, "set": set_like}, ) - def handle_send_like_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + @staticmethod + def handle_send_like_command(args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: """ 处理发送点赞命令的逻辑。 @@ -547,7 +566,8 @@ class SendHandler: {"user_id": user_id, "times": times}, ) - def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + @staticmethod + def handle_ai_voice_send_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """ 处理AI语音发送命令的逻辑。 并返回 NapCat 兼容的 (action, params) 元组。 @@ -594,7 +614,8 @@ class SendHandler: return {"status": "error", "message": str(e)} return response - async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None: + @staticmethod + async def message_sent_back(message_base: MessageBase, qq_message_id: str) -> None: # 修改 additional_config,添加 echo 字段 if message_base.message_info.additional_config is None: message_base.message_info.additional_config = {} @@ -612,8 +633,9 @@ class SendHandler: logger.debug("已回送消息ID") return + @staticmethod async def send_adapter_command_response( - self, original_message: MessageBase, response_data: dict, request_id: str + original_message: MessageBase, response_data: dict, request_id: str ) -> None: """ 发送适配器命令响应回MaiBot @@ -642,7 +664,8 @@ class SendHandler: except Exception as e: logger.error(f"发送适配器命令响应时出错: {e}") - def handle_at_message_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: + @staticmethod + def handle_at_message_command(args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]: """处理艾特并发送消息命令 Args: diff --git a/src/plugins/built_in/permission_management/plugin.py b/src/plugins/built_in/permission_management/plugin.py index 174482d47..e33a6d08f 100644 --- a/src/plugins/built_in/permission_management/plugin.py +++ b/src/plugins/built_in/permission_management/plugin.py @@ -111,7 +111,8 @@ class PermissionCommand(PlusCommand): await self.send_text(help_text) - def _parse_user_mention(self, mention: str) -> Optional[str]: + @staticmethod + def _parse_user_mention(mention: str) -> Optional[str]: """解析用户提及,提取QQ号 支持的格式: diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index cd4d753c6..741cb38b9 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -34,11 +34,11 @@ class ManagementCommand(PlusCommand): @require_permission("plugin.management.admin", "❌ 你没有插件管理的权限") async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]: """执行插件管理命令""" - if args.is_empty(): + if args.is_empty: await self._show_help("all") return True, "显示帮助信息", True - subcommand = args.get_first().lower() + subcommand = args.get_first.lower() remaining_args = args.get_args()[1:] # 获取除第一个参数外的所有参数 if subcommand in ["plugin", "插件"]: @@ -318,7 +318,8 @@ class ManagementCommand(PlusCommand): else: await self.send_text(f"❌ 插件目录添加失败: `{dir_path}`") - def _fetch_all_registered_components(self) -> List[ComponentInfo]: + @staticmethod + def _fetch_all_registered_components() -> List[ComponentInfo]: all_plugin_info = component_manage_api.get_all_plugin_info() if not all_plugin_info: return [] diff --git a/src/plugins/built_in/reminder_plugin/plugin.py b/src/plugins/built_in/reminder_plugin/plugin.py index 55fd8b85d..5382cccff 100644 --- a/src/plugins/built_in/reminder_plugin/plugin.py +++ b/src/plugins/built_in/reminder_plugin/plugin.py @@ -1,6 +1,7 @@ import asyncio from datetime import datetime from typing import List, Tuple, Type, Optional + from dateutil.parser import parse as parse_datetime from src.common.logger import get_logger @@ -14,7 +15,7 @@ from src.plugin_system import ( ActionActivationType, ) from src.plugin_system.apis import send_api, llm_api, generator_api -from src.plugin_system.base.component_types import ChatType, ComponentType +from src.plugin_system.base.component_types import ComponentType logger = get_logger(__name__) diff --git a/src/plugins/built_in/tts_plugin/plugin.py b/src/plugins/built_in/tts_plugin/plugin.py index 30748a9ff..fc625c093 100644 --- a/src/plugins/built_in/tts_plugin/plugin.py +++ b/src/plugins/built_in/tts_plugin/plugin.py @@ -74,7 +74,8 @@ class TTSAction(BaseAction): logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}") return False, f"执行TTS动作时出错: {e}" - def _process_text_for_tts(self, text: str) -> str: + @staticmethod + def _process_text_for_tts(text: str) -> str: """ 处理文本使其更适合TTS使用 - 移除不必要的特殊字符和表情符号 diff --git a/src/plugins/built_in/web_search_tool/engines/bing_engine.py b/src/plugins/built_in/web_search_tool/engines/bing_engine.py index ac90956e0..6d32492ad 100644 --- a/src/plugins/built_in/web_search_tool/engines/bing_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/bing_engine.py @@ -111,7 +111,8 @@ class BingSearchEngine(BaseSearchEngine): logger.debug(f"Bing搜索 [{keyword}] 完成,总共 {len(list_result)} 个结果") return list_result[:num_results] if len(list_result) > num_results else list_result - def _parse_html(self, url: str) -> List[Dict[str, Any]]: + @staticmethod + def _parse_html(url: str) -> List[Dict[str, Any]]: """解析处理结果""" try: logger.debug(f"访问Bing搜索URL: {url}") diff --git a/src/plugins/built_in/web_search_tool/tools/url_parser.py b/src/plugins/built_in/web_search_tool/tools/url_parser.py index 315e06271..3a05423a7 100644 --- a/src/plugins/built_in/web_search_tool/tools/url_parser.py +++ b/src/plugins/built_in/web_search_tool/tools/url_parser.py @@ -89,7 +89,7 @@ class URLParserTool(BaseTool): title = soup.title.string if soup.title else "无标题" for script in soup(["script", "style"]): script.extract() - text = soup.get_text(separator="\n", strip=True) + text = soup.get_text(strip=True) if not text: return {"error": "无法从页面提取有效文本内容。"} diff --git a/src/schedule/llm_generator.py b/src/schedule/llm_generator.py index 9dda68f80..35e35b0b5 100644 --- a/src/schedule/llm_generator.py +++ b/src/schedule/llm_generator.py @@ -125,7 +125,8 @@ class ScheduleLLMGenerator: logger.info("继续重试...") await asyncio.sleep(3) - def _validate_schedule_with_pydantic(self, schedule_data) -> bool: + @staticmethod + def _validate_schedule_with_pydantic(schedule_data) -> bool: try: ScheduleData(schedule=schedule_data) logger.info("日程数据Pydantic验证通过") @@ -204,7 +205,8 @@ class MonthlyPlanLLMGenerator: logger.error(" 所有尝试都失败,无法生成月度计划") return [] - def _parse_plans_response(self, response: str) -> List[str]: + @staticmethod + def _parse_plans_response(response: str) -> List[str]: try: response = response.strip() lines = [line.strip() for line in response.split("\n") if line.strip()] diff --git a/src/schedule/plan_manager.py b/src/schedule/plan_manager.py index b84a37b72..82f8a8e04 100644 --- a/src/schedule/plan_manager.py +++ b/src/schedule/plan_manager.py @@ -80,7 +80,8 @@ class PlanManager: finally: self.generation_running = False - def _get_previous_month(self, current_month: str) -> str: + @staticmethod + def _get_previous_month(current_month: str) -> str: try: year, month = map(int, current_month.split("-")) if month == 1: @@ -90,7 +91,8 @@ class PlanManager: except Exception: return "1900-01" - async def archive_current_month_plans(self, target_month: Optional[str] = None): + @staticmethod + async def archive_current_month_plans(target_month: Optional[str] = None): try: if target_month is None: target_month = datetime.now().strftime("%Y-%m") @@ -100,6 +102,7 @@ class PlanManager: except Exception as e: logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}") - async def get_plans_for_schedule(self, month: str, max_count: int) -> List: + @staticmethod + async def get_plans_for_schedule(month: str, max_count: int) -> List: avoid_days = global_config.planning_system.avoid_repetition_days return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) \ No newline at end of file diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index 822131dec..115480381 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -3,6 +3,8 @@ import asyncio from datetime import datetime, time, timedelta from typing import Optional, List, Dict, Any +from sqlalchemy import select + from src.common.database.sqlalchemy_models import Schedule, get_db_session from src.config.config import global_config from src.common.logger import get_logger @@ -115,7 +117,8 @@ class ScheduleManager: self.schedule_generation_running = False logger.info("日程生成任务结束") - async def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]): + @staticmethod + async def _save_schedule_to_db(date_str: str, schedule_data: List[Dict[str, Any]]): async with get_db_session() as session: schedule_json = orjson.dumps(schedule_data).decode("utf-8") result = await session.execute(select(Schedule).filter(Schedule.date == date_str)) @@ -128,7 +131,8 @@ class ScheduleManager: session.add(new_schedule) await session.commit() - def _log_generated_schedule(self, date_str: str, schedule_data: List[Dict[str, Any]]): + @staticmethod + def _log_generated_schedule(date_str: str, schedule_data: List[Dict[str, Any]]): schedule_str = f"✅ 成功生成并保存今天的日程 ({date_str}):\n" for item in schedule_data: schedule_str += f" - {item.get('time_range', '未知时间')}: {item.get('activity', '未知活动')}\n" @@ -153,7 +157,8 @@ class ScheduleManager: logger.warning(f"解析日程事件失败: {event}, 错误: {e}") return None - def _validate_schedule_with_pydantic(self, schedule_data) -> bool: + @staticmethod + def _validate_schedule_with_pydantic(schedule_data) -> bool: try: ScheduleData(schedule=schedule_data) return True diff --git a/src/utils/message_chunker.py b/src/utils/message_chunker.py index ec2e300c2..66a2964e1 100644 --- a/src/utils/message_chunker.py +++ b/src/utils/message_chunker.py @@ -58,7 +58,8 @@ class MessageReassembler: except Exception as e: logger.error(f"清理过期切片时出错: {e}") - def is_chunk_message(self, message: Dict[str, Any]) -> bool: + @staticmethod + def is_chunk_message(message: Dict[str, Any]) -> bool: """检查是否是来自 Ada 的切片消息""" return ( isinstance(message, dict) From 788190fb11796f956bceef5e017432c024b69624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 11:45:59 +0800 Subject: [PATCH 06/31] =?UTF-8?q?=E5=B0=86AFC=E5=90=88=E5=B9=B6=E8=87=B3Ma?= =?UTF-8?q?ster=E5=88=86=E6=94=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/message.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index fc57b6fc6..5347ea43f 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -2,7 +2,7 @@ import base64 import time from abc import abstractmethod, ABCMeta from dataclasses import dataclass -from typing import Optional, Any +from typing import Optional, Any, TYPE_CHECKING import urllib3 from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase @@ -14,8 +14,11 @@ from src.chat.utils.utils_voice import get_voice_text from src.common.logger import get_logger from src.config.config import global_config +from src.chat.message_receive.chat_stream import ChatStream + install(extra_lines=3) + logger = get_logger("chat_message") # 禁用SSL警告 From a8992cdd51247a9cef069a596c8481fadde1e128 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 11:57:22 +0800 Subject: [PATCH 07/31] =?UTF-8?q?4=E6=AC=A1=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/message.py | 10 ++- src/chat/message_receive/storage.py | 2 +- src/chat/planner_actions/plan_filter.py | 4 +- src/chat/utils/utils_image.py | 101 ++++++++++++------------ src/llm_models/utils.py | 6 +- src/llm_models/utils_model.py | 6 +- 6 files changed, 69 insertions(+), 60 deletions(-) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index fc57b6fc6..1797363bb 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -95,17 +95,23 @@ class Message(MessageBase, metaclass=ABCMeta): class MessageRecv(Message): """接收消息类,用于处理从MessageCQ序列化的消息""" - def __init__(self, message_dict: dict[str, Any], message_id: str, chat_stream: "ChatStream", user_info: UserInfo): + def __init__(self, message_dict: dict[str, Any]): """从MessageCQ的字典初始化 Args: message_dict: MessageCQ序列化后的字典 """ - super().__init__(message_id, chat_stream, user_info) + # Manually initialize attributes from MessageBase and Message self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) self.raw_message = message_dict.get("raw_message") + + self.chat_stream = None + self.reply = None self.processed_plain_text = message_dict.get("processed_plain_text", "") + self.memorized_times = 0 + + # MessageRecv specific attributes self.is_emoji = False self.has_emoji = False self.is_picid = False diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index eb0dc5d1e..4bdaa9edc 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -122,7 +122,7 @@ class MessageStorage: is_picid=is_picid, ) async with get_db_session() as session: - await session.add(new_message) + session.add(new_message) except Exception: logger.exception("存储消息失败") diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index 91237c9cb..19d11bc4e 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -124,7 +124,7 @@ class PlanFilter: if plan.mode == ChatMode.PROACTIVE: long_term_memory_block = await self._get_long_term_memory_context() - chat_content_block, message_id_list = build_readable_messages_with_id( + chat_content_block, message_id_list = await build_readable_messages_with_id( messages=[msg.flatten() for msg in plan.chat_history], timestamp_mode="normal", truncate=False, @@ -160,7 +160,7 @@ class PlanFilter: show_actions=True, ) - actions_before_now = get_actions_by_timestamp_with_chat( + actions_before_now = await get_actions_by_timestamp_with_chat( chat_id=plan.chat_id, timestamp_start=time.time() - 3600, timestamp_end=time.time(), diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 18adc8a9b..93ec14957 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -69,7 +69,7 @@ class ImageManager: os.makedirs(self.IMAGE_DIR, exist_ok=True) @staticmethod - def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: + async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 Args: @@ -80,22 +80,22 @@ class ImageManager: Optional[str]: 描述文本,如果不存在则返回None """ try: - with get_db_session() as session: - record = session.execute( + async with get_db_session() as session: + record = (await session.execute( select(ImageDescriptions).where( and_( ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type, ) ) - ).scalar() + )).scalar() return record.description if record else None except Exception as e: logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}") return None @staticmethod - def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: + async def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: """保存图片描述到数据库 Args: @@ -105,16 +105,16 @@ class ImageManager: """ try: current_timestamp = time.time() - with get_db_session() as session: + async with get_db_session() as session: # 查找现有记录 - existing = session.execute( + existing = (await session.execute( select(ImageDescriptions).where( and_( ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type, ) ) - ).scalar() + )).scalar() if existing: # 更新现有记录 @@ -129,7 +129,7 @@ class ImageManager: timestamp=current_timestamp, ) session.add(new_desc) - session.commit() + await session.commit() # 会在上下文管理器中自动调用 except Exception as e: logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}") @@ -175,7 +175,7 @@ class ImageManager: logger.debug(f"查询EmojiManager时出错: {e}") # 查询ImageDescriptions表的缓存描述 - if cached_description := self._get_description_from_db(image_hash, "emoji"): + if cached_description := await self._get_description_from_db(image_hash, "emoji"): logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...") return f"[表情包:{cached_description}]" @@ -239,7 +239,7 @@ class ImageManager: logger.info(f"[emoji识别] 详细描述: {detailed_description}... -> 情感标签: {final_emotion}") - if cached_description := self._get_description_from_db(image_hash, "emoji"): + if cached_description := await self._get_description_from_db(image_hash, "emoji"): logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") return f"[表情包:{cached_description}]" @@ -261,10 +261,10 @@ class ImageManager: try: from src.common.database.sqlalchemy_models import get_db_session - with get_db_session() as session: - existing_img = session.execute( + async with get_db_session() as session: + existing_img = (await session.execute( select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji")) - ).scalar() + )).scalar() if existing_img: existing_img.path = file_path @@ -279,7 +279,7 @@ class ImageManager: timestamp=current_timestamp, ) session.add(new_img) - session.commit() + await session.commit() except Exception as e: logger.error(f"保存到Images表失败: {str(e)}") @@ -289,7 +289,7 @@ class ImageManager: logger.debug("偷取表情包功能已关闭,跳过保存。") # 保存最终的情感标签到缓存 (ImageDescriptions表) - self._save_description_to_db(image_hash, final_emotion, "emoji") + await self._save_description_to_db(image_hash, final_emotion, "emoji") return f"[表情包:{final_emotion}]" @@ -306,9 +306,9 @@ class ImageManager: image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - # 优先检查Images表中是否已有完整的描述 - with get_db_session() as session: - existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar() + async with get_db_session() as session: + # 优先检查Images表中是否已有完整的描述 + existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar() if existing_image: # 更新计数 if hasattr(existing_image, "count") and existing_image.count is not None: @@ -318,34 +318,34 @@ class ImageManager: # 如果已有描述,直接返回 if existing_image.description: + await session.commit() logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...") return f"[图片:{existing_image.description}]" - if cached_description := self._get_description_from_db(image_hash, "image"): - logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...") - return f"[图片:{cached_description}]" + # 如果没有描述,继续在当前会话中操作 + if cached_description := await self._get_description_from_db(image_hash, "image"): + logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...") + return f"[图片:{cached_description}]" - # 调用AI获取描述 - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore - prompt = global_config.custom_prompt.image_prompt - logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") - description, _ = await self.vlm.generate_response_for_image( - prompt, image_base64, image_format, temperature=0.4, max_tokens=300 - ) + # 调用AI获取描述 + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore + prompt = global_config.custom_prompt.image_prompt + logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") + description, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.4, max_tokens=300 + ) - if description is None: - logger.warning("AI未能生成图片描述") - return "[图片(描述生成失败)]" + if description is None: + logger.warning("AI未能生成图片描述") + return "[图片(描述生成失败)]" - # 保存图片和描述 - current_timestamp = time.time() - filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" - image_dir = os.path.join(self.IMAGE_DIR, "image") - os.makedirs(image_dir, exist_ok=True) - file_path = os.path.join(image_dir, filename) + # 保存图片和描述 + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + image_dir = os.path.join(self.IMAGE_DIR, "image") + os.makedirs(image_dir, exist_ok=True) + file_path = os.path.join(image_dir, filename) - try: - # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) @@ -358,7 +358,6 @@ class ImageManager: existing_image.image_id = str(uuid.uuid4()) if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None: existing_image.vlm_processed = True - logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...") else: new_img = Images( @@ -372,13 +371,15 @@ class ImageManager: count=1, ) session.add(new_img) - logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") - except Exception as e: - logger.error(f"保存图片文件或元数据失败: {str(e)}") - # 保存描述到ImageDescriptions表作为备用缓存 - self._save_description_to_db(image_hash, description, "image") + await session.commit() + + # 保存描述到ImageDescriptions表作为备用缓存 + await self._save_description_to_db(image_hash, description, "image") + + logger.info(f"[VLM完成] 图片描述生成: {description}...") + return f"[图片:{description}]" logger.info(f"[VLM完成] 图片描述生成: {description}...") return f"[图片:{description}]" @@ -525,8 +526,8 @@ class ImageManager: image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - with get_db_session() as session: - existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar() + async with get_db_session() as session: + existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar() if existing_image: # 检查是否缺少必要字段,如果缺少则创建新记录 if ( @@ -546,6 +547,7 @@ class ImageManager: existing_image.vlm_processed = False existing_image.count += 1 + await session.commit() # 如果已有描述,直接返回 if existing_image.description and existing_image.description.strip(): @@ -556,6 +558,7 @@ class ImageManager: # 更新数据库中的描述 existing_image.description = description.replace("[图片:", "").replace("]", "") existing_image.vlm_processed = True + await session.commit() return existing_image.image_id, f"[picid:{existing_image.image_id}]" # print(f"图片不存在: {image_hash}") @@ -588,7 +591,7 @@ class ImageManager: count=1, ) session.add(new_img) - session.commit() + await session.commit() return image_id, f"[picid:{image_id}]" diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index 48f7be5f2..bf23f144a 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -146,7 +146,7 @@ class LLMUsageRecorder: """ @staticmethod - def record_usage_to_database( + async def record_usage_to_database( model_info: ModelInfo, model_usage: UsageRecord, user_id: str, @@ -161,7 +161,7 @@ class LLMUsageRecorder: session = None try: # 使用 SQLAlchemy 会话创建记录 - with get_db_session() as session: + async with get_db_session() as session: usage_record = LLMUsage( model_name=model_info.model_identifier, model_assign_name=model_info.name, @@ -179,7 +179,7 @@ class LLMUsageRecorder: ) session.add(usage_record) - session.commit() + await session.commit() logger.debug( f"Token使用情况 - 模型: {model_usage.model_name}, " diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 4969644e6..146e5eb46 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -202,7 +202,7 @@ class LLMRequest: content, extracted_reasoning = self._extract_reasoning(content) reasoning_content = extracted_reasoning if usage := response.usage: - llm_usage_recorder.record_usage_to_database( + await llm_usage_recorder.record_usage_to_database( model_info=model_info, model_usage=usage, user_id="system", @@ -367,7 +367,7 @@ class LLMRequest: # 成功获取响应 if usage := response.usage: - llm_usage_recorder.record_usage_to_database( + await llm_usage_recorder.record_usage_to_database( model_info=model_info, model_usage=usage, time_cost=time.time() - start_time, @@ -442,7 +442,7 @@ class LLMRequest: embedding = response.embedding if usage := response.usage: - llm_usage_recorder.record_usage_to_database( + await llm_usage_recorder.record_usage_to_database( model_info=model_info, time_cost=time.time() - start_time, model_usage=usage, From 9b2addfd86cdd83e79f4af375b10bdb7d68c7804 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 12:22:15 +0800 Subject: [PATCH 08/31] 5 --- src/chat/chat_loop/heartFC_chat.py | 18 +++++++++--------- .../heart_flow/heartflow_message_processor.py | 2 +- src/chat/planner_actions/plan_filter.py | 3 ++- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index 7a351c97c..47afab50d 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -39,6 +39,7 @@ class HeartFChatting: - 初始化聊天模式并记录初始化完成日志 """ self.context = HfcContext(chat_id) + self.context.new_message_queue = asyncio.Queue() self.cycle_tracker = CycleTracker(self.context) self.response_handler = ResponseHandler(self.context) @@ -108,6 +109,10 @@ class HeartFChatting: self._loop_task.add_done_callback(self._handle_loop_completion) logger.info(f"{self.context.log_prefix} HeartFChatting 启动完成") + async def add_message(self, message: Dict[str, Any]): + """从外部接收新消息并放入队列""" + await self.context.new_message_queue.put(message) + async def stop(self): """ 停止心跳聊天系统 @@ -362,15 +367,10 @@ class HeartFChatting: # 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收 filter_command_flag = not (is_sleeping or is_in_insomnia) - recent_messages = await message_api.get_messages_by_time_in_chat( - chat_id=self.context.stream_id, - start_time=self.context.last_read_time, - end_time=time.time(), - limit=10, - limit_mode="latest", - filter_mai=True, - filter_command=filter_command_flag, - ) + # 从队列中获取所有待处理的新消息 + recent_messages = [] + while not self.context.new_message_queue.empty(): + recent_messages.append(await self.context.new_message_queue.get()) has_new_messages = bool(recent_messages) new_message_count = len(recent_messages) diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 1e6376e6d..958bc9096 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -137,7 +137,7 @@ class HeartFCMessageReceiver: subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore - # subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) + await subheartflow.heart_fc_instance.add_message(message.to_dict()) if global_config.mood.enable_mood: chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate)) diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index 19d11bc4e..6d9998a7d 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional from json_repair import repair_json +from . import planner_prompts from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.utils.chat_message_builder import ( build_readable_actions, @@ -167,7 +168,7 @@ class PlanFilter: limit=5, ) - actions_before_now_block = build_readable_actions(actions=await actions_before_now) + actions_before_now_block = build_readable_actions(actions=actions_before_now) actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" self.last_obs_time_mark = time.time() From 74a9c346f0c505d7e86ea3c97d0f6004247d6465 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 12:26:31 +0800 Subject: [PATCH 09/31] 6 --- src/chat/replyer/default_generator.py | 4 ++-- src/plugin_system/apis/cross_context_api.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 06cff0c9d..4ad846624 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -862,7 +862,7 @@ class DefaultReplyer: target = "(无消息内容)" person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) + person_id = await person_info_manager.get_person_id_by_person_name(sender) platform = chat_stream.platform target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) @@ -899,7 +899,7 @@ class DefaultReplyer: # 获取目标用户信息,用于s4u模式 target_user_info = None if sender: - target_user_info = person_info_manager.get_person_info_by_name(sender) + target_user_info = await person_info_manager.get_person_info_by_name(sender) from src.chat.utils.prompt import Prompt # 并行执行六个构建任务 diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index fcc93d485..76bd45bde 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -53,14 +53,14 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: continue try: - messages = get_raw_msg_before_timestamp_with_chat( + messages = await get_raw_msg_before_timestamp_with_chat( chat_id=stream_id, timestamp=time.time(), limit=5, # 可配置 ) if messages: chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id - formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative") + formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative") cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}') except Exception as e: logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}") @@ -92,7 +92,7 @@ async def build_cross_context_s4u( continue try: - messages = get_raw_msg_before_timestamp_with_chat( + messages = await get_raw_msg_before_timestamp_with_chat( chat_id=stream_id, timestamp=time.time(), limit=20, # 获取更多消息以供筛选 @@ -104,7 +104,7 @@ async def build_cross_context_s4u( user_name = ( target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id ) - formatted_messages, _ = build_readable_messages_with_id( + formatted_messages, _ = await build_readable_messages_with_id( user_messages, timestamp_mode="relative" ) cross_context_messages.append( @@ -161,14 +161,14 @@ async def get_chat_history_by_group_name(group_name: str) -> str: stream_id = found_stream.stream_id try: - messages = get_raw_msg_before_timestamp_with_chat( + messages = await get_raw_msg_before_timestamp_with_chat( chat_id=stream_id, timestamp=time.time(), limit=5, # 可配置 ) if messages: chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id - formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative") + formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative") cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}') except Exception as e: logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}") From d22b3b71fe90b36a9eadc3b7594047ee6cb1e260 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 11:45:59 +0800 Subject: [PATCH 10/31] =?UTF-8?q?=E5=B0=8F=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/message.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index fc57b6fc6..a113088c9 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -13,6 +13,7 @@ from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_ava from src.chat.utils.utils_voice import get_voice_text from src.common.logger import get_logger from src.config.config import global_config +rom src.chat.message_receive.chat_stream import ChatStream install(extra_lines=3) From 7777e1ec712e748f7ef37190d180587b1bd190af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 12:38:24 +0800 Subject: [PATCH 11/31] Update message.py --- src/chat/message_receive/message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index a113088c9..9af9a5c3a 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -13,7 +13,7 @@ from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_ava from src.chat.utils.utils_voice import get_voice_text from src.common.logger import get_logger from src.config.config import global_config -rom src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_stream import ChatStream install(extra_lines=3) From 6b5bf023823bcd00af149e8a427f97329b03f925 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 12:43:37 +0800 Subject: [PATCH 12/31] =?UTF-8?q?=E6=80=BB=E7=AE=97=E8=83=BD=E5=9B=9E?= =?UTF-8?q?=E5=A4=8D=E4=BA=86=F0=9F=98=AD=F0=9F=98=AD=F0=9F=98=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/chat_loop/response_handler.py | 2 +- src/chat/replyer/default_generator.py | 2 +- src/person_info/relationship_fetcher.py | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/chat/chat_loop/response_handler.py b/src/chat/chat_loop/response_handler.py index 9859c76c3..99f065319 100644 --- a/src/chat/chat_loop/response_handler.py +++ b/src/chat/chat_loop/response_handler.py @@ -130,7 +130,7 @@ class ResponseHandler: """ current_time = time.time() # 计算新消息数量 - new_message_count = message_api.count_new_messages( + new_message_count = await message_api.count_new_messages( chat_id=self.context.stream_id, start_time=thinking_start_time, end_time=current_time ) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 4ad846624..aa9c5eba0 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -1332,7 +1332,7 @@ class DefaultReplyer: # 获取用户ID person_info_manager = get_person_info_manager() - person_id = person_info_manager.get_person_id_by_person_name(sender) + person_id = await person_info_manager.get_person_id_by_person_name(sender) if not person_id: logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取") return f"你完全不认识{sender},不理解ta的相关信息。" diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index e903915a7..4b25f6b14 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -110,7 +110,11 @@ class RelationshipFetcher: if person_name == nickname_str and not short_impression: return "" - current_points = person_info.get("points") or [] + current_points = person_info.get("points") + if isinstance(current_points, str): + current_points = orjson.loads(current_points) + else: + current_points = current_points or [] # 按时间排序forgotten_points current_points.sort(key=lambda x: x[2]) From 8d9aa4fb9e45d0730c3f6e5a256d4d2f2a3ba7bd Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 13:07:06 +0800 Subject: [PATCH 13/31] =?UTF-8?q?refactor(db):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E4=BA=A4=E4=BA=92=E4=B8=BA=E5=BC=82?= =?UTF-8?q?=E6=AD=A5=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为了提升性能并与项目整体的异步架构保持一致,对核心数据库交互模块进行了异步化重构。 主要修改内容包括: - 将 `PermissionManager` 中的所有数据库操作从同步改为异步,以避免阻塞事件循环。 - 使用 `async_sessionmaker` 和 `async with session` 替代原有的同步会话管理。 - 将 SQLAlchemy 查询语法更新为异步兼容的 `await session.execute(select(...))` 模式。 - 相应地,调用链中依赖数据库操作的多个方法也已更新为 `async` 函数。 --- src/chat/chat_loop/cycle_processor.py | 3 +- src/chat/message_receive/message.py | 1 - src/chat/replyer/default_generator.py | 6 +- src/plugin_system/base/base_action.py | 2 +- src/plugin_system/core/permission_manager.py | 137 ++++++++++--------- 5 files changed, 81 insertions(+), 68 deletions(-) diff --git a/src/chat/chat_loop/cycle_processor.py b/src/chat/chat_loop/cycle_processor.py index fe993f484..b2a092958 100644 --- a/src/chat/chat_loop/cycle_processor.py +++ b/src/chat/chat_loop/cycle_processor.py @@ -86,7 +86,8 @@ class CycleProcessor: platform, action_message.get("chat_info_user_id", ""), ) - person_name = await person_info_manager.get_value(person_id, "person_name") + person_info = await person_info_manager.get_values(person_id, ["person_name"]) + person_name = person_info.get("person_name") action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" # 存储动作信息到数据库 diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 631aa7c09..22e57edf0 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -15,7 +15,6 @@ from src.common.logger import get_logger from src.config.config import global_config from src.chat.message_receive.chat_stream import ChatStream -from src.chat.message_receive.chat_stream import ChatStream install(extra_lines=3) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index aa9c5eba0..bf3d4fe26 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -660,7 +660,7 @@ class DefaultReplyer: duration = end_time - start_time return name, result, duration - def build_s4u_chat_history_prompts( + async def build_s4u_chat_history_prompts( self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str ) -> Tuple[str, str]: """ @@ -692,7 +692,7 @@ class DefaultReplyer: all_dialogue_prompt = "" if message_list_before_now: latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :] - all_dialogue_prompt_str = build_readable_messages( + all_dialogue_prompt_str = await build_readable_messages( latest_25_msgs, replace_bot_name=True, timestamp_mode="normal", @@ -716,7 +716,7 @@ class DefaultReplyer: else: core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量 - core_dialogue_prompt_str = build_readable_messages( + core_dialogue_prompt_str = await build_readable_messages( core_dialogue_list, replace_bot_name=True, merge_messages=False, diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 9400032f8..51a0f4257 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -213,7 +213,7 @@ class BaseAction(ABC): # 检查新消息 current_time = time.time() - new_message_count = message_api.count_new_messages( + new_message_count = await message_api.count_new_messages( chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time ) diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index 9d996fd46..eb6083fc9 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -5,9 +5,10 @@ """ from typing import List, Set, Tuple -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.exc import IntegrityError, SQLAlchemyError from datetime import datetime +from sqlalchemy import select, delete from src.common.logger import get_logger from src.common.database.sqlalchemy_models import get_engine, PermissionNodes, UserPermissions @@ -22,7 +23,7 @@ class PermissionManager(IPermissionManager): def __init__(self): self.engine = get_engine() - self.SessionLocal = sessionmaker(bind=self.engine) + self.SessionLocal = async_sessionmaker(bind=self.engine) self._master_users: Set[Tuple[str, str]] = set() self._load_master_users() logger.info("权限管理器初始化完成") @@ -62,7 +63,7 @@ class PermissionManager(IPermissionManager): logger.debug(f"用户 {user.platform}:{user.user_id} 是Master用户") return is_master - def check_permission(self, user: UserInfo, permission_node: str) -> bool: + async def check_permission(self, user: UserInfo, permission_node: str) -> bool: """ 检查用户是否拥有指定权限节点 @@ -79,34 +80,35 @@ class PermissionManager(IPermissionManager): logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}") return True - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 检查权限节点是否存在 - node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node)) + node = result.scalar_one_or_none() if not node: logger.warning(f"权限节点 {permission_node} 不存在") return False # 检查用户是否有明确的权限设置 - user_perm = ( - session.query(UserPermissions) + result = await session.execute( + select(UserPermissions) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) - .first() ) + user_perm = result.scalar_one_or_none() if user_perm: # 有明确设置,返回设置的值 - result = user_perm.granted + res = user_perm.granted logger.debug( - f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}" + f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {res}" ) - return result + return res else: # 没有明确设置,使用默认值 - result = node.default_granted + res = node.default_granted logger.debug( - f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}" + f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {res}" ) - return result + return res except SQLAlchemyError as e: logger.error(f"检查权限时数据库错误: {e}") @@ -115,7 +117,7 @@ class PermissionManager(IPermissionManager): logger.error(f"检查权限时发生未知错误: {e}") return False - def register_permission_node(self, node: PermissionNode) -> bool: + async def register_permission_node(self, node: PermissionNode) -> bool: """ 注册权限节点 @@ -126,15 +128,16 @@ class PermissionManager(IPermissionManager): bool: 注册是否成功 """ try: - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 检查节点是否已存在 - existing_node = session.query(PermissionNodes).filter_by(node_name=node.node_name).first() + result = await session.execute(select(PermissionNodes).filter_by(node_name=node.node_name)) + existing_node = result.scalar_one_or_none() if existing_node: # 更新现有节点的信息 existing_node.description = node.description existing_node.plugin_name = node.plugin_name existing_node.default_granted = node.default_granted - session.commit() + await session.commit() logger.debug(f"更新权限节点: {node.node_name}") return True @@ -147,7 +150,7 @@ class PermissionManager(IPermissionManager): created_at=datetime.utcnow(), ) session.add(new_node) - session.commit() + await session.commit() logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})") return True @@ -161,7 +164,7 @@ class PermissionManager(IPermissionManager): logger.error(f"注册权限节点时发生未知错误: {e}") return False - def grant_permission(self, user: UserInfo, permission_node: str) -> bool: + async def grant_permission(self, user: UserInfo, permission_node: str) -> bool: """ 授权用户权限节点 @@ -173,19 +176,20 @@ class PermissionManager(IPermissionManager): bool: 授权是否成功 """ try: - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 检查权限节点是否存在 - node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node)) + node = result.scalar_one_or_none() if not node: logger.error(f"尝试授权不存在的权限节点: {permission_node}") return False # 检查是否已有权限记录 - existing_perm = ( - session.query(UserPermissions) + result = await session.execute( + select(UserPermissions) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) - .first() ) + existing_perm = result.scalar_one_or_none() if existing_perm: # 更新现有记录 @@ -202,7 +206,7 @@ class PermissionManager(IPermissionManager): ) session.add(new_perm) - session.commit() + await session.commit() logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") return True @@ -213,7 +217,7 @@ class PermissionManager(IPermissionManager): logger.error(f"授权权限时发生未知错误: {e}") return False - def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: + async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: """ 撤销用户权限节点 @@ -225,19 +229,20 @@ class PermissionManager(IPermissionManager): bool: 撤销是否成功 """ try: - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 检查权限节点是否存在 - node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node)) + node = result.scalar_one_or_none() if not node: logger.error(f"尝试撤销不存在的权限节点: {permission_node}") return False # 检查是否已有权限记录 - existing_perm = ( - session.query(UserPermissions) + result = await session.execute( + select(UserPermissions) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node) - .first() ) + existing_perm = result.scalar_one_or_none() if existing_perm: # 更新现有记录 @@ -254,7 +259,7 @@ class PermissionManager(IPermissionManager): ) session.add(new_perm) - session.commit() + await session.commit() logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") return True @@ -265,7 +270,7 @@ class PermissionManager(IPermissionManager): logger.error(f"撤销权限时发生未知错误: {e}") return False - def get_user_permissions(self, user: UserInfo) -> List[str]: + async def get_user_permissions(self, user: UserInfo) -> List[str]: """ 获取用户拥有的所有权限节点 @@ -278,23 +283,25 @@ class PermissionManager(IPermissionManager): try: # Master用户拥有所有权限 if self.is_master(user): - with self.SessionLocal() as session: - all_nodes = session.query(PermissionNodes.node_name).all() - return [node.node_name for node in all_nodes] + async with self.SessionLocal() as session: + result = await session.execute(select(PermissionNodes.node_name)) + all_nodes = result.scalars().all() + return all_nodes permissions = [] - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 获取所有权限节点 - all_nodes = session.query(PermissionNodes).all() + result = await session.execute(select(PermissionNodes)) + all_nodes = result.scalars().all() for node in all_nodes: # 检查用户是否有明确的权限设置 - user_perm = ( - session.query(UserPermissions) + result = await session.execute( + select(UserPermissions) .filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name) - .first() ) + user_perm = result.scalar_one_or_none() if user_perm: # 有明确设置,使用设置的值 @@ -314,7 +321,7 @@ class PermissionManager(IPermissionManager): logger.error(f"获取用户权限时发生未知错误: {e}") return [] - def get_all_permission_nodes(self) -> List[PermissionNode]: + async def get_all_permission_nodes(self) -> List[PermissionNode]: """ 获取所有已注册的权限节点 @@ -322,8 +329,9 @@ class PermissionManager(IPermissionManager): List[PermissionNode]: 权限节点列表 """ try: - with self.SessionLocal() as session: - nodes = session.query(PermissionNodes).all() + async with self.SessionLocal() as session: + result = await session.execute(select(PermissionNodes)) + nodes = result.scalars().all() return [ PermissionNode( node_name=node.node_name, @@ -341,7 +349,7 @@ class PermissionManager(IPermissionManager): logger.error(f"获取所有权限节点时发生未知错误: {e}") return [] - def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: + async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: """ 获取指定插件的所有权限节点 @@ -352,8 +360,9 @@ class PermissionManager(IPermissionManager): List[PermissionNode]: 权限节点列表 """ try: - with self.SessionLocal() as session: - nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all() + async with self.SessionLocal() as session: + result = await session.execute(select(PermissionNodes).filter_by(plugin_name=plugin_name)) + nodes = result.scalars().all() return [ PermissionNode( node_name=node.node_name, @@ -371,7 +380,7 @@ class PermissionManager(IPermissionManager): logger.error(f"获取插件权限节点时发生未知错误: {e}") return [] - def delete_plugin_permissions(self, plugin_name: str) -> bool: + async def delete_plugin_permissions(self, plugin_name: str) -> bool: """ 删除指定插件的所有权限节点(用于插件卸载时清理) @@ -382,9 +391,10 @@ class PermissionManager(IPermissionManager): bool: 删除是否成功 """ try: - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 获取插件的所有权限节点 - plugin_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all() + result = await session.execute(select(PermissionNodes).filter_by(plugin_name=plugin_name)) + plugin_nodes = result.scalars().all() node_names = [node.node_name for node in plugin_nodes] if not node_names: @@ -392,16 +402,17 @@ class PermissionManager(IPermissionManager): return True # 删除用户权限记录 - deleted_user_perms = ( - session.query(UserPermissions) - .filter(UserPermissions.permission_node.in_(node_names)) - .delete(synchronize_session=False) + result = await session.execute( + delete(UserPermissions) + .where(UserPermissions.permission_node.in_(node_names)) ) + deleted_user_perms = result.rowcount # 删除权限节点 - deleted_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).delete() + result = await session.execute(delete(PermissionNodes).filter_by(plugin_name=plugin_name)) + deleted_nodes = result.rowcount - session.commit() + await session.commit() logger.info( f"已删除插件 {plugin_name} 的 {deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录" ) @@ -414,7 +425,7 @@ class PermissionManager(IPermissionManager): logger.error(f"删除插件权限时发生未知错误: {e}") return False - def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]: + async def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]: """ 获取拥有指定权限的所有用户 @@ -427,17 +438,19 @@ class PermissionManager(IPermissionManager): try: users = [] - with self.SessionLocal() as session: + async with self.SessionLocal() as session: # 检查权限节点是否存在 - node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + result = await session.execute(select(PermissionNodes).filter_by(node_name=permission_node)) + node = result.scalar_one_or_none() if not node: logger.warning(f"权限节点 {permission_node} 不存在") return users # 获取明确授权的用户 - granted_users = ( - session.query(UserPermissions).filter_by(permission_node=permission_node, granted=True).all() + result = await session.execute( + select(UserPermissions).filter_by(permission_node=permission_node, granted=True) ) + granted_users = result.scalars().all() for user_perm in granted_users: users.append((user_perm.platform, user_perm.user_id)) From e2e0d3c30a22615cdf694ceb21e6916fb9ec14f2 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 13:52:37 +0800 Subject: [PATCH 14/31] =?UTF-8?q?refactor(core):=20=E9=80=82=E9=85=8D?= =?UTF-8?q?=E6=A0=B8=E5=BF=83=E6=95=B0=E6=8D=AE=E8=8E=B7=E5=8F=96=E4=B8=8E?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E6=9E=84=E5=BB=BA=E5=87=BD=E6=95=B0=E7=9A=84?= =?UTF-8?q?=E5=BC=82=E6=AD=A5=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在数据库交互层异步化后,多个相关的数据获取和消息构建函数(如 `build_readable_messages`)也转为异步实现。本次提交在所有调用点添加了 `await` 关键字,以适应这一变化。 此外,本次提交还包含以下修复: - 在主动思考模块中增加了对规划器返回无效动作的检查,避免后续流程出错。 - 修正了日志记录中错误的上下文变量引用。 --- src/chat/chat_loop/cycle_processor.py | 2 +- src/chat/chat_loop/proactive/proactive_thinker.py | 8 ++++++-- src/chat/planner_actions/plan_filter.py | 2 +- src/chat/utils/prompt.py | 10 +++++----- src/mais4u/mais4u_chat/body_emotion_action_manager.py | 4 ++-- src/mais4u/mais4u_chat/s4u_mood_manager.py | 4 ++-- src/mais4u/mais4u_chat/s4u_prompt.py | 8 ++++---- src/mood/mood_manager.py | 4 ++-- src/person_info/relationship_manager.py | 2 +- src/plugins/built_in/core_actions/emoji.py | 4 ++-- .../maizone_refactored/services/qzone_service.py | 2 +- 11 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/chat/chat_loop/cycle_processor.py b/src/chat/chat_loop/cycle_processor.py index b2a092958..441d14b1f 100644 --- a/src/chat/chat_loop/cycle_processor.py +++ b/src/chat/chat_loop/cycle_processor.py @@ -192,7 +192,7 @@ class CycleProcessor: await self.action_modifier.modify_actions() available_actions = self.context.action_manager.get_using_actions() except Exception as e: - logger.error(f"{self.context.log_prefix} 动作修改失败: {e}") + logger.error(f"{self.log_prefix} 动作修改失败: {e}") available_actions = {} # 规划动作 diff --git a/src/chat/chat_loop/proactive/proactive_thinker.py b/src/chat/chat_loop/proactive/proactive_thinker.py index adf187dca..4dea5ec99 100644 --- a/src/chat/chat_loop/proactive/proactive_thinker.py +++ b/src/chat/chat_loop/proactive/proactive_thinker.py @@ -120,6 +120,10 @@ class ProactiveThinker: action_result = actions[0] if actions else {} action_type = action_result.get("action_type") + if action_type is None: + logger.info(f"{self.context.log_prefix} 主动思考决策: 规划器未返回有效动作") + return + if action_type == "proactive_reply": await self._generate_proactive_content_and_send(action_result, trigger_event) elif action_type not in ["do_nothing", "no_action"]: @@ -212,12 +216,12 @@ class ProactiveThinker: logger.warning(f"{self.context.log_prefix} 主题为空,跳过网络搜索。") except Exception as e: logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}") - message_list = get_raw_msg_before_timestamp_with_chat( + message_list = await get_raw_msg_before_timestamp_with_chat( chat_id=self.context.stream_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.3), ) - chat_context_block, _ = build_readable_messages_with_id(messages=message_list) + chat_context_block, _ = await build_readable_messages_with_id(messages=message_list) from src.llm_models.utils_model import LLMRequest from src.config.config import model_config diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index 6d9998a7d..4ef8de2d8 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -133,7 +133,7 @@ class PlanFilter: ) prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt") - actions_before_now = get_actions_by_timestamp_with_chat( + actions_before_now = await get_actions_by_timestamp_with_chat( chat_id=plan.chat_id, timestamp_start=time.time() - 3600, timestamp_end=time.time(), diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 9dec72a28..2f115aa98 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -486,7 +486,7 @@ class Prompt: all_dialogue_prompt = "" if message_list_before_now: latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :] - all_dialogue_prompt_str = build_readable_messages( + all_dialogue_prompt_str = await build_readable_messages( latest_25_msgs, replace_bot_name=True, timestamp_mode="normal", @@ -505,7 +505,7 @@ class Prompt: else: core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] - core_dialogue_prompt_str = build_readable_messages( + core_dialogue_prompt_str = await build_readable_messages( core_dialogue_list, replace_bot_name=True, merge_messages=False, @@ -534,7 +534,7 @@ class Prompt: chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-10:] - chat_history = build_readable_messages( + chat_history = await build_readable_messages( recent_messages, replace_bot_name=True, timestamp_mode="normal", @@ -574,7 +574,7 @@ class Prompt: chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-20:] - chat_history = build_readable_messages( + chat_history = await build_readable_messages( recent_messages, replace_bot_name=True, timestamp_mode="normal", @@ -632,7 +632,7 @@ class Prompt: chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-15:] - chat_history = build_readable_messages( + chat_history = await build_readable_messages( recent_messages, replace_bot_name=True, timestamp_mode="normal", diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py index 5807e2acf..38073baa4 100644 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py @@ -125,7 +125,7 @@ class ChatAction: limit=15, limit_mode="last", ) - chat_talking_prompt = build_readable_messages( + chat_talking_prompt = await build_readable_messages( message_list_before_now, replace_bot_name=True, merge_messages=False, @@ -189,7 +189,7 @@ class ChatAction: limit=10, limit_mode="last", ) - chat_talking_prompt = build_readable_messages( + chat_talking_prompt = await build_readable_messages( message_list_before_now, replace_bot_name=True, merge_messages=False, diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index db852567e..fa12523a4 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -167,7 +167,7 @@ class ChatMood: limit=10, limit_mode="last", ) - chat_talking_prompt = build_readable_messages( + chat_talking_prompt = await build_readable_messages( message_list_before_now, replace_bot_name=True, merge_messages=False, @@ -246,7 +246,7 @@ class ChatMood: limit=5, limit_mode="last", ) - chat_talking_prompt = build_readable_messages( + chat_talking_prompt = await build_readable_messages( message_list_before_now, replace_bot_name=True, merge_messages=False, diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index 2590a388f..a5b3a8f4b 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -183,7 +183,7 @@ class PromptBuilder: return "" @staticmethod - def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U): + async def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U): message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), @@ -217,7 +217,7 @@ class PromptBuilder: background_dialogue_prompt = "" if background_dialogue_list: context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :] - background_dialogue_prompt_str = build_readable_messages( + background_dialogue_prompt_str = await build_readable_messages( context_msgs, timestamp_mode="normal_no_YMD", show_pic=False, @@ -266,7 +266,7 @@ class PromptBuilder: timestamp=time.time(), limit=20, ) - all_dialogue_prompt_str = build_readable_messages( + all_dialogue_prompt_str = await build_readable_messages( all_dialogue_prompt, timestamp_mode="normal_no_YMD", show_pic=False, @@ -316,7 +316,7 @@ class PromptBuilder: self.build_expression_habits(chat_stream, message_txt, sender_name), ) - core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = self.build_chat_history_prompts( + core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = await self.build_chat_history_prompts( chat_stream, message ) diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 95a365b41..ef4416673 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -105,7 +105,7 @@ class ChatMood: limit=int(global_config.chat.max_context_size / 3), limit_mode="last", ) - chat_talking_prompt = build_readable_messages( + chat_talking_prompt = await build_readable_messages( message_list_before_now, replace_bot_name=True, merge_messages=False, @@ -154,7 +154,7 @@ class ChatMood: limit=15, limit_mode="last", ) - chat_talking_prompt = build_readable_messages( + chat_talking_prompt = await build_readable_messages( message_list_before_now, replace_bot_name=True, merge_messages=False, diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index c2a3ffb96..a6ce8ab02 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -118,7 +118,7 @@ class RelationshipManager: name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}" current_user = chr(ord(current_user) + 1) - readable_messages = build_readable_messages( + readable_messages = await build_readable_messages( messages=user_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True ) diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index 4375ae1a2..2c0940fcc 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -125,7 +125,7 @@ class EmojiAction(BaseAction): recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) messages_text = "" if recent_messages: - messages_text = message_api.build_readable_messages( + messages_text = await message_api.build_readable_messages( messages=recent_messages, timestamp_mode="normal_no_YMD", truncate=False, @@ -184,7 +184,7 @@ class EmojiAction(BaseAction): recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) messages_text = "" if recent_messages: - messages_text = message_api.build_readable_messages( + messages_text = await message_api.build_readable_messages( messages=recent_messages, timestamp_mode="normal_no_YMD", truncate=False, diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 545e615a0..67a3669db 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -240,7 +240,7 @@ class QZoneService: all_messages = all_messages[-100:] # build_readable_messages_with_id 返回一个元组 (formatted_string, message_id_list) - formatted_string, _ = build_readable_messages_with_id(all_messages) + formatted_string, _ = await build_readable_messages_with_id(all_messages) return formatted_string logger.debug(f"Stream ID '{stream_id}' 未在任何互通组中找到。") From 816ce9805ced0da4c1a7665d27c4305b005ec264 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 14:18:53 +0800 Subject: [PATCH 15/31] Update base_action.py --- src/plugin_system/base/base_action.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 51a0f4257..9400032f8 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -213,7 +213,7 @@ class BaseAction(ABC): # 检查新消息 current_time = time.time() - new_message_count = await message_api.count_new_messages( + new_message_count = message_api.count_new_messages( chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time ) From 55717669dd54db92a5aa3a5bb03749bdbbad974e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 14:35:31 +0800 Subject: [PATCH 16/31] =?UTF-8?q?refactor(db):=20=E5=B0=86=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=93=8D=E4=BD=9C=E5=BC=82=E6=AD=A5=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将所有 session.add() 改为 await session.add() - 将所有 session.commit() 改为 await session.commit() - 将 session.refresh() 改为 await session.refresh() --- src/chat/antipromptinjector/anti_injector.py | 4 ++-- src/chat/antipromptinjector/management/statistics.py | 12 ++++++------ src/chat/antipromptinjector/management/user_ban.py | 8 ++++---- src/chat/emoji_system/emoji_manager.py | 2 +- src/chat/memory_system/instant_memory.py | 4 ++-- src/chat/message_receive/chat_stream.py | 2 +- src/chat/message_receive/storage.py | 6 ++---- src/chat/utils/utils_image.py | 8 ++++---- src/chat/utils/utils_video.py | 6 +++--- src/common/database/database.py | 2 +- src/common/database/sqlalchemy_database_api.py | 4 ++-- src/common/database/sqlalchemy_models.py | 1 - src/common/message_repository.py | 2 +- src/llm_models/utils.py | 2 +- src/person_info/person_info.py | 6 +++--- src/plugin_system/core/permission_manager.py | 6 +++--- .../maizone_refactored/services/scheduler_service.py | 4 ++-- .../built_in/napcat_adapter_plugin/src/database.py | 8 ++++---- src/schedule/database.py | 2 +- src/schedule/schedule_manager.py | 2 +- 20 files changed, 44 insertions(+), 47 deletions(-) diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index 751a7d87e..f35070135 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -265,7 +265,7 @@ class AntiPromptInjector: # 删除对应的消息记录 stmt = delete(Messages).where(Messages.message_id == message_id) result = session.execute(stmt) - session.commit() + await session.commit() if result.rowcount > 0: logger.debug(f"成功删除违禁消息记录: {message_id}") @@ -295,7 +295,7 @@ class AntiPromptInjector: .values(processed_plain_text=new_content, display_message=new_content) ) result = session.execute(stmt) - session.commit() + await session.commit() if result.rowcount > 0: logger.debug(f"成功更新消息内容为加盾版本: {message_id}") diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 12606d4ba..e9b4be66b 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -32,9 +32,9 @@ class AntiInjectionStatistics: stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() if not stats: stats = AntiInjectionStats() - session.add(stats) - session.commit() - session.refresh(stats) + await session.add(stats) + await session.commit() + await session.refresh(stats) return stats except Exception as e: logger.error(f"获取统计记录失败: {e}") @@ -48,7 +48,7 @@ class AntiInjectionStatistics: stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() if not stats: stats = AntiInjectionStats() - session.add(stats) + await session.add(stats) # 更新统计字段 for key, value in kwargs.items(): @@ -80,7 +80,7 @@ class AntiInjectionStatistics: # 直接设置的字段 setattr(stats, key, value) - session.commit() + await session.commit() except Exception as e: logger.error(f"更新统计数据失败: {e}") @@ -141,7 +141,7 @@ class AntiInjectionStatistics: with get_db_session() as session: # 删除现有统计记录 session.query(AntiInjectionStats).delete() - session.commit() + await session.commit() logger.info("统计信息已重置") except Exception as e: logger.error(f"重置统计信息失败: {e}") diff --git a/src/chat/antipromptinjector/management/user_ban.py b/src/chat/antipromptinjector/management/user_ban.py index 5a2239162..865ddddb9 100644 --- a/src/chat/antipromptinjector/management/user_ban.py +++ b/src/chat/antipromptinjector/management/user_ban.py @@ -52,7 +52,7 @@ class UserBanManager: # 封禁已过期,重置违规次数 ban_record.violation_num = 0 ban_record.created_at = datetime.datetime.now() - session.commit() + await session.commit() logger.info(f"用户 {platform}:{user_id} 封禁已过期,违规次数已重置") return None @@ -85,9 +85,9 @@ class UserBanManager: reason=f"提示词注入攻击 (置信度: {detection_result.confidence:.2f})", created_at=datetime.datetime.now(), ) - session.add(ban_record) + await session.add(ban_record) - session.commit() + await session.commit() # 检查是否需要自动封禁 if ban_record.violation_num >= self.config.auto_ban_violation_threshold: @@ -95,7 +95,7 @@ class UserBanManager: # 只有在首次达到阈值时才更新封禁开始时间 if ban_record.violation_num == self.config.auto_ban_violation_threshold: ban_record.created_at = datetime.datetime.now() - session.commit() + await session.commit() else: logger.info(f"用户 {platform}:{user_id} 违规记录已更新,当前违规次数: {ban_record.violation_num}") diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index e2a6eb7f1..6b2c8df5a 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -166,7 +166,7 @@ class MaiEmoji: usage_count=self.usage_count, last_used_time=self.last_used_time, ) - session.add(emoji) + await session.add(emoji) await session.commit() logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index 0b4b0b2e3..5b78f4d3d 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -117,8 +117,8 @@ class InstantMemory: create_time=memory_item.create_time, last_view_time=memory_item.last_view_time, ) - session.add(memory) - session.commit() + await session.add(memory) + await session.commit() async def get_memory(self, target: str): from json_repair import repair_json diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index e72d99686..6360928b1 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -147,7 +147,7 @@ class ChatManager: # db.connect(reuse_if_open=True) # # 确保 ChatStreams 表存在 # session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)")) - # session.commit() + # await session.commit() # except Exception as e: # logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 4bdaa9edc..159d33aae 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -122,7 +122,8 @@ class MessageStorage: is_picid=is_picid, ) async with get_db_session() as session: - session.add(new_message) + await session.add(new_message) + await session.commit() except Exception: logger.exception("存储消息失败") @@ -161,9 +162,6 @@ class MessageStorage: logger.debug(f"消息段数据: {message.message_segment.data}") return - # 使用上下文管理器确保session正确管理 - from src.common.database.sqlalchemy_models import get_db_session - async with get_db_session() as session: matched_message = ( await session.execute( diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 93ec14957..bcfc6e7fd 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -128,7 +128,7 @@ class ImageManager: description=description, timestamp=current_timestamp, ) - session.add(new_desc) + await session.add(new_desc) await session.commit() # 会在上下文管理器中自动调用 except Exception as e: @@ -278,7 +278,7 @@ class ImageManager: description=detailed_description, # 保存详细描述 timestamp=current_timestamp, ) - session.add(new_img) + await session.add(new_img) await session.commit() except Exception as e: logger.error(f"保存到Images表失败: {str(e)}") @@ -370,7 +370,7 @@ class ImageManager: vlm_processed=True, count=1, ) - session.add(new_img) + await session.add(new_img) logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") await session.commit() @@ -590,7 +590,7 @@ class ImageManager: vlm_processed=True, count=1, ) - session.add(new_img) + await session.add(new_img) await session.commit() return image_id, f"[picid:{image_id}]" diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 6ea5a111f..e249bc133 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -242,7 +242,7 @@ class VideoAnalyzer: existing_video.fps = metadata.get("fps") existing_video.resolution = metadata.get("resolution") existing_video.file_size = metadata.get("file_size") - session.commit() + await session.commit() session.refresh(existing_video) logger.info(f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}") return existing_video @@ -257,8 +257,8 @@ class VideoAnalyzer: video_record.resolution = metadata.get("resolution") video_record.file_size = metadata.get("file_size") - session.add(video_record) - session.commit() + await session.add(video_record) + await session.commit() session.refresh(video_record) logger.info(f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}...") return video_record diff --git a/src/common/database/database.py b/src/common/database/database.py index 3279a67ed..293f0cd1f 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -43,7 +43,7 @@ class SQLAlchemyTransaction: def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: - self.session.commit() + self.await session.commit() else: self.session.rollback() self.session.close() diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 13ef39c1a..63de1e43b 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -168,7 +168,7 @@ async def db_query( # 创建新记录 new_record = model_class(**data) - session.add(new_record) + await session.add(new_record) await session.flush() # 获取自动生成的ID # 转换为字典格式返回 @@ -295,7 +295,7 @@ async def db_save( # 创建新记录 new_record = model_class(**data) - session.add(new_record) + await session.add(new_record) await session.flush() # 转换为字典格式返回 diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 0c193e358..2b276213d 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -676,7 +676,6 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: raise RuntimeError("Database session not initialized") session = SessionLocal() yield session - # await session.commit() except Exception: if session: await session.rollback() diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 63d4c000d..7c620d2c7 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -201,5 +201,5 @@ async def count_messages(message_filter: dict[str, Any]) -> int: # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 -# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。 +# 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。 # 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。 diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index bf23f144a..659fc5399 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -178,7 +178,7 @@ class LLMUsageRecorder: timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段 ) - session.add(usage_record) + await session.add(usage_record) await session.commit() logger.debug( diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 3a036d029..1f8cb843c 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -180,7 +180,7 @@ class PersonInfoManager: async with get_db_session() as session: try: new_person = PersonInfo(**p_data) - session.add(new_person) + await session.add(new_person) await session.commit() return True except Exception as e: @@ -245,7 +245,7 @@ class PersonInfoManager: # 尝试创建 new_person = PersonInfo(**p_data) - session.add(new_person) + await session.add(new_person) await session.commit() return True except Exception as e: @@ -607,7 +607,7 @@ class PersonInfoManager: # 记录不存在,尝试创建 try: new_person = PersonInfo(**init_data) - session.add(new_person) + await session.add(new_person) await session.commit() await session.refresh(new_person) return new_person, True # 创建成功 diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index eb6083fc9..db7ef9b1a 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -149,7 +149,7 @@ class PermissionManager(IPermissionManager): default_granted=node.default_granted, created_at=datetime.utcnow(), ) - session.add(new_node) + await session.add(new_node) await session.commit() logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})") return True @@ -204,7 +204,7 @@ class PermissionManager(IPermissionManager): granted=True, granted_at=datetime.utcnow(), ) - session.add(new_perm) + await session.add(new_perm) await session.commit() logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") @@ -257,7 +257,7 @@ class PermissionManager(IPermissionManager): granted=False, granted_at=datetime.utcnow(), ) - session.add(new_perm) + await session.add(new_perm) await session.commit() logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index 69ec0956e..ca6dc52c3 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -186,8 +186,8 @@ class SchedulerService: story_content=content, send_success=success, ) - session.add(new_record) - session.commit() + await session.add(new_record) + await session.commit() logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}") except Exception as e: logger.error(f"更新日程处理状态时发生数据库错误: {e}") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/database.py b/src/plugins/built_in/napcat_adapter_plugin/src/database.py index 74842eed5..23b5d1f5d 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/database.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/database.py @@ -83,14 +83,14 @@ class DatabaseManager: continue # 更新现有记录的 lift_time existing_record.lift_time = ban_user.lift_time - session.add(existing_record) + await session.add(existing_record) logger.debug(f"更新禁言记录: {existing_record}") else: # 创建新记录 db_record = DB_BanUser( user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time ) - session.add(db_record) + await session.add(db_record) logger.debug(f"创建新禁言记录: {ban_user}") # 删除不在 ban_list 中的记录 for db_record in all_records: @@ -132,14 +132,14 @@ class DatabaseManager: if existing_record: # 如果记录已存在,更新 lift_time existing_record.lift_time = ban_record.lift_time - session.add(existing_record) + await session.add(existing_record) logger.debug(f"更新禁言记录: {ban_record}") else: # 如果记录不存在,创建新记录 db_record = DB_BanUser( user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time ) - session.add(db_record) + await session.add(db_record) logger.debug(f"创建新禁言记录: {ban_record}") def delete_ban_record(self, ban_record: BanUser): diff --git a/src/schedule/database.py b/src/schedule/database.py index 5025c1fa3..b420f0686 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -42,7 +42,7 @@ async def add_new_plans(plans: List[str], month: str): new_plan_objects = [ MonthlyPlan(plan_text=plan, target_month=month, status="active") for plan in plans_to_add ] - session.add_all(new_plan_objects) + await session.add_all(new_plan_objects) await session.commit() logger.info(f"成功向数据库添加了 {len(new_plan_objects)} 条 {month} 的月度计划。") diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index 115480381..4e66bf0c8 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -128,7 +128,7 @@ class ScheduleManager: existing_schedule.updated_at = datetime.now() else: new_schedule = Schedule(date=date_str, schedule_data=schedule_json) - session.add(new_schedule) + await session.add(new_schedule) await session.commit() @staticmethod From 5f3203c6c95a460287ae0031d3f2b5099e915c13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 17:25:48 +0800 Subject: [PATCH 17/31] =?UTF-8?q?refactor(db):=20=E4=BF=AE=E6=AD=A3SQLAlch?= =?UTF-8?q?emy=E5=BC=82=E6=AD=A5=E6=93=8D=E4=BD=9C=E8=B0=83=E7=94=A8?= =?UTF-8?q?=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 移除session.add()方法的不必要await调用,修正异步数据库操作模式。主要变更包括: - 将 `await session.add()` 统一改为 `session.add()` - 修正部分函数调用为异步版本(如消息查询函数) - 重构SQLAlchemyTransaction为完全异步实现 - 重写napcat_adapter_plugin数据库层以符合异步规范 - 添加aiomysql和aiosqlite依赖支持 --- .../old/config.toml.bak.20250907_121908 | 25 ++ .../management/statistics.py | 4 +- .../antipromptinjector/management/user_ban.py | 2 +- src/chat/emoji_system/emoji_manager.py | 2 +- src/chat/express/expression_learner.py | 4 +- src/chat/memory_system/instant_memory.py | 2 +- src/chat/message_receive/storage.py | 2 +- src/chat/utils/utils_image.py | 8 +- src/chat/utils/utils_video.py | 54 ++-- src/common/database/database.py | 29 +- .../database/sqlalchemy_database_api.py | 4 +- src/common/message_repository.py | 2 +- src/llm_models/utils.py | 2 +- src/mais4u/mais4u_chat/s4u_mood_manager.py | 4 +- src/mood/mood_manager.py | 4 +- src/person_info/person_info.py | 6 +- src/person_info/relationship_builder.py | 45 ++- src/plugin_system/core/permission_manager.py | 6 +- .../services/scheduler_service.py | 2 +- .../napcat_adapter_plugin/src/database.py | 262 +++++++++--------- .../src/recv_handler/notice_handler.py | 29 +- .../napcat_adapter_plugin/src/utils.py | 13 +- src/schedule/database.py | 2 +- src/schedule/schedule_manager.py | 2 +- uv.lock | 29 +- 25 files changed, 299 insertions(+), 245 deletions(-) create mode 100644 plugins/napcat_adapter_plugin/config/old/config.toml.bak.20250907_121908 diff --git a/plugins/napcat_adapter_plugin/config/old/config.toml.bak.20250907_121908 b/plugins/napcat_adapter_plugin/config/old/config.toml.bak.20250907_121908 new file mode 100644 index 000000000..1ddca6cf5 --- /dev/null +++ b/plugins/napcat_adapter_plugin/config/old/config.toml.bak.20250907_121908 @@ -0,0 +1,25 @@ +[inner] +version = "0.2.0" # 版本号 +# 请勿修改版本号,除非你知道自己在做什么 + +[nickname] # 现在没用 +nickname = "" + +[napcat_server] # Napcat连接的ws服务设置 +mode = "reverse" # 连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端) +host = "localhost" # 主机地址 +port = 8095 # 端口号 +url = "" # 正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用) +access_token = "" # WebSocket 连接的访问令牌,用于身份验证(可选) +heartbeat_interval = 30 # 心跳间隔时间(按秒计) + +[maibot_server] # 连接麦麦的ws服务设置 +host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段 +port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段 + +[voice] # 发送语音设置 +use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter) + +[debug] +level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL) + diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index e9b4be66b..2cfe3e13c 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -32,7 +32,7 @@ class AntiInjectionStatistics: stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() if not stats: stats = AntiInjectionStats() - await session.add(stats) + session.add(stats) await session.commit() await session.refresh(stats) return stats @@ -48,7 +48,7 @@ class AntiInjectionStatistics: stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() if not stats: stats = AntiInjectionStats() - await session.add(stats) + session.add(stats) # 更新统计字段 for key, value in kwargs.items(): diff --git a/src/chat/antipromptinjector/management/user_ban.py b/src/chat/antipromptinjector/management/user_ban.py index 865ddddb9..676436c42 100644 --- a/src/chat/antipromptinjector/management/user_ban.py +++ b/src/chat/antipromptinjector/management/user_ban.py @@ -85,7 +85,7 @@ class UserBanManager: reason=f"提示词注入攻击 (置信度: {detection_result.confidence:.2f})", created_at=datetime.datetime.now(), ) - await session.add(ban_record) + session.add(ban_record) await session.commit() diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 6b2c8df5a..e2a6eb7f1 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -166,7 +166,7 @@ class MaiEmoji: usage_count=self.usage_count, last_used_time=self.last_used_time, ) - await session.add(emoji) + session.add(emoji) await session.commit() logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index a709ee78f..b7dabe6e1 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -381,7 +381,7 @@ class ExpressionLearner: type=type, create_date=current_time, # 手动设置创建日期 ) - await session.add(new_expression) + session.add(new_expression) # 限制最大数量 exprs_result = await session.execute( @@ -608,7 +608,7 @@ class ExpressionLearnerManager: type=type_str, create_date=last_active_time, # 迁移时使用last_active_time作为创建时间 ) - await session.add(new_expression) + session.add(new_expression) migrated_count += 1 logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式") diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index 5b78f4d3d..a8675f5c0 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -117,7 +117,7 @@ class InstantMemory: create_time=memory_item.create_time, last_view_time=memory_item.last_view_time, ) - await session.add(memory) + session.add(memory) await session.commit() async def get_memory(self, target: str): diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 159d33aae..015578be8 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -122,7 +122,7 @@ class MessageStorage: is_picid=is_picid, ) async with get_db_session() as session: - await session.add(new_message) + session.add(new_message) await session.commit() except Exception: diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index bcfc6e7fd..93ec14957 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -128,7 +128,7 @@ class ImageManager: description=description, timestamp=current_timestamp, ) - await session.add(new_desc) + session.add(new_desc) await session.commit() # 会在上下文管理器中自动调用 except Exception as e: @@ -278,7 +278,7 @@ class ImageManager: description=detailed_description, # 保存详细描述 timestamp=current_timestamp, ) - await session.add(new_img) + session.add(new_img) await session.commit() except Exception as e: logger.error(f"保存到Images表失败: {str(e)}") @@ -370,7 +370,7 @@ class ImageManager: vlm_processed=True, count=1, ) - await session.add(new_img) + session.add(new_img) logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") await session.commit() @@ -590,7 +590,7 @@ class ImageManager: vlm_processed=True, count=1, ) - await session.add(new_img) + session.add(new_img) await session.commit() return image_id, f"[picid:{image_id}]" diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index e249bc133..8cb294f3e 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -22,6 +22,7 @@ from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger from src.common.database.sqlalchemy_models import get_db_session, Videos +from sqlalchemy import select logger = get_logger("utils_video") @@ -205,34 +206,29 @@ class VideoAnalyzer: return hash_obj.hexdigest() @staticmethod - def _check_video_exists(video_hash: str) -> Optional[Videos]: - """检查视频是否已经分析过""" + async def _check_video_exists(video_hash: str) -> Optional[Videos]: + """检查视频是否已经分析过 (异步)""" try: - with get_db_session() as session: - # 明确刷新会话以确保看到其他事务的最新提交 - session.expire_all() - return session.query(Videos).filter(Videos.video_hash == video_hash).first() + async with get_db_session() as session: + result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) + return result.scalar_one_or_none() except Exception as e: logger.warning(f"检查视频是否存在时出错: {e}") return None @staticmethod - def _store_video_result( - video_hash: str, description: str, metadata: Optional[Dict] = None + async def _store_video_result( + video_hash: str, description: str, metadata: Optional[Dict] = None ) -> Optional[Videos]: - """存储视频分析结果到数据库""" - # 检查描述是否为错误信息,如果是则不保存 + """存储视频分析结果到数据库 (异步)""" if description.startswith("❌"): logger.warning(f"⚠️ 检测到错误信息,不保存到数据库: {description[:50]}...") return None - try: - with get_db_session() as session: - # 只根据video_hash查找 - existing_video = session.query(Videos).filter(Videos.video_hash == video_hash).first() - + async with get_db_session() as session: + result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) + existing_video = result.scalar_one_or_none() if existing_video: - # 如果已存在,更新描述和计数 existing_video.description = description existing_video.count += 1 existing_video.timestamp = time.time() @@ -243,12 +239,17 @@ class VideoAnalyzer: existing_video.resolution = metadata.get("resolution") existing_video.file_size = metadata.get("file_size") await session.commit() - session.refresh(existing_video) - logger.info(f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}") + await session.refresh(existing_video) + logger.info( + f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}" + ) return existing_video else: video_record = Videos( - video_hash=video_hash, description=description, timestamp=time.time(), count=1 + video_hash=video_hash, + description=description, + timestamp=time.time(), + count=1, ) if metadata: video_record.duration = metadata.get("duration") @@ -256,11 +257,12 @@ class VideoAnalyzer: video_record.fps = metadata.get("fps") video_record.resolution = metadata.get("resolution") video_record.file_size = metadata.get("file_size") - - await session.add(video_record) + session.add(video_record) await session.commit() - session.refresh(video_record) - logger.info(f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}...") + await session.refresh(video_record) + logger.info( + f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}..." + ) return video_record except Exception as e: logger.error(f"❌ 存储视频分析结果时出错: {e}") @@ -708,7 +710,7 @@ class VideoAnalyzer: logger.info("✅ 等待结束,检查是否有处理结果") # 检查是否有结果了 - existing_video = self._check_video_exists(video_hash) + existing_video = await self._check_video_exists(video_hash) if existing_video: logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})") return {"summary": existing_video.description} @@ -722,7 +724,7 @@ class VideoAnalyzer: logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)") # 再次检查数据库(可能在等待期间已经有结果了) - existing_video = self._check_video_exists(video_hash) + existing_video = await self._check_video_exists(video_hash) if existing_video: logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})") video_event.set() # 通知其他等待者 @@ -753,7 +755,7 @@ class VideoAnalyzer: # 保存分析结果到数据库(仅保存成功的结果) if success: metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()} - self._store_video_result(video_hash=video_hash, description=result, metadata=metadata) + await self._store_video_result(video_hash=video_hash, description=result, metadata=metadata) logger.info("✅ 分析结果已保存到数据库") else: logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试") diff --git a/src/common/database/database.py b/src/common/database/database.py index 293f0cd1f..6a34d900e 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -32,21 +32,32 @@ class DatabaseProxy: class SQLAlchemyTransaction: - """SQLAlchemy事务上下文管理器""" + """SQLAlchemy 异步事务上下文管理器 (兼容旧代码示例,推荐直接使用 get_db_session)。""" def __init__(self): + self._ctx = None self.session = None - def __enter__(self): - self.session = get_db_session() + async def __aenter__(self): + # get_db_session 是一个 async contextmanager + self._ctx = get_db_session() + self.session = await self._ctx.__aenter__() return self.session - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None: - self.await session.commit() - else: - self.session.rollback() - self.session.close() + async def __aexit__(self, exc_type, exc_val, exc_tb): + try: + if self.session: + if exc_type is None: + try: + await self.session.commit() + except Exception: + await self.session.rollback() + raise + else: + await self.session.rollback() + finally: + if self._ctx: + await self._ctx.__aexit__(exc_type, exc_val, exc_tb) # 创建全局数据库代理实例 diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 63de1e43b..13ef39c1a 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -168,7 +168,7 @@ async def db_query( # 创建新记录 new_record = model_class(**data) - await session.add(new_record) + session.add(new_record) await session.flush() # 获取自动生成的ID # 转换为字典格式返回 @@ -295,7 +295,7 @@ async def db_save( # 创建新记录 new_record = model_class(**data) - await session.add(new_record) + session.add(new_record) await session.flush() # 转换为字典格式返回 diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 7c620d2c7..96714db1f 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -201,5 +201,5 @@ async def count_messages(message_filter: dict[str, Any]) -> int: # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 -# 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。 +# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 await session.commit()。 # 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。 diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index 659fc5399..bf23f144a 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -178,7 +178,7 @@ class LLMUsageRecorder: timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段 ) - await session.add(usage_record) + session.add(usage_record) await session.commit() logger.debug( diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index fa12523a4..d235843d4 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -160,7 +160,7 @@ class ChatMood: self.regression_count = 0 message_time: float = message.message_info.time # type: ignore - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, @@ -239,7 +239,7 @@ class ChatMood: async def regress_mood(self): message_time = time.time() - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index ef4416673..5138a7d5d 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -98,7 +98,7 @@ class ChatMood: ) message_time: float = message.message_info.time # type: ignore - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, @@ -147,7 +147,7 @@ class ChatMood: async def regress_mood(self): message_time = time.time() - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 1f8cb843c..3a036d029 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -180,7 +180,7 @@ class PersonInfoManager: async with get_db_session() as session: try: new_person = PersonInfo(**p_data) - await session.add(new_person) + session.add(new_person) await session.commit() return True except Exception as e: @@ -245,7 +245,7 @@ class PersonInfoManager: # 尝试创建 new_person = PersonInfo(**p_data) - await session.add(new_person) + session.add(new_person) await session.commit() return True except Exception as e: @@ -607,7 +607,7 @@ class PersonInfoManager: # 记录不存在,尝试创建 try: new_person = PersonInfo(**init_data) - await session.add(new_person) + session.add(new_person) await session.commit() await session.refresh(new_person) return new_person, True # 创建成功 diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 720076eb2..1ff90a99d 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -3,7 +3,7 @@ import traceback import os import pickle import random -from typing import List, Dict, Any, Coroutine +from typing import List, Dict, Any from src.config.config import global_config from src.common.logger import get_logger from src.person_info.relationship_manager import get_relationship_manager @@ -113,7 +113,7 @@ class RelationshipBuilder: # 负责跟踪用户消息活动、管理消息段、清理过期数据 # ================================ - def _update_message_segments(self, person_id: str, message_time: float): + async def _update_message_segments(self, person_id: str, message_time: float): """更新用户的消息段 Args: @@ -126,11 +126,8 @@ class RelationshipBuilder: segments = self.person_engaged_cache[person_id] # 获取该消息前5条消息的时间作为潜在的开始时间 - before_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, message_time, limit=5) - if before_messages: - potential_start_time = before_messages[0]["time"] - else: - potential_start_time = message_time + before_messages = await get_raw_msg_before_timestamp_with_chat(self.chat_id, message_time, limit=5) + potential_start_time = before_messages[0]["time"] if before_messages else message_time # 如果没有现有消息段,创建新的 if not segments: @@ -138,10 +135,9 @@ class RelationshipBuilder: "start_time": potential_start_time, "end_time": message_time, "last_msg_time": message_time, - "message_count": self._count_messages_in_timerange(potential_start_time, message_time), + "message_count": await self._count_messages_in_timerange(potential_start_time, message_time), } segments.append(new_segment) - person_name = get_person_info_manager().get_value_sync(person_id, "person_name") or person_id logger.debug( f"{self.log_prefix} 眼熟用户 {person_name} 在 {time.strftime('%H:%M:%S', time.localtime(potential_start_time))} - {time.strftime('%H:%M:%S', time.localtime(message_time))} 之间有 {new_segment['message_count']} 条消息" @@ -153,39 +149,32 @@ class RelationshipBuilder: last_segment = segments[-1] # 计算从最后一条消息到当前消息之间的消息数量(不包含边界) - messages_between = self._count_messages_between(last_segment["last_msg_time"], message_time) + messages_between = await self._count_messages_between(last_segment["last_msg_time"], message_time) if messages_between <= 10: - # 在10条消息内,延伸当前消息段 last_segment["end_time"] = message_time last_segment["last_msg_time"] = message_time - # 重新计算整个消息段的消息数量 - last_segment["message_count"] = self._count_messages_in_timerange( + last_segment["message_count"] = await self._count_messages_in_timerange( last_segment["start_time"], last_segment["end_time"] ) logger.debug(f"{self.log_prefix} 延伸用户 {person_id} 的消息段: {last_segment}") else: - # 超过10条消息,结束当前消息段并创建新的 - # 结束当前消息段:延伸到原消息段最后一条消息后5条消息的时间 current_time = time.time() - after_messages = get_raw_msg_by_timestamp_with_chat( + after_messages = await get_raw_msg_by_timestamp_with_chat( self.chat_id, last_segment["last_msg_time"], current_time, limit=5, limit_mode="earliest" ) if after_messages and len(after_messages) >= 5: - # 如果有足够的后续消息,使用第5条消息的时间作为结束时间 last_segment["end_time"] = after_messages[4]["time"] - # 重新计算当前消息段的消息数量 - last_segment["message_count"] = self._count_messages_in_timerange( + last_segment["message_count"] = await self._count_messages_in_timerange( last_segment["start_time"], last_segment["end_time"] ) - # 创建新的消息段 new_segment = { "start_time": potential_start_time, "end_time": message_time, "last_msg_time": message_time, - "message_count": self._count_messages_in_timerange(potential_start_time, message_time), + "message_count": await self._count_messages_in_timerange(potential_start_time, message_time), } segments.append(new_segment) person_info_manager = get_person_info_manager() @@ -196,14 +185,14 @@ class RelationshipBuilder: self._save_cache() - def _count_messages_in_timerange(self, start_time: float, end_time: float) -> int: + async def _count_messages_in_timerange(self, start_time: float, end_time: float) -> int: """计算指定时间范围内的消息数量(包含边界)""" - messages = get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) + messages = await get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) return len(messages) - def _count_messages_between(self, start_time: float, end_time: float) -> Coroutine[Any, Any, int]: + async def _count_messages_between(self, start_time: float, end_time: float) -> int: """计算两个时间点之间的消息数量(不包含边界),用于间隔检查""" - return num_new_messages_since(self.chat_id, start_time, end_time) + return await num_new_messages_since(self.chat_id, start_time, end_time) def _get_total_message_count(self, person_id: str) -> int: """获取用户所有消息段的总消息数量""" @@ -350,7 +339,7 @@ class RelationshipBuilder: self._cleanup_old_segments() current_time = time.time() - if latest_messages := get_raw_msg_by_timestamp_with_chat( + if latest_messages := await get_raw_msg_by_timestamp_with_chat( self.chat_id, self.last_processed_message_time, current_time, @@ -369,7 +358,7 @@ class RelationshipBuilder: and msg_time > self.last_processed_message_time ): person_id = PersonInfoManager.get_person_id(platform, user_id) - self._update_message_segments(person_id, msg_time) + await self._update_message_segments(person_id, msg_time) logger.debug( f"{self.log_prefix} 更新用户 {person_id} 的消息段,消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(msg_time))}" ) @@ -439,7 +428,7 @@ class RelationshipBuilder: start_date = time.strftime("%Y-%m-%d %H:%M", time.localtime(start_time)) # 获取该段的消息(包含边界) - segment_messages = get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) + segment_messages = await get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) logger.debug( f"消息段: {start_date} - {time.strftime('%Y-%m-%d %H:%M', time.localtime(end_time))}, 消息数: {len(segment_messages)}" ) diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index db7ef9b1a..eb6083fc9 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -149,7 +149,7 @@ class PermissionManager(IPermissionManager): default_granted=node.default_granted, created_at=datetime.utcnow(), ) - await session.add(new_node) + session.add(new_node) await session.commit() logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})") return True @@ -204,7 +204,7 @@ class PermissionManager(IPermissionManager): granted=True, granted_at=datetime.utcnow(), ) - await session.add(new_perm) + session.add(new_perm) await session.commit() logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") @@ -257,7 +257,7 @@ class PermissionManager(IPermissionManager): granted=False, granted_at=datetime.utcnow(), ) - await session.add(new_perm) + session.add(new_perm) await session.commit() logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index ca6dc52c3..6124f4f06 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -186,7 +186,7 @@ class SchedulerService: story_content=content, send_success=success, ) - await session.add(new_record) + session.add(new_record) await session.commit() logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}") except Exception as e: diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/database.py b/src/plugins/built_in/napcat_adapter_plugin/src/database.py index 23b5d1f5d..1620ec304 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/database.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/database.py @@ -1,162 +1,156 @@ -import os -from typing import Optional, List -from dataclasses import dataclass -from sqlmodel import Field, Session, SQLModel, create_engine, select +"""Napcat Adapter 插件数据库层 (基于主程序异步SQLAlchemy API) +本模块替换原先的 sqlmodel + 同步Session 实现: +1. 复用主项目的异步数据库连接与迁移体系 +2. 提供与旧接口名兼容的方法(update_ban_record/create_ban_record/delete_ban_record) +3. 新增首选异步方法: update_ban_records / create_or_update / delete_record / get_ban_records + +数据语义: + user_id == 0 表示群全体禁言 + +注意: 所有方法均为异步, 需要在 async 上下文中调用。 +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, List, Sequence + +from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index +from sqlalchemy.ext.asyncio import AsyncSession + +from src.common.database.sqlalchemy_models import Base, get_db_session from src.common.logger import get_logger logger = get_logger("napcat_adapter") -""" -表记录的方式: -| group_id | user_id | lift_time | -|----------|---------|-----------| -其中使用 user_id == 0 表示群全体禁言 -""" +class NapcatBanRecord(Base): + __tablename__ = "napcat_ban_records" + + id = Column(Integer, primary_key=True, autoincrement=True) + group_id = Column(BigInteger, nullable=False, index=True) + user_id = Column(BigInteger, nullable=False, index=True) # 0 == 全体禁言 + lift_time = Column(BigInteger, nullable=True) # -1 / None 表示未知/永久 + + __table_args__ = ( + UniqueConstraint("group_id", "user_id", name="uq_napcat_group_user"), + Index("idx_napcat_ban_group", "group_id"), + Index("idx_napcat_ban_user", "user_id"), + ) @dataclass class BanUser: - """ - 程序处理使用的实例 - """ - user_id: int group_id: int - lift_time: Optional[int] = Field(default=-1) + lift_time: Optional[int] = -1 + + def identity(self) -> tuple[int, int]: + return self.group_id, self.user_id -class DB_BanUser(SQLModel, table=True): - """ - 表示数据库中的用户禁言记录。 - 使用双重主键 - """ +class NapcatDatabase: + async def _fetch_all(self, session: AsyncSession) -> Sequence[NapcatBanRecord]: + result = await session.execute(select(NapcatBanRecord)) + return result.scalars().all() - user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID - group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID - lift_time: Optional[int] # 禁言解除的时间(时间戳) + async def get_ban_records(self) -> List[BanUser]: + async with get_db_session() as session: + rows = await self._fetch_all(session) + return [BanUser(group_id=r.group_id, user_id=r.user_id, lift_time=r.lift_time) for r in rows] + async def update_ban_records(self, ban_list: List[BanUser]) -> None: + target_map = {b.identity(): b for b in ban_list} + async with get_db_session() as session: + rows = await self._fetch_all(session) + existing_map = {(r.group_id, r.user_id): r for r in rows} -def is_identical(obj1: BanUser, obj2: BanUser) -> bool: - """ - 检查两个 BanUser 对象是否相同。 - """ - return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id - - -class DatabaseManager: - """ - 数据库管理类,负责与数据库交互。 - """ - - def __init__(self): - os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在 - DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db") - self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL - self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎 - self._ensure_database() # 确保数据库和表已创建 - - def _ensure_database(self) -> None: - """ - 确保数据库和表已创建。 - """ - logger.info("确保数据库文件和表已创建...") - SQLModel.metadata.create_all(self.engine) - logger.info("数据库和表已创建或已存在") - - def update_ban_record(self, ban_list: List[BanUser]) -> None: - # sourcery skip: class-extract-method - """ - 更新禁言列表到数据库。 - 支持在不存在时创建新记录,对于多余的项目自动删除。 - """ - with Session(self.engine) as session: - all_records = session.exec(select(DB_BanUser)).all() - for ban_user in ban_list: - statement = select(DB_BanUser).where( - DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id - ) - if existing_record := session.exec(statement).first(): - if existing_record.lift_time == ban_user.lift_time: - logger.debug(f"禁言记录未变更: {existing_record}") - continue - # 更新现有记录的 lift_time - existing_record.lift_time = ban_user.lift_time - await session.add(existing_record) - logger.debug(f"更新禁言记录: {existing_record}") + changed = 0 + for ident, ban in target_map.items(): + if ident in existing_map: + row = existing_map[ident] + if row.lift_time != ban.lift_time: + row.lift_time = ban.lift_time + changed += 1 else: - # 创建新记录 - db_record = DB_BanUser( - user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time + session.add( + NapcatBanRecord(group_id=ban.group_id, user_id=ban.user_id, lift_time=ban.lift_time) ) - await session.add(db_record) - logger.debug(f"创建新禁言记录: {ban_user}") - # 删除不在 ban_list 中的记录 - for db_record in all_records: - record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time) - if not any(is_identical(record, ban_user) for ban_user in ban_list): - statement = select(DB_BanUser).where( - DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id - ) - if ban_record := session.exec(statement).first(): - session.delete(ban_record) + changed += 1 - logger.debug(f"删除禁言记录: {ban_record}") - else: - logger.info(f"未找到禁言记录: {ban_record}") + removed = 0 + for ident, row in existing_map.items(): + if ident not in target_map: + await session.delete(row) + removed += 1 - logger.info("禁言记录已更新") - - def get_ban_records(self) -> List[BanUser]: - """ - 读取所有禁言记录。 - """ - with Session(self.engine) as session: - statement = select(DB_BanUser) - records = session.exec(statement).all() - return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records] - - def create_ban_record(self, ban_record: BanUser) -> None: - """ - 为特定群组中的用户创建禁言记录。 - 一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。 - 其同时还是简化版的更新方式。 - """ - with Session(self.engine) as session: - # 检查记录是否已存在 - statement = select(DB_BanUser).where( - DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id + logger.debug( + f"Napcat ban list sync => total_incoming={len(ban_list)} created_or_updated={changed} removed={removed}" ) - existing_record = session.exec(statement).first() - if existing_record: - # 如果记录已存在,更新 lift_time - existing_record.lift_time = ban_record.lift_time - await session.add(existing_record) - logger.debug(f"更新禁言记录: {ban_record}") + + async def create_or_update(self, ban_record: BanUser) -> None: + async with get_db_session() as session: + stmt = select(NapcatBanRecord).where( + NapcatBanRecord.group_id == ban_record.group_id, + NapcatBanRecord.user_id == ban_record.user_id, + ) + result = await session.execute(stmt) + row = result.scalars().first() + if row: + if row.lift_time != ban_record.lift_time: + row.lift_time = ban_record.lift_time + logger.debug( + f"更新禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}" + ) else: - # 如果记录不存在,创建新记录 - db_record = DB_BanUser( - user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time + session.add( + NapcatBanRecord( + group_id=ban_record.group_id, user_id=ban_record.user_id, lift_time=ban_record.lift_time + ) + ) + logger.debug( + f"创建禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}" ) - await session.add(db_record) - logger.debug(f"创建新禁言记录: {ban_record}") - def delete_ban_record(self, ban_record: BanUser): - """ - 删除特定用户在特定群组中的禁言记录。 - 一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。 - """ - user_id = ban_record.user_id - group_id = ban_record.group_id - with Session(self.engine) as session: - statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id) - if ban_record := session.exec(statement).first(): - session.delete(ban_record) - - logger.debug(f"删除禁言记录: {ban_record}") + async def delete_record(self, ban_record: BanUser) -> None: + async with get_db_session() as session: + stmt = select(NapcatBanRecord).where( + NapcatBanRecord.group_id == ban_record.group_id, + NapcatBanRecord.user_id == ban_record.user_id, + ) + result = await session.execute(stmt) + row = result.scalars().first() + if row: + await session.delete(row) + logger.debug( + f"删除禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={row.lift_time}" + ) else: - logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}") + logger.info( + f"未找到禁言记录 group={ban_record.group_id} user={ban_record.user_id}" + ) + + # 兼容旧命名 + async def update_ban_record(self, ban_list: List[BanUser]) -> None: # old name + await self.update_ban_records(ban_list) + + async def create_ban_record(self, ban_record: BanUser) -> None: # old name + await self.create_or_update(ban_record) + + async def delete_ban_record(self, ban_record: BanUser) -> None: # old name + await self.delete_record(ban_record) -db_manager = DatabaseManager() +napcat_db = NapcatDatabase() + + +def is_identical(a: BanUser, b: BanUser) -> bool: + return a.group_id == b.group_id and a.user_id == b.user_id + + +__all__ = [ + "BanUser", + "NapcatBanRecord", + "napcat_db", + "is_identical", +] diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index a9eaead16..4a32657a7 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -9,7 +9,7 @@ from src.common.logger import get_logger logger = get_logger("napcat_adapter") from src.plugin_system.apis import config_api -from ..database import BanUser, db_manager, is_identical +from ..database import BanUser, napcat_db, is_identical from . import NoticeType, ACCEPT_FORMAT from .message_sending import message_send_instance from .message_handler import message_handler @@ -62,7 +62,7 @@ class NoticeHandler: return self.server_connection return websocket_manager.get_connection() - def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: + async def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: """ 将用户禁言记录添加到self.banned_list中 如果是全体禁言,则user_id为0 @@ -71,16 +71,16 @@ class NoticeHandler: user_id = 0 # 使用0表示全体禁言 lift_time = -1 ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time) - for record in self.banned_list: + for record in list(self.banned_list): if is_identical(record, ban_record): self.banned_list.remove(record) self.banned_list.append(ban_record) - db_manager.create_ban_record(ban_record) # 作为更新 + await napcat_db.create_ban_record(ban_record) # 更新 return self.banned_list.append(ban_record) - db_manager.create_ban_record(ban_record) # 添加到数据库 + await napcat_db.create_ban_record(ban_record) # 新建 - def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: + async def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: """ 从self.lifted_group_list中移除已经解除全体禁言的群 """ @@ -88,7 +88,12 @@ class NoticeHandler: user_id = 0 # 使用0表示全体禁言 ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1) self.lifted_list.append(ban_record) - db_manager.delete_ban_record(ban_record) # 删除数据库中的记录 + # 从被禁言列表里移除对应记录 + for record in list(self.banned_list): + if is_identical(record, ban_record): + self.banned_list.remove(record) + break + await napcat_db.delete_ban_record(ban_record) async def handle_notice(self, raw_message: dict) -> None: notice_type = raw_message.get("notice_type") @@ -376,7 +381,7 @@ class NoticeHandler: if user_id == 0: # 为全体禁言 sub_type: str = "whole_ban" - self._ban_operation(group_id) + await self._ban_operation(group_id) else: # 为单人禁言 # 获取被禁言人的信息 sub_type: str = "ban" @@ -390,7 +395,7 @@ class NoticeHandler: user_nickname=user_nickname, user_cardname=user_cardname, ) - self._ban_operation(group_id, user_id, int(time.time() + duration)) + await self._ban_operation(group_id, user_id, int(time.time() + duration)) seg_data: Seg = Seg( type="notify", @@ -439,7 +444,7 @@ class NoticeHandler: user_id = raw_message.get("user_id") if user_id == 0: # 全体禁言解除 sub_type = "whole_lift_ban" - self._lift_operation(group_id) + await self._lift_operation(group_id) else: # 单人禁言解除 sub_type = "lift_ban" # 获取被解除禁言人的信息 @@ -455,7 +460,7 @@ class NoticeHandler: user_nickname=user_nickname, user_cardname=user_cardname, ) - self._lift_operation(group_id, user_id) + await self._lift_operation(group_id, user_id) seg_data: Seg = Seg( type="notify", @@ -483,7 +488,7 @@ class NoticeHandler: group_id = lift_record.group_id user_id = lift_record.user_id - db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录 + asyncio.create_task(napcat_db.delete_ban_record(lift_record)) # 从数据库中删除禁言记录 seg_message: Seg = await self.natural_lift(group_id, user_id) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py index e36fc93fd..4c47a2570 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py @@ -6,7 +6,7 @@ import urllib3 import ssl import io -from .database import BanUser, db_manager +from .database import BanUser, napcat_db from src.common.logger import get_logger logger = get_logger("napcat_adapter") @@ -270,10 +270,11 @@ async def read_ban_list( ] """ try: - ban_list = db_manager.get_ban_records() + ban_list = await napcat_db.get_ban_records() lifted_list: List[BanUser] = [] logger.info("已经读取禁言列表") - for ban_record in ban_list: + # 复制列表以避免迭代中修改原列表问题 + for ban_record in list(ban_list): if ban_record.user_id == 0: fetched_group_info = await get_group_info(websocket, ban_record.group_id) if fetched_group_info is None: @@ -301,12 +302,12 @@ async def read_ban_list( ban_list.remove(ban_record) else: ban_record.lift_time = lift_ban_time - db_manager.update_ban_record(ban_list) + await napcat_db.update_ban_record(ban_list) return ban_list, lifted_list except Exception as e: logger.error(f"读取禁言列表失败: {e}") return [], [] -def save_ban_record(list: List[BanUser]): - return db_manager.update_ban_record(list) +async def save_ban_record(list: List[BanUser]): + return await napcat_db.update_ban_record(list) diff --git a/src/schedule/database.py b/src/schedule/database.py index b420f0686..5025c1fa3 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -42,7 +42,7 @@ async def add_new_plans(plans: List[str], month: str): new_plan_objects = [ MonthlyPlan(plan_text=plan, target_month=month, status="active") for plan in plans_to_add ] - await session.add_all(new_plan_objects) + session.add_all(new_plan_objects) await session.commit() logger.info(f"成功向数据库添加了 {len(new_plan_objects)} 条 {month} 的月度计划。") diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index 4e66bf0c8..115480381 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -128,7 +128,7 @@ class ScheduleManager: existing_schedule.updated_at = datetime.now() else: new_schedule = Schedule(date=date_str, schedule_data=schedule_json) - await session.add(new_schedule) + session.add(new_schedule) await session.commit() @staticmethod diff --git a/uv.lock b/uv.lock index 8e04441db..7a974afd0 100644 --- a/uv.lock +++ b/uv.lock @@ -154,6 +154,18 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/98/3b/40a68de458904bcc143622015fff2352b6461cd92fd66d3527bf1c6f5716/aiohttp_cors-0.8.1-py3-none-any.whl", hash = "sha256:3180cf304c5c712d626b9162b195b1db7ddf976a2a25172b35bb2448b890a80d", size = 25231, upload-time = "2025-03-31T14:16:18.478Z" }, ] +[[package]] +name = "aiomysql" +version = "0.2.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "pymysql" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/67/76/2c5b55e4406a1957ffdfd933a94c2517455291c97d2b81cec6813754791a/aiomysql-0.2.0.tar.gz", hash = "sha256:558b9c26d580d08b8c5fd1be23c5231ce3aeff2dadad989540fee740253deb67", size = 114706, upload-time = "2023-06-11T19:57:53.608Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/42/87/c982ee8b333c85b8ae16306387d703a1fcdfc81a2f3f15a24820ab1a512d/aiomysql-0.2.0-py3-none-any.whl", hash = "sha256:b7c26da0daf23a5ec5e0b133c03d20657276e4eae9b73e040b72787f6f6ade0a", size = 44215, upload-time = "2023-06-11T19:57:51.09Z" }, +] + [[package]] name = "aiosignal" version = "1.4.0" @@ -167,6 +179,18 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiosqlite" +version = "0.21.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -774,7 +798,6 @@ dependencies = [ { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }, marker = "python_full_version >= '3.11'" }, { name = "packaging" }, ] -sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e7/9a/e33fc563f007924dd4ec3c5101fe5320298d6c13c158a24a9ed849058569/faiss_cpu-1.11.0.tar.gz", hash = "sha256:44877b896a2b30a61e35ea4970d008e8822545cb340eca4eff223ac7f40a1db9", size = 70218, upload-time = "2025-04-28T07:48:30.459Z" } wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/e5/7490368ec421e44efd60a21aa88d244653c674d8d6ee6bc455d8ee3d02ed/faiss_cpu-1.11.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:1995119152928c68096b0c1e5816e3ee5b1eebcf615b80370874523be009d0f6", size = 3307996, upload-time = "2025-04-28T07:47:29.126Z" }, { url = "https://pypi.tuna.tsinghua.edu.cn/packages/dd/ac/a94fbbbf4f38c2ad11862af92c071ff346630ebf33f3d36fe75c3817c2f0/faiss_cpu-1.11.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:788d7bf24293fdecc1b93f1414ca5cc62ebd5f2fecfcbb1d77f0e0530621c95d", size = 7886309, upload-time = "2025-04-28T07:47:31.668Z" }, @@ -1693,6 +1716,8 @@ source = { virtual = "." } dependencies = [ { name = "aiohttp" }, { name = "aiohttp-cors" }, + { name = "aiomysql" }, + { name = "aiosqlite" }, { name = "apscheduler" }, { name = "asyncddgs" }, { name = "asyncio" }, @@ -1773,6 +1798,8 @@ lint = [ requires-dist = [ { name = "aiohttp", specifier = ">=3.12.14" }, { name = "aiohttp-cors", specifier = ">=0.8.1" }, + { name = "aiomysql", specifier = ">=0.2.0" }, + { name = "aiosqlite", specifier = ">=0.21.0" }, { name = "apscheduler", specifier = ">=3.11.0" }, { name = "asyncddgs", specifier = ">=0.1.0a1" }, { name = "asyncio", specifier = ">=4.0.0" }, From 832743249d40677916c3ed8ae461ae783abdb837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 17:26:28 +0800 Subject: [PATCH 18/31] =?UTF-8?q?refactor(db):=20=E4=BF=AE=E6=AD=A3SQLAlch?= =?UTF-8?q?emy=E5=BC=82=E6=AD=A5=E6=93=8D=E4=BD=9C=E8=B0=83=E7=94=A8?= =?UTF-8?q?=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 移除session.add()方法的不必要await调用,修正异步数据库操作模式。主要变更包括: - 将 `await session.add()` 统一改为 `session.add()` - 修正部分函数调用为异步版本(如消息查询函数) - 重构SQLAlchemyTransaction为完全异步实现 - 重写napcat_adapter_plugin数据库层以符合异步规范 - 添加aiomysql和aiosqlite依赖支持 --- .../management/statistics.py | 4 +- .../antipromptinjector/management/user_ban.py | 2 +- src/chat/emoji_system/emoji_manager.py | 2 +- src/chat/express/expression_learner.py | 4 +- src/chat/memory_system/instant_memory.py | 2 +- src/chat/message_receive/storage.py | 2 +- src/chat/utils/utils_image.py | 8 +- src/chat/utils/utils_video.py | 54 ++-- src/common/database/database.py | 29 +- .../database/sqlalchemy_database_api.py | 4 +- src/common/message_repository.py | 2 +- src/llm_models/utils.py | 2 +- src/mais4u/mais4u_chat/s4u_mood_manager.py | 4 +- src/mood/mood_manager.py | 4 +- src/person_info/person_info.py | 6 +- src/person_info/relationship_builder.py | 45 ++- src/plugin_system/core/permission_manager.py | 6 +- .../services/scheduler_service.py | 2 +- .../napcat_adapter_plugin/src/database.py | 262 +++++++++--------- .../src/recv_handler/notice_handler.py | 29 +- .../napcat_adapter_plugin/src/utils.py | 13 +- src/schedule/database.py | 2 +- src/schedule/schedule_manager.py | 2 +- 23 files changed, 246 insertions(+), 244 deletions(-) diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index e9b4be66b..2cfe3e13c 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -32,7 +32,7 @@ class AntiInjectionStatistics: stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() if not stats: stats = AntiInjectionStats() - await session.add(stats) + session.add(stats) await session.commit() await session.refresh(stats) return stats @@ -48,7 +48,7 @@ class AntiInjectionStatistics: stats = session.query(AntiInjectionStats).order_by(AntiInjectionStats.id.desc()).first() if not stats: stats = AntiInjectionStats() - await session.add(stats) + session.add(stats) # 更新统计字段 for key, value in kwargs.items(): diff --git a/src/chat/antipromptinjector/management/user_ban.py b/src/chat/antipromptinjector/management/user_ban.py index 865ddddb9..676436c42 100644 --- a/src/chat/antipromptinjector/management/user_ban.py +++ b/src/chat/antipromptinjector/management/user_ban.py @@ -85,7 +85,7 @@ class UserBanManager: reason=f"提示词注入攻击 (置信度: {detection_result.confidence:.2f})", created_at=datetime.datetime.now(), ) - await session.add(ban_record) + session.add(ban_record) await session.commit() diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 6b2c8df5a..e2a6eb7f1 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -166,7 +166,7 @@ class MaiEmoji: usage_count=self.usage_count, last_used_time=self.last_used_time, ) - await session.add(emoji) + session.add(emoji) await session.commit() logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index a709ee78f..b7dabe6e1 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -381,7 +381,7 @@ class ExpressionLearner: type=type, create_date=current_time, # 手动设置创建日期 ) - await session.add(new_expression) + session.add(new_expression) # 限制最大数量 exprs_result = await session.execute( @@ -608,7 +608,7 @@ class ExpressionLearnerManager: type=type_str, create_date=last_active_time, # 迁移时使用last_active_time作为创建时间 ) - await session.add(new_expression) + session.add(new_expression) migrated_count += 1 logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式") diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index 5b78f4d3d..a8675f5c0 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -117,7 +117,7 @@ class InstantMemory: create_time=memory_item.create_time, last_view_time=memory_item.last_view_time, ) - await session.add(memory) + session.add(memory) await session.commit() async def get_memory(self, target: str): diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 159d33aae..015578be8 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -122,7 +122,7 @@ class MessageStorage: is_picid=is_picid, ) async with get_db_session() as session: - await session.add(new_message) + session.add(new_message) await session.commit() except Exception: diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index bcfc6e7fd..93ec14957 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -128,7 +128,7 @@ class ImageManager: description=description, timestamp=current_timestamp, ) - await session.add(new_desc) + session.add(new_desc) await session.commit() # 会在上下文管理器中自动调用 except Exception as e: @@ -278,7 +278,7 @@ class ImageManager: description=detailed_description, # 保存详细描述 timestamp=current_timestamp, ) - await session.add(new_img) + session.add(new_img) await session.commit() except Exception as e: logger.error(f"保存到Images表失败: {str(e)}") @@ -370,7 +370,7 @@ class ImageManager: vlm_processed=True, count=1, ) - await session.add(new_img) + session.add(new_img) logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") await session.commit() @@ -590,7 +590,7 @@ class ImageManager: vlm_processed=True, count=1, ) - await session.add(new_img) + session.add(new_img) await session.commit() return image_id, f"[picid:{image_id}]" diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index e249bc133..8cb294f3e 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -22,6 +22,7 @@ from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.common.logger import get_logger from src.common.database.sqlalchemy_models import get_db_session, Videos +from sqlalchemy import select logger = get_logger("utils_video") @@ -205,34 +206,29 @@ class VideoAnalyzer: return hash_obj.hexdigest() @staticmethod - def _check_video_exists(video_hash: str) -> Optional[Videos]: - """检查视频是否已经分析过""" + async def _check_video_exists(video_hash: str) -> Optional[Videos]: + """检查视频是否已经分析过 (异步)""" try: - with get_db_session() as session: - # 明确刷新会话以确保看到其他事务的最新提交 - session.expire_all() - return session.query(Videos).filter(Videos.video_hash == video_hash).first() + async with get_db_session() as session: + result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) + return result.scalar_one_or_none() except Exception as e: logger.warning(f"检查视频是否存在时出错: {e}") return None @staticmethod - def _store_video_result( - video_hash: str, description: str, metadata: Optional[Dict] = None + async def _store_video_result( + video_hash: str, description: str, metadata: Optional[Dict] = None ) -> Optional[Videos]: - """存储视频分析结果到数据库""" - # 检查描述是否为错误信息,如果是则不保存 + """存储视频分析结果到数据库 (异步)""" if description.startswith("❌"): logger.warning(f"⚠️ 检测到错误信息,不保存到数据库: {description[:50]}...") return None - try: - with get_db_session() as session: - # 只根据video_hash查找 - existing_video = session.query(Videos).filter(Videos.video_hash == video_hash).first() - + async with get_db_session() as session: + result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) + existing_video = result.scalar_one_or_none() if existing_video: - # 如果已存在,更新描述和计数 existing_video.description = description existing_video.count += 1 existing_video.timestamp = time.time() @@ -243,12 +239,17 @@ class VideoAnalyzer: existing_video.resolution = metadata.get("resolution") existing_video.file_size = metadata.get("file_size") await session.commit() - session.refresh(existing_video) - logger.info(f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}") + await session.refresh(existing_video) + logger.info( + f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}" + ) return existing_video else: video_record = Videos( - video_hash=video_hash, description=description, timestamp=time.time(), count=1 + video_hash=video_hash, + description=description, + timestamp=time.time(), + count=1, ) if metadata: video_record.duration = metadata.get("duration") @@ -256,11 +257,12 @@ class VideoAnalyzer: video_record.fps = metadata.get("fps") video_record.resolution = metadata.get("resolution") video_record.file_size = metadata.get("file_size") - - await session.add(video_record) + session.add(video_record) await session.commit() - session.refresh(video_record) - logger.info(f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}...") + await session.refresh(video_record) + logger.info( + f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}..." + ) return video_record except Exception as e: logger.error(f"❌ 存储视频分析结果时出错: {e}") @@ -708,7 +710,7 @@ class VideoAnalyzer: logger.info("✅ 等待结束,检查是否有处理结果") # 检查是否有结果了 - existing_video = self._check_video_exists(video_hash) + existing_video = await self._check_video_exists(video_hash) if existing_video: logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})") return {"summary": existing_video.description} @@ -722,7 +724,7 @@ class VideoAnalyzer: logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)") # 再次检查数据库(可能在等待期间已经有结果了) - existing_video = self._check_video_exists(video_hash) + existing_video = await self._check_video_exists(video_hash) if existing_video: logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})") video_event.set() # 通知其他等待者 @@ -753,7 +755,7 @@ class VideoAnalyzer: # 保存分析结果到数据库(仅保存成功的结果) if success: metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()} - self._store_video_result(video_hash=video_hash, description=result, metadata=metadata) + await self._store_video_result(video_hash=video_hash, description=result, metadata=metadata) logger.info("✅ 分析结果已保存到数据库") else: logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试") diff --git a/src/common/database/database.py b/src/common/database/database.py index 293f0cd1f..6a34d900e 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -32,21 +32,32 @@ class DatabaseProxy: class SQLAlchemyTransaction: - """SQLAlchemy事务上下文管理器""" + """SQLAlchemy 异步事务上下文管理器 (兼容旧代码示例,推荐直接使用 get_db_session)。""" def __init__(self): + self._ctx = None self.session = None - def __enter__(self): - self.session = get_db_session() + async def __aenter__(self): + # get_db_session 是一个 async contextmanager + self._ctx = get_db_session() + self.session = await self._ctx.__aenter__() return self.session - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None: - self.await session.commit() - else: - self.session.rollback() - self.session.close() + async def __aexit__(self, exc_type, exc_val, exc_tb): + try: + if self.session: + if exc_type is None: + try: + await self.session.commit() + except Exception: + await self.session.rollback() + raise + else: + await self.session.rollback() + finally: + if self._ctx: + await self._ctx.__aexit__(exc_type, exc_val, exc_tb) # 创建全局数据库代理实例 diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 63de1e43b..13ef39c1a 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -168,7 +168,7 @@ async def db_query( # 创建新记录 new_record = model_class(**data) - await session.add(new_record) + session.add(new_record) await session.flush() # 获取自动生成的ID # 转换为字典格式返回 @@ -295,7 +295,7 @@ async def db_save( # 创建新记录 new_record = model_class(**data) - await session.add(new_record) + session.add(new_record) await session.flush() # 转换为字典格式返回 diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 7c620d2c7..96714db1f 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -201,5 +201,5 @@ async def count_messages(message_filter: dict[str, Any]) -> int: # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 -# 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。 +# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 await session.commit()。 # 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。 diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index 659fc5399..bf23f144a 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -178,7 +178,7 @@ class LLMUsageRecorder: timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段 ) - await session.add(usage_record) + session.add(usage_record) await session.commit() logger.debug( diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index fa12523a4..d235843d4 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -160,7 +160,7 @@ class ChatMood: self.regression_count = 0 message_time: float = message.message_info.time # type: ignore - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, @@ -239,7 +239,7 @@ class ChatMood: async def regress_mood(self): message_time = time.time() - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index ef4416673..5138a7d5d 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -98,7 +98,7 @@ class ChatMood: ) message_time: float = message.message_info.time # type: ignore - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, @@ -147,7 +147,7 @@ class ChatMood: async def regress_mood(self): message_time = time.time() - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 1f8cb843c..3a036d029 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -180,7 +180,7 @@ class PersonInfoManager: async with get_db_session() as session: try: new_person = PersonInfo(**p_data) - await session.add(new_person) + session.add(new_person) await session.commit() return True except Exception as e: @@ -245,7 +245,7 @@ class PersonInfoManager: # 尝试创建 new_person = PersonInfo(**p_data) - await session.add(new_person) + session.add(new_person) await session.commit() return True except Exception as e: @@ -607,7 +607,7 @@ class PersonInfoManager: # 记录不存在,尝试创建 try: new_person = PersonInfo(**init_data) - await session.add(new_person) + session.add(new_person) await session.commit() await session.refresh(new_person) return new_person, True # 创建成功 diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 720076eb2..1ff90a99d 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -3,7 +3,7 @@ import traceback import os import pickle import random -from typing import List, Dict, Any, Coroutine +from typing import List, Dict, Any from src.config.config import global_config from src.common.logger import get_logger from src.person_info.relationship_manager import get_relationship_manager @@ -113,7 +113,7 @@ class RelationshipBuilder: # 负责跟踪用户消息活动、管理消息段、清理过期数据 # ================================ - def _update_message_segments(self, person_id: str, message_time: float): + async def _update_message_segments(self, person_id: str, message_time: float): """更新用户的消息段 Args: @@ -126,11 +126,8 @@ class RelationshipBuilder: segments = self.person_engaged_cache[person_id] # 获取该消息前5条消息的时间作为潜在的开始时间 - before_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, message_time, limit=5) - if before_messages: - potential_start_time = before_messages[0]["time"] - else: - potential_start_time = message_time + before_messages = await get_raw_msg_before_timestamp_with_chat(self.chat_id, message_time, limit=5) + potential_start_time = before_messages[0]["time"] if before_messages else message_time # 如果没有现有消息段,创建新的 if not segments: @@ -138,10 +135,9 @@ class RelationshipBuilder: "start_time": potential_start_time, "end_time": message_time, "last_msg_time": message_time, - "message_count": self._count_messages_in_timerange(potential_start_time, message_time), + "message_count": await self._count_messages_in_timerange(potential_start_time, message_time), } segments.append(new_segment) - person_name = get_person_info_manager().get_value_sync(person_id, "person_name") or person_id logger.debug( f"{self.log_prefix} 眼熟用户 {person_name} 在 {time.strftime('%H:%M:%S', time.localtime(potential_start_time))} - {time.strftime('%H:%M:%S', time.localtime(message_time))} 之间有 {new_segment['message_count']} 条消息" @@ -153,39 +149,32 @@ class RelationshipBuilder: last_segment = segments[-1] # 计算从最后一条消息到当前消息之间的消息数量(不包含边界) - messages_between = self._count_messages_between(last_segment["last_msg_time"], message_time) + messages_between = await self._count_messages_between(last_segment["last_msg_time"], message_time) if messages_between <= 10: - # 在10条消息内,延伸当前消息段 last_segment["end_time"] = message_time last_segment["last_msg_time"] = message_time - # 重新计算整个消息段的消息数量 - last_segment["message_count"] = self._count_messages_in_timerange( + last_segment["message_count"] = await self._count_messages_in_timerange( last_segment["start_time"], last_segment["end_time"] ) logger.debug(f"{self.log_prefix} 延伸用户 {person_id} 的消息段: {last_segment}") else: - # 超过10条消息,结束当前消息段并创建新的 - # 结束当前消息段:延伸到原消息段最后一条消息后5条消息的时间 current_time = time.time() - after_messages = get_raw_msg_by_timestamp_with_chat( + after_messages = await get_raw_msg_by_timestamp_with_chat( self.chat_id, last_segment["last_msg_time"], current_time, limit=5, limit_mode="earliest" ) if after_messages and len(after_messages) >= 5: - # 如果有足够的后续消息,使用第5条消息的时间作为结束时间 last_segment["end_time"] = after_messages[4]["time"] - # 重新计算当前消息段的消息数量 - last_segment["message_count"] = self._count_messages_in_timerange( + last_segment["message_count"] = await self._count_messages_in_timerange( last_segment["start_time"], last_segment["end_time"] ) - # 创建新的消息段 new_segment = { "start_time": potential_start_time, "end_time": message_time, "last_msg_time": message_time, - "message_count": self._count_messages_in_timerange(potential_start_time, message_time), + "message_count": await self._count_messages_in_timerange(potential_start_time, message_time), } segments.append(new_segment) person_info_manager = get_person_info_manager() @@ -196,14 +185,14 @@ class RelationshipBuilder: self._save_cache() - def _count_messages_in_timerange(self, start_time: float, end_time: float) -> int: + async def _count_messages_in_timerange(self, start_time: float, end_time: float) -> int: """计算指定时间范围内的消息数量(包含边界)""" - messages = get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) + messages = await get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) return len(messages) - def _count_messages_between(self, start_time: float, end_time: float) -> Coroutine[Any, Any, int]: + async def _count_messages_between(self, start_time: float, end_time: float) -> int: """计算两个时间点之间的消息数量(不包含边界),用于间隔检查""" - return num_new_messages_since(self.chat_id, start_time, end_time) + return await num_new_messages_since(self.chat_id, start_time, end_time) def _get_total_message_count(self, person_id: str) -> int: """获取用户所有消息段的总消息数量""" @@ -350,7 +339,7 @@ class RelationshipBuilder: self._cleanup_old_segments() current_time = time.time() - if latest_messages := get_raw_msg_by_timestamp_with_chat( + if latest_messages := await get_raw_msg_by_timestamp_with_chat( self.chat_id, self.last_processed_message_time, current_time, @@ -369,7 +358,7 @@ class RelationshipBuilder: and msg_time > self.last_processed_message_time ): person_id = PersonInfoManager.get_person_id(platform, user_id) - self._update_message_segments(person_id, msg_time) + await self._update_message_segments(person_id, msg_time) logger.debug( f"{self.log_prefix} 更新用户 {person_id} 的消息段,消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(msg_time))}" ) @@ -439,7 +428,7 @@ class RelationshipBuilder: start_date = time.strftime("%Y-%m-%d %H:%M", time.localtime(start_time)) # 获取该段的消息(包含边界) - segment_messages = get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) + segment_messages = await get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time) logger.debug( f"消息段: {start_date} - {time.strftime('%Y-%m-%d %H:%M', time.localtime(end_time))}, 消息数: {len(segment_messages)}" ) diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index db7ef9b1a..eb6083fc9 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -149,7 +149,7 @@ class PermissionManager(IPermissionManager): default_granted=node.default_granted, created_at=datetime.utcnow(), ) - await session.add(new_node) + session.add(new_node) await session.commit() logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})") return True @@ -204,7 +204,7 @@ class PermissionManager(IPermissionManager): granted=True, granted_at=datetime.utcnow(), ) - await session.add(new_perm) + session.add(new_perm) await session.commit() logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") @@ -257,7 +257,7 @@ class PermissionManager(IPermissionManager): granted=False, granted_at=datetime.utcnow(), ) - await session.add(new_perm) + session.add(new_perm) await session.commit() logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index ca6dc52c3..6124f4f06 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -186,7 +186,7 @@ class SchedulerService: story_content=content, send_success=success, ) - await session.add(new_record) + session.add(new_record) await session.commit() logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}") except Exception as e: diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/database.py b/src/plugins/built_in/napcat_adapter_plugin/src/database.py index 23b5d1f5d..1620ec304 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/database.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/database.py @@ -1,162 +1,156 @@ -import os -from typing import Optional, List -from dataclasses import dataclass -from sqlmodel import Field, Session, SQLModel, create_engine, select +"""Napcat Adapter 插件数据库层 (基于主程序异步SQLAlchemy API) +本模块替换原先的 sqlmodel + 同步Session 实现: +1. 复用主项目的异步数据库连接与迁移体系 +2. 提供与旧接口名兼容的方法(update_ban_record/create_ban_record/delete_ban_record) +3. 新增首选异步方法: update_ban_records / create_or_update / delete_record / get_ban_records + +数据语义: + user_id == 0 表示群全体禁言 + +注意: 所有方法均为异步, 需要在 async 上下文中调用。 +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, List, Sequence + +from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index +from sqlalchemy.ext.asyncio import AsyncSession + +from src.common.database.sqlalchemy_models import Base, get_db_session from src.common.logger import get_logger logger = get_logger("napcat_adapter") -""" -表记录的方式: -| group_id | user_id | lift_time | -|----------|---------|-----------| -其中使用 user_id == 0 表示群全体禁言 -""" +class NapcatBanRecord(Base): + __tablename__ = "napcat_ban_records" + + id = Column(Integer, primary_key=True, autoincrement=True) + group_id = Column(BigInteger, nullable=False, index=True) + user_id = Column(BigInteger, nullable=False, index=True) # 0 == 全体禁言 + lift_time = Column(BigInteger, nullable=True) # -1 / None 表示未知/永久 + + __table_args__ = ( + UniqueConstraint("group_id", "user_id", name="uq_napcat_group_user"), + Index("idx_napcat_ban_group", "group_id"), + Index("idx_napcat_ban_user", "user_id"), + ) @dataclass class BanUser: - """ - 程序处理使用的实例 - """ - user_id: int group_id: int - lift_time: Optional[int] = Field(default=-1) + lift_time: Optional[int] = -1 + + def identity(self) -> tuple[int, int]: + return self.group_id, self.user_id -class DB_BanUser(SQLModel, table=True): - """ - 表示数据库中的用户禁言记录。 - 使用双重主键 - """ +class NapcatDatabase: + async def _fetch_all(self, session: AsyncSession) -> Sequence[NapcatBanRecord]: + result = await session.execute(select(NapcatBanRecord)) + return result.scalars().all() - user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID - group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID - lift_time: Optional[int] # 禁言解除的时间(时间戳) + async def get_ban_records(self) -> List[BanUser]: + async with get_db_session() as session: + rows = await self._fetch_all(session) + return [BanUser(group_id=r.group_id, user_id=r.user_id, lift_time=r.lift_time) for r in rows] + async def update_ban_records(self, ban_list: List[BanUser]) -> None: + target_map = {b.identity(): b for b in ban_list} + async with get_db_session() as session: + rows = await self._fetch_all(session) + existing_map = {(r.group_id, r.user_id): r for r in rows} -def is_identical(obj1: BanUser, obj2: BanUser) -> bool: - """ - 检查两个 BanUser 对象是否相同。 - """ - return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id - - -class DatabaseManager: - """ - 数据库管理类,负责与数据库交互。 - """ - - def __init__(self): - os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在 - DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db") - self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL - self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎 - self._ensure_database() # 确保数据库和表已创建 - - def _ensure_database(self) -> None: - """ - 确保数据库和表已创建。 - """ - logger.info("确保数据库文件和表已创建...") - SQLModel.metadata.create_all(self.engine) - logger.info("数据库和表已创建或已存在") - - def update_ban_record(self, ban_list: List[BanUser]) -> None: - # sourcery skip: class-extract-method - """ - 更新禁言列表到数据库。 - 支持在不存在时创建新记录,对于多余的项目自动删除。 - """ - with Session(self.engine) as session: - all_records = session.exec(select(DB_BanUser)).all() - for ban_user in ban_list: - statement = select(DB_BanUser).where( - DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id - ) - if existing_record := session.exec(statement).first(): - if existing_record.lift_time == ban_user.lift_time: - logger.debug(f"禁言记录未变更: {existing_record}") - continue - # 更新现有记录的 lift_time - existing_record.lift_time = ban_user.lift_time - await session.add(existing_record) - logger.debug(f"更新禁言记录: {existing_record}") + changed = 0 + for ident, ban in target_map.items(): + if ident in existing_map: + row = existing_map[ident] + if row.lift_time != ban.lift_time: + row.lift_time = ban.lift_time + changed += 1 else: - # 创建新记录 - db_record = DB_BanUser( - user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time + session.add( + NapcatBanRecord(group_id=ban.group_id, user_id=ban.user_id, lift_time=ban.lift_time) ) - await session.add(db_record) - logger.debug(f"创建新禁言记录: {ban_user}") - # 删除不在 ban_list 中的记录 - for db_record in all_records: - record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time) - if not any(is_identical(record, ban_user) for ban_user in ban_list): - statement = select(DB_BanUser).where( - DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id - ) - if ban_record := session.exec(statement).first(): - session.delete(ban_record) + changed += 1 - logger.debug(f"删除禁言记录: {ban_record}") - else: - logger.info(f"未找到禁言记录: {ban_record}") + removed = 0 + for ident, row in existing_map.items(): + if ident not in target_map: + await session.delete(row) + removed += 1 - logger.info("禁言记录已更新") - - def get_ban_records(self) -> List[BanUser]: - """ - 读取所有禁言记录。 - """ - with Session(self.engine) as session: - statement = select(DB_BanUser) - records = session.exec(statement).all() - return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records] - - def create_ban_record(self, ban_record: BanUser) -> None: - """ - 为特定群组中的用户创建禁言记录。 - 一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。 - 其同时还是简化版的更新方式。 - """ - with Session(self.engine) as session: - # 检查记录是否已存在 - statement = select(DB_BanUser).where( - DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id + logger.debug( + f"Napcat ban list sync => total_incoming={len(ban_list)} created_or_updated={changed} removed={removed}" ) - existing_record = session.exec(statement).first() - if existing_record: - # 如果记录已存在,更新 lift_time - existing_record.lift_time = ban_record.lift_time - await session.add(existing_record) - logger.debug(f"更新禁言记录: {ban_record}") + + async def create_or_update(self, ban_record: BanUser) -> None: + async with get_db_session() as session: + stmt = select(NapcatBanRecord).where( + NapcatBanRecord.group_id == ban_record.group_id, + NapcatBanRecord.user_id == ban_record.user_id, + ) + result = await session.execute(stmt) + row = result.scalars().first() + if row: + if row.lift_time != ban_record.lift_time: + row.lift_time = ban_record.lift_time + logger.debug( + f"更新禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}" + ) else: - # 如果记录不存在,创建新记录 - db_record = DB_BanUser( - user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time + session.add( + NapcatBanRecord( + group_id=ban_record.group_id, user_id=ban_record.user_id, lift_time=ban_record.lift_time + ) + ) + logger.debug( + f"创建禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={ban_record.lift_time}" ) - await session.add(db_record) - logger.debug(f"创建新禁言记录: {ban_record}") - def delete_ban_record(self, ban_record: BanUser): - """ - 删除特定用户在特定群组中的禁言记录。 - 一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。 - """ - user_id = ban_record.user_id - group_id = ban_record.group_id - with Session(self.engine) as session: - statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id) - if ban_record := session.exec(statement).first(): - session.delete(ban_record) - - logger.debug(f"删除禁言记录: {ban_record}") + async def delete_record(self, ban_record: BanUser) -> None: + async with get_db_session() as session: + stmt = select(NapcatBanRecord).where( + NapcatBanRecord.group_id == ban_record.group_id, + NapcatBanRecord.user_id == ban_record.user_id, + ) + result = await session.execute(stmt) + row = result.scalars().first() + if row: + await session.delete(row) + logger.debug( + f"删除禁言记录 group={ban_record.group_id} user={ban_record.user_id} lift={row.lift_time}" + ) else: - logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}") + logger.info( + f"未找到禁言记录 group={ban_record.group_id} user={ban_record.user_id}" + ) + + # 兼容旧命名 + async def update_ban_record(self, ban_list: List[BanUser]) -> None: # old name + await self.update_ban_records(ban_list) + + async def create_ban_record(self, ban_record: BanUser) -> None: # old name + await self.create_or_update(ban_record) + + async def delete_ban_record(self, ban_record: BanUser) -> None: # old name + await self.delete_record(ban_record) -db_manager = DatabaseManager() +napcat_db = NapcatDatabase() + + +def is_identical(a: BanUser, b: BanUser) -> bool: + return a.group_id == b.group_id and a.user_id == b.user_id + + +__all__ = [ + "BanUser", + "NapcatBanRecord", + "napcat_db", + "is_identical", +] diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index a9eaead16..4a32657a7 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -9,7 +9,7 @@ from src.common.logger import get_logger logger = get_logger("napcat_adapter") from src.plugin_system.apis import config_api -from ..database import BanUser, db_manager, is_identical +from ..database import BanUser, napcat_db, is_identical from . import NoticeType, ACCEPT_FORMAT from .message_sending import message_send_instance from .message_handler import message_handler @@ -62,7 +62,7 @@ class NoticeHandler: return self.server_connection return websocket_manager.get_connection() - def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: + async def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None: """ 将用户禁言记录添加到self.banned_list中 如果是全体禁言,则user_id为0 @@ -71,16 +71,16 @@ class NoticeHandler: user_id = 0 # 使用0表示全体禁言 lift_time = -1 ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time) - for record in self.banned_list: + for record in list(self.banned_list): if is_identical(record, ban_record): self.banned_list.remove(record) self.banned_list.append(ban_record) - db_manager.create_ban_record(ban_record) # 作为更新 + await napcat_db.create_ban_record(ban_record) # 更新 return self.banned_list.append(ban_record) - db_manager.create_ban_record(ban_record) # 添加到数据库 + await napcat_db.create_ban_record(ban_record) # 新建 - def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: + async def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None: """ 从self.lifted_group_list中移除已经解除全体禁言的群 """ @@ -88,7 +88,12 @@ class NoticeHandler: user_id = 0 # 使用0表示全体禁言 ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1) self.lifted_list.append(ban_record) - db_manager.delete_ban_record(ban_record) # 删除数据库中的记录 + # 从被禁言列表里移除对应记录 + for record in list(self.banned_list): + if is_identical(record, ban_record): + self.banned_list.remove(record) + break + await napcat_db.delete_ban_record(ban_record) async def handle_notice(self, raw_message: dict) -> None: notice_type = raw_message.get("notice_type") @@ -376,7 +381,7 @@ class NoticeHandler: if user_id == 0: # 为全体禁言 sub_type: str = "whole_ban" - self._ban_operation(group_id) + await self._ban_operation(group_id) else: # 为单人禁言 # 获取被禁言人的信息 sub_type: str = "ban" @@ -390,7 +395,7 @@ class NoticeHandler: user_nickname=user_nickname, user_cardname=user_cardname, ) - self._ban_operation(group_id, user_id, int(time.time() + duration)) + await self._ban_operation(group_id, user_id, int(time.time() + duration)) seg_data: Seg = Seg( type="notify", @@ -439,7 +444,7 @@ class NoticeHandler: user_id = raw_message.get("user_id") if user_id == 0: # 全体禁言解除 sub_type = "whole_lift_ban" - self._lift_operation(group_id) + await self._lift_operation(group_id) else: # 单人禁言解除 sub_type = "lift_ban" # 获取被解除禁言人的信息 @@ -455,7 +460,7 @@ class NoticeHandler: user_nickname=user_nickname, user_cardname=user_cardname, ) - self._lift_operation(group_id, user_id) + await self._lift_operation(group_id, user_id) seg_data: Seg = Seg( type="notify", @@ -483,7 +488,7 @@ class NoticeHandler: group_id = lift_record.group_id user_id = lift_record.user_id - db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录 + asyncio.create_task(napcat_db.delete_ban_record(lift_record)) # 从数据库中删除禁言记录 seg_message: Seg = await self.natural_lift(group_id, user_id) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py index e36fc93fd..4c47a2570 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py @@ -6,7 +6,7 @@ import urllib3 import ssl import io -from .database import BanUser, db_manager +from .database import BanUser, napcat_db from src.common.logger import get_logger logger = get_logger("napcat_adapter") @@ -270,10 +270,11 @@ async def read_ban_list( ] """ try: - ban_list = db_manager.get_ban_records() + ban_list = await napcat_db.get_ban_records() lifted_list: List[BanUser] = [] logger.info("已经读取禁言列表") - for ban_record in ban_list: + # 复制列表以避免迭代中修改原列表问题 + for ban_record in list(ban_list): if ban_record.user_id == 0: fetched_group_info = await get_group_info(websocket, ban_record.group_id) if fetched_group_info is None: @@ -301,12 +302,12 @@ async def read_ban_list( ban_list.remove(ban_record) else: ban_record.lift_time = lift_ban_time - db_manager.update_ban_record(ban_list) + await napcat_db.update_ban_record(ban_list) return ban_list, lifted_list except Exception as e: logger.error(f"读取禁言列表失败: {e}") return [], [] -def save_ban_record(list: List[BanUser]): - return db_manager.update_ban_record(list) +async def save_ban_record(list: List[BanUser]): + return await napcat_db.update_ban_record(list) diff --git a/src/schedule/database.py b/src/schedule/database.py index b420f0686..5025c1fa3 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -42,7 +42,7 @@ async def add_new_plans(plans: List[str], month: str): new_plan_objects = [ MonthlyPlan(plan_text=plan, target_month=month, status="active") for plan in plans_to_add ] - await session.add_all(new_plan_objects) + session.add_all(new_plan_objects) await session.commit() logger.info(f"成功向数据库添加了 {len(new_plan_objects)} 条 {month} 的月度计划。") diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index 4e66bf0c8..115480381 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -128,7 +128,7 @@ class ScheduleManager: existing_schedule.updated_at = datetime.now() else: new_schedule = Schedule(date=date_str, schedule_data=schedule_json) - await session.add(new_schedule) + session.add(new_schedule) await session.commit() @staticmethod From 4ca2dfe65a2e7103a3f9f97eb8aec930bb9fcada Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 18:08:07 +0800 Subject: [PATCH 19/31] =?UTF-8?q?refactor(chat):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E8=87=AA=E6=B6=88=E6=81=AF=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E5=92=8C=E5=9B=9E=E5=A4=8D=E7=9B=AE=E6=A0=87=E9=80=89=E6=8B=A9?= =?UTF-8?q?=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加自消息阻断机制,避免机器人回复自己的消息 - 重构回复目标选择逻辑,优先选择非机器人用户的消息作为回复目标 --- .../old/config.toml.bak.20250907_121908 | 25 ------ src/chat/replyer/default_generator.py | 77 ++++++++++++------- 2 files changed, 48 insertions(+), 54 deletions(-) delete mode 100644 plugins/napcat_adapter_plugin/config/old/config.toml.bak.20250907_121908 diff --git a/plugins/napcat_adapter_plugin/config/old/config.toml.bak.20250907_121908 b/plugins/napcat_adapter_plugin/config/old/config.toml.bak.20250907_121908 deleted file mode 100644 index 1ddca6cf5..000000000 --- a/plugins/napcat_adapter_plugin/config/old/config.toml.bak.20250907_121908 +++ /dev/null @@ -1,25 +0,0 @@ -[inner] -version = "0.2.0" # 版本号 -# 请勿修改版本号,除非你知道自己在做什么 - -[nickname] # 现在没用 -nickname = "" - -[napcat_server] # Napcat连接的ws服务设置 -mode = "reverse" # 连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端) -host = "localhost" # 主机地址 -port = 8095 # 端口号 -url = "" # 正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用) -access_token = "" # WebSocket 连接的访问令牌,用于身份验证(可选) -heartbeat_interval = 30 # 心跳间隔时间(按秒计) - -[maibot_server] # 连接麦麦的ws服务设置 -host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段 -port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段 - -[voice] # 发送语音设置 -use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter) - -[debug] -level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL) - diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index bf3d4fe26..de9c176cf 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -233,6 +233,19 @@ class DefaultReplyer: self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id) + def _should_block_self_message(self, reply_message: Optional[Dict[str, Any]]) -> bool: + """判定是否应阻断当前待处理消息(自消息且无外部触发)""" + try: + bot_id = str(global_config.bot.qq_account) + uid = str(reply_message.get("user_id")) + if uid != bot_id: + return False + + return True + except Exception as e: + logger.warning(f"[SelfGuard] 判定异常,回退为不阻断: {e}") + return False + async def generate_reply_with_context( self, reply_to: str = "", @@ -260,6 +273,10 @@ class DefaultReplyer: prompt = None if available_actions is None: available_actions = {} + # 自消息阻断 + if self._should_block_self_message(reply_message): + logger.debug("[SelfGuard] 阻断:自消息且无外部触发。") + return False, None, None llm_response = None try: # 构建 Prompt @@ -822,36 +839,35 @@ class DefaultReplyer: # 兼容旧的reply_to sender, target = self._parse_reply_target(reply_to) else: - # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 - if reply_message is None: - logger.warning("reply_message 为 None,无法构建prompt") - return "" - platform = reply_message.get("chat_info_platform") - person_id = person_info_manager.get_person_id( - platform, # type: ignore - reply_message.get("user_id"), # type: ignore - ) - person_info = await person_info_manager.get_values(person_id, ["person_name", "user_id"]) - person_name = person_info.get("person_name") - - # 如果person_name为None,使用fallback值 - if person_name is None: - # 尝试从reply_message获取用户名 - fallback_name = reply_message.get("user_nickname") or reply_message.get("user_id", "未知用户") - logger.warning(f"无法获取person_name,使用fallback: {fallback_name}") - person_name = str(fallback_name) - - # 检查是否是bot自己的名字,如果是则替换为"(你)" + # 需求:遍历最近消息,找到第一条 user_id != bot_id 的消息作为目标;找不到则静默退出 bot_user_id = str(global_config.bot.qq_account) - current_user_id = person_info.get("user_id") - current_platform = reply_message.get("chat_info_platform") - - if current_user_id == bot_user_id and current_platform == global_config.bot.platform: - sender = f"{person_name}(你)" + # 优先使用传入的 reply_message 如果它不是 bot + candidate_msg = None + if reply_message and str(reply_message.get("user_id")) != bot_user_id: + candidate_msg = reply_message else: - # 如果不是bot自己,直接使用person_name - sender = person_name - target = reply_message.get("processed_plain_text") + try: + recent_msgs = await get_raw_msg_before_timestamp_with_chat( + chat_id=chat_id, + timestamp=time.time(), + limit= max(10, int(global_config.chat.max_context_size * 0.5)), + ) + # 从最近到更早遍历,找第一条不是bot的 + for m in reversed(recent_msgs): + if str(m.get("user_id")) != bot_user_id: + candidate_msg = m + break + except Exception as e: + logger.error(f"获取最近消息失败: {e}") + if not candidate_msg: + logger.debug("未找到可作为目标的非bot消息,静默不回复。") + return "" + platform = candidate_msg.get("chat_info_platform") or self.chat_stream.platform + person_id = person_info_manager.get_person_id(platform, candidate_msg.get("user_id")) + person_info = await person_info_manager.get_values(person_id, ["person_name", "user_id"]) if person_id else {} + person_name = person_info.get("person_name") or candidate_msg.get("user_nickname") or candidate_msg.get("user_id") or "未知用户" + sender = person_name + target = candidate_msg.get("processed_plain_text") or candidate_msg.get("raw_message") or "" # 最终的空值检查,确保sender和target不为None if sender is None: @@ -867,6 +883,8 @@ class DefaultReplyer: target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) + # (简化)不再对自消息做额外任务段落清理,只通过前置选择逻辑避免自目标 + # 构建action描述 (如果启用planner) action_descriptions = "" if available_actions: @@ -895,7 +913,6 @@ class DefaultReplyer: read_mark=0.0, show_actions=True, ) - # 获取目标用户信息,用于s4u模式 target_user_info = None if sender: @@ -1068,6 +1085,8 @@ class DefaultReplyer: prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters) prompt_text = await prompt.build() + # 自目标情况已在上游通过筛选避免,这里不再额外修改 prompt + # --- 动态添加分割指令 --- if global_config.response_splitter.enable and global_config.response_splitter.split_mode == "llm": split_instruction = """ From 4fcaa8e7fba2ef817ff3164c7928128847ffae06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 20 Sep 2025 18:36:02 +0800 Subject: [PATCH 20/31] =?UTF-8?q?=E7=A7=BB=E5=87=BARust=E7=BB=84=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rust_video/Cargo.toml | 31 -- rust_video/pyproject.toml | 15 - rust_video/rust_video.pyi | 391 ------------------ rust_video/src/lib.rs | 831 -------------------------------------- 4 files changed, 1268 deletions(-) delete mode 100644 rust_video/Cargo.toml delete mode 100644 rust_video/pyproject.toml delete mode 100644 rust_video/rust_video.pyi delete mode 100644 rust_video/src/lib.rs diff --git a/rust_video/Cargo.toml b/rust_video/Cargo.toml deleted file mode 100644 index beb03b188..000000000 --- a/rust_video/Cargo.toml +++ /dev/null @@ -1,31 +0,0 @@ -[package] -name = "rust_video" -version = "0.1.0" -edition = "2021" -authors = ["VideoAnalysis Team"] -description = "Ultra-fast video keyframe extraction tool in Rust" -license = "MIT" - -[dependencies] -anyhow = "1.0" -clap = { version = "4.0", features = ["derive"] } -rayon = "1.11" - -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" - -chrono = { version = "0.4", features = ["serde"] } - -# PyO3 dependencies -pyo3 = { version = "0.22", features = ["extension-module"] } - -[lib] -name = "rust_video" -crate-type = ["cdylib"] - -[profile.release] -opt-level = 3 -lto = true -codegen-units = 1 -panic = "abort" -strip = true diff --git a/rust_video/pyproject.toml b/rust_video/pyproject.toml deleted file mode 100644 index 3cbb6fadd..000000000 --- a/rust_video/pyproject.toml +++ /dev/null @@ -1,15 +0,0 @@ -[build-system] -requires = ["maturin>=1.9,<2.0"] -build-backend = "maturin" - -[project] -name = "rust_video" -requires-python = ">=3.8" -classifiers = [ - "Programming Language :: Rust", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", -] -dynamic = ["version"] -[tool.maturin] -features = ["pyo3/extension-module"] diff --git a/rust_video/rust_video.pyi b/rust_video/rust_video.pyi deleted file mode 100644 index 83bccf591..000000000 --- a/rust_video/rust_video.pyi +++ /dev/null @@ -1,391 +0,0 @@ -""" -Rust Video Keyframe Extractor - Python Type Hints - -Ultra-fast video keyframe extraction tool with SIMD optimization. -""" - -from typing import Dict, List, Optional, Tuple, Union, Any -from pathlib import Path - -class PyVideoFrame: - """ - Python绑定的视频帧结构 - - 表示一个视频帧,包含帧编号、尺寸和像素数据。 - """ - - frame_number: int - """帧编号""" - - width: int - """帧宽度(像素)""" - - height: int - """帧高度(像素)""" - - def __init__(self, frame_number: int, width: int, height: int, data: List[int]) -> None: - """ - 创建新的视频帧 - - Args: - frame_number: 帧编号 - width: 帧宽度 - height: 帧高度 - data: 像素数据(灰度值列表) - """ - ... - - def get_data(self) -> List[int]: - """ - 获取帧的像素数据 - - Returns: - 像素数据列表(灰度值) - """ - ... - - def calculate_difference(self, other: 'PyVideoFrame') -> float: - """ - 计算与另一帧的差异 - - Args: - other: 要比较的另一帧 - - Returns: - 帧差异值(0-255范围) - """ - ... - - def calculate_difference_simd(self, other: 'PyVideoFrame', block_size: Optional[int] = None) -> float: - """ - 使用SIMD优化计算帧差异 - - Args: - other: 要比较的另一帧 - block_size: 处理块大小,默认8192 - - Returns: - 帧差异值(0-255范围) - """ - ... - -class PyPerformanceResult: - """ - 性能测试结果 - - 包含详细的性能统计信息。 - """ - - test_name: str - """测试名称""" - - video_file: str - """视频文件名""" - - total_time_ms: float - """总处理时间(毫秒)""" - - frame_extraction_time_ms: float - """帧提取时间(毫秒)""" - - keyframe_analysis_time_ms: float - """关键帧分析时间(毫秒)""" - - total_frames: int - """总帧数""" - - keyframes_extracted: int - """提取的关键帧数""" - - keyframe_ratio: float - """关键帧比例(百分比)""" - - processing_fps: float - """处理速度(帧每秒)""" - - threshold: float - """检测阈值""" - - optimization_type: str - """优化类型""" - - simd_enabled: bool - """是否启用SIMD""" - - threads_used: int - """使用的线程数""" - - timestamp: str - """时间戳""" - - def to_dict(self) -> Dict[str, Any]: - """ - 转换为Python字典 - - Returns: - 包含所有结果字段的字典 - """ - ... - -class VideoKeyframeExtractor: - """ - 主要的视频关键帧提取器类 - - 提供完整的视频关键帧提取功能,包括SIMD优化和多线程处理。 - """ - - def __init__( - self, - ffmpeg_path: str = "ffmpeg", - threads: int = 0, - verbose: bool = False - ) -> None: - """ - 创建关键帧提取器 - - Args: - ffmpeg_path: FFmpeg可执行文件路径,默认"ffmpeg" - threads: 线程数,0表示自动检测 - verbose: 是否启用详细输出 - """ - ... - - def extract_frames( - self, - video_path: str, - max_frames: Optional[int] = None - ) -> Tuple[List[PyVideoFrame], int, int]: - """ - 从视频中提取帧 - - Args: - video_path: 视频文件路径 - max_frames: 最大提取帧数,None表示提取所有帧 - - Returns: - (帧列表, 宽度, 高度) - """ - ... - - def extract_keyframes( - self, - frames: List[PyVideoFrame], - threshold: float, - use_simd: Optional[bool] = None, - block_size: Optional[int] = None - ) -> List[int]: - """ - 提取关键帧索引 - - Args: - frames: 视频帧列表 - threshold: 检测阈值 - use_simd: 是否使用SIMD优化,默认True - block_size: 处理块大小,默认8192 - - Returns: - 关键帧索引列表 - """ - ... - - def save_keyframes( - self, - video_path: str, - keyframe_indices: List[int], - output_dir: str, - max_save: Optional[int] = None - ) -> int: - """ - 保存关键帧为图片 - - Args: - video_path: 原视频文件路径 - keyframe_indices: 关键帧索引列表 - output_dir: 输出目录 - max_save: 最大保存数量,默认50 - - Returns: - 实际保存的关键帧数量 - """ - ... - - def benchmark( - self, - video_path: str, - threshold: float, - test_name: str, - max_frames: Optional[int] = None, - use_simd: Optional[bool] = None, - block_size: Optional[int] = None - ) -> PyPerformanceResult: - """ - 运行性能测试 - - Args: - video_path: 视频文件路径 - threshold: 检测阈值 - test_name: 测试名称 - max_frames: 最大处理帧数,默认1000 - use_simd: 是否使用SIMD优化,默认True - block_size: 处理块大小,默认8192 - - Returns: - 性能测试结果 - """ - ... - - def process_video( - self, - video_path: str, - output_dir: str, - threshold: Optional[float] = None, - max_frames: Optional[int] = None, - max_save: Optional[int] = None, - use_simd: Optional[bool] = None, - block_size: Optional[int] = None - ) -> PyPerformanceResult: - """ - 完整的处理流程 - - 执行完整的视频关键帧提取和保存流程。 - - Args: - video_path: 视频文件路径 - output_dir: 输出目录 - threshold: 检测阈值,默认2.0 - max_frames: 最大处理帧数,0表示处理所有帧 - max_save: 最大保存数量,默认50 - use_simd: 是否使用SIMD优化,默认True - block_size: 处理块大小,默认8192 - - Returns: - 处理结果 - """ - ... - - def get_cpu_features(self) -> Dict[str, bool]: - """ - 获取CPU特性信息 - - Returns: - CPU特性字典,包含AVX2、SSE2等支持信息 - """ - ... - - def get_thread_count(self) -> int: - """ - 获取当前配置的线程数 - - Returns: - 配置的线程数 - """ - ... - - def get_configured_threads(self) -> int: - """ - 获取配置的线程数 - - Returns: - 配置的线程数 - """ - ... - - def get_actual_thread_count(self) -> int: - """ - 获取实际运行的线程数 - - Returns: - 实际运行的线程数 - """ - ... - -def extract_keyframes_from_video( - video_path: str, - output_dir: str, - threshold: Optional[float] = None, - max_frames: Optional[int] = None, - max_save: Optional[int] = None, - ffmpeg_path: Optional[str] = None, - use_simd: Optional[bool] = None, - threads: Optional[int] = None, - verbose: Optional[bool] = None -) -> PyPerformanceResult: - """ - 便捷函数:从视频提取关键帧 - - 这是一个便捷函数,封装了完整的关键帧提取流程。 - - Args: - video_path: 视频文件路径 - output_dir: 输出目录 - threshold: 检测阈值,默认2.0 - max_frames: 最大处理帧数,0表示处理所有帧 - max_save: 最大保存数量,默认50 - ffmpeg_path: FFmpeg路径,默认"ffmpeg" - use_simd: 是否使用SIMD优化,默认True - threads: 线程数,0表示自动检测 - verbose: 是否启用详细输出,默认False - - Returns: - 处理结果 - - Example: - >>> result = extract_keyframes_from_video( - ... "video.mp4", - ... "./output", - ... threshold=2.5, - ... max_save=30, - ... verbose=True - ... ) - >>> print(f"提取了 {result.keyframes_extracted} 个关键帧") - """ - ... - -def get_system_info() -> Dict[str, Any]: - """ - 获取系统信息 - - Returns: - 系统信息字典,包含: - - threads: 可用线程数 - - avx2_supported: 是否支持AVX2(x86_64) - - sse2_supported: 是否支持SSE2(x86_64) - - simd_supported: 是否支持SIMD(非x86_64) - - version: 库版本 - - Example: - >>> info = get_system_info() - >>> print(f"线程数: {info['threads']}") - >>> print(f"AVX2支持: {info.get('avx2_supported', False)}") - """ - ... - -# 类型别名 -VideoPath = Union[str, Path] -"""视频文件路径类型""" - -OutputPath = Union[str, Path] -"""输出路径类型""" - -FrameData = List[int] -"""帧数据类型(像素值列表)""" - -KeyframeIndices = List[int] -"""关键帧索引类型""" - -# 常量 -DEFAULT_THRESHOLD: float = 2.0 -"""默认检测阈值""" - -DEFAULT_BLOCK_SIZE: int = 8192 -"""默认处理块大小""" - -DEFAULT_MAX_SAVE: int = 50 -"""默认最大保存数量""" - -MAX_FRAME_DIFFERENCE: float = 255.0 -"""最大帧差异值""" - -# 版本信息 -__version__: str = "0.1.0" -"""库版本""" diff --git a/rust_video/src/lib.rs b/rust_video/src/lib.rs deleted file mode 100644 index 69382f298..000000000 --- a/rust_video/src/lib.rs +++ /dev/null @@ -1,831 +0,0 @@ -use pyo3::prelude::*; -use anyhow::{Context, Result}; -use chrono::prelude::*; -use rayon::prelude::*; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::fs; -use std::io::{BufReader, Read}; -use std::path::PathBuf; -use std::process::{Command, Stdio}; -use std::time::Instant; - -#[cfg(target_arch = "x86_64")] -use std::arch::x86_64::*; - -/// Python绑定的视频帧结构 -#[pyclass] -#[derive(Debug, Clone)] -pub struct PyVideoFrame { - #[pyo3(get)] - pub frame_number: usize, - #[pyo3(get)] - pub width: usize, - #[pyo3(get)] - pub height: usize, - pub data: Vec, -} - -#[pymethods] -impl PyVideoFrame { - #[new] - fn new(frame_number: usize, width: usize, height: usize, data: Vec) -> Self { - // 确保数据长度是32的倍数以支持AVX2处理 - let mut aligned_data = data; - let remainder = aligned_data.len() % 32; - if remainder != 0 { - aligned_data.resize(aligned_data.len() + (32 - remainder), 0); - } - - Self { - frame_number, - width, - height, - data: aligned_data, - } - } - - /// 获取帧数据 - fn get_data(&self) -> &[u8] { - let pixel_count = self.width * self.height; - &self.data[..pixel_count] - } - - /// 计算与另一帧的差异 - fn calculate_difference(&self, other: &PyVideoFrame) -> PyResult { - if self.width != other.width || self.height != other.height { - return Ok(f64::MAX); - } - - let total_pixels = self.width * self.height; - let total_diff: u64 = self.data[..total_pixels] - .iter() - .zip(other.data[..total_pixels].iter()) - .map(|(a, b)| (*a as i32 - *b as i32).abs() as u64) - .sum(); - - Ok(total_diff as f64 / total_pixels as f64) - } - - /// 使用SIMD优化计算帧差异 - #[pyo3(signature = (other, block_size=None))] - fn calculate_difference_simd(&self, other: &PyVideoFrame, block_size: Option) -> PyResult { - let block_size = block_size.unwrap_or(8192); - Ok(self.calculate_difference_parallel_simd(other, block_size, true)) - } -} - -impl PyVideoFrame { - /// 使用并行SIMD处理计算帧差异 - fn calculate_difference_parallel_simd(&self, other: &PyVideoFrame, block_size: usize, use_simd: bool) -> f64 { - if self.width != other.width || self.height != other.height { - return f64::MAX; - } - - let total_pixels = self.width * self.height; - let num_blocks = (total_pixels + block_size - 1) / block_size; - - let total_diff: u64 = (0..num_blocks) - .into_par_iter() - .map(|block_idx| { - let start = block_idx * block_size; - let end = ((block_idx + 1) * block_size).min(total_pixels); - let block_len = end - start; - - if use_simd { - #[cfg(target_arch = "x86_64")] - { - unsafe { - if std::arch::is_x86_feature_detected!("avx2") { - return self.calculate_difference_avx2_block(&other.data, start, block_len); - } else if std::arch::is_x86_feature_detected!("sse2") { - return self.calculate_difference_sse2_block(&other.data, start, block_len); - } - } - } - } - - // 标量实现回退 - self.data[start..end] - .iter() - .zip(other.data[start..end].iter()) - .map(|(a, b)| (*a as i32 - *b as i32).abs() as u64) - .sum() - }) - .sum(); - - total_diff as f64 / total_pixels as f64 - } - - /// AVX2 优化的块处理 - #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx2")] - unsafe fn calculate_difference_avx2_block(&self, other_data: &[u8], start: usize, len: usize) -> u64 { - let mut total_diff = 0u64; - let chunks = len / 32; - - for i in 0..chunks { - let offset = start + i * 32; - - let a = _mm256_loadu_si256(self.data.as_ptr().add(offset) as *const __m256i); - let b = _mm256_loadu_si256(other_data.as_ptr().add(offset) as *const __m256i); - - let diff = _mm256_sad_epu8(a, b); - let result = _mm256_extract_epi64(diff, 0) as u64 + - _mm256_extract_epi64(diff, 1) as u64 + - _mm256_extract_epi64(diff, 2) as u64 + - _mm256_extract_epi64(diff, 3) as u64; - - total_diff += result; - } - - // 处理剩余字节 - for i in (start + chunks * 32)..(start + len) { - total_diff += (self.data[i] as i32 - other_data[i] as i32).abs() as u64; - } - - total_diff - } - - /// SSE2 优化的块处理 - #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "sse2")] - unsafe fn calculate_difference_sse2_block(&self, other_data: &[u8], start: usize, len: usize) -> u64 { - let mut total_diff = 0u64; - let chunks = len / 16; - - for i in 0..chunks { - let offset = start + i * 16; - - let a = _mm_loadu_si128(self.data.as_ptr().add(offset) as *const __m128i); - let b = _mm_loadu_si128(other_data.as_ptr().add(offset) as *const __m128i); - - let diff = _mm_sad_epu8(a, b); - let result = _mm_extract_epi64(diff, 0) as u64 + _mm_extract_epi64(diff, 1) as u64; - - total_diff += result; - } - - // 处理剩余字节 - for i in (start + chunks * 16)..(start + len) { - total_diff += (self.data[i] as i32 - other_data[i] as i32).abs() as u64; - } - - total_diff - } -} - -/// 性能测试结果 -#[pyclass] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PyPerformanceResult { - #[pyo3(get)] - pub test_name: String, - #[pyo3(get)] - pub video_file: String, - #[pyo3(get)] - pub total_time_ms: f64, - #[pyo3(get)] - pub frame_extraction_time_ms: f64, - #[pyo3(get)] - pub keyframe_analysis_time_ms: f64, - #[pyo3(get)] - pub total_frames: usize, - #[pyo3(get)] - pub keyframes_extracted: usize, - #[pyo3(get)] - pub keyframe_ratio: f64, - #[pyo3(get)] - pub processing_fps: f64, - #[pyo3(get)] - pub threshold: f64, - #[pyo3(get)] - pub optimization_type: String, - #[pyo3(get)] - pub simd_enabled: bool, - #[pyo3(get)] - pub threads_used: usize, - #[pyo3(get)] - pub timestamp: String, -} - -#[pymethods] -impl PyPerformanceResult { - /// 转换为Python字典 - fn to_dict(&self) -> PyResult> { - Python::with_gil(|py| { - let mut dict = HashMap::new(); - dict.insert("test_name".to_string(), self.test_name.to_object(py)); - dict.insert("video_file".to_string(), self.video_file.to_object(py)); - dict.insert("total_time_ms".to_string(), self.total_time_ms.to_object(py)); - dict.insert("frame_extraction_time_ms".to_string(), self.frame_extraction_time_ms.to_object(py)); - dict.insert("keyframe_analysis_time_ms".to_string(), self.keyframe_analysis_time_ms.to_object(py)); - dict.insert("total_frames".to_string(), self.total_frames.to_object(py)); - dict.insert("keyframes_extracted".to_string(), self.keyframes_extracted.to_object(py)); - dict.insert("keyframe_ratio".to_string(), self.keyframe_ratio.to_object(py)); - dict.insert("processing_fps".to_string(), self.processing_fps.to_object(py)); - dict.insert("threshold".to_string(), self.threshold.to_object(py)); - dict.insert("optimization_type".to_string(), self.optimization_type.to_object(py)); - dict.insert("simd_enabled".to_string(), self.simd_enabled.to_object(py)); - dict.insert("threads_used".to_string(), self.threads_used.to_object(py)); - dict.insert("timestamp".to_string(), self.timestamp.to_object(py)); - Ok(dict) - }) - } -} - -/// 主要的视频关键帧提取器类 -#[pyclass] -pub struct VideoKeyframeExtractor { - ffmpeg_path: String, - threads: usize, - verbose: bool, -} - -#[pymethods] -impl VideoKeyframeExtractor { - #[new] - #[pyo3(signature = (ffmpeg_path = "ffmpeg".to_string(), threads = 0, verbose = false))] - fn new(ffmpeg_path: String, threads: usize, verbose: bool) -> PyResult { - // 设置线程池(如果还没有初始化) - if threads > 0 { - // 尝试设置线程池,如果已经初始化则忽略错误 - let _ = rayon::ThreadPoolBuilder::new() - .num_threads(threads) - .build_global(); - } - - Ok(Self { - ffmpeg_path, - threads: if threads == 0 { rayon::current_num_threads() } else { threads }, - verbose, - }) - } - - /// 从视频中提取帧 - #[pyo3(signature = (video_path, max_frames=None))] - fn extract_frames(&self, video_path: &str, max_frames: Option) -> PyResult<(Vec, usize, usize)> { - let video_path = PathBuf::from(video_path); - let max_frames = max_frames.unwrap_or(0); - - extract_frames_memory_stream(&video_path, &PathBuf::from(&self.ffmpeg_path), max_frames, self.verbose) - .map_err(|e| PyErr::new::(format!("Frame extraction failed: {}", e))) - } - - /// 提取关键帧索引 - #[pyo3(signature = (frames, threshold, use_simd=None, block_size=None))] - fn extract_keyframes( - &self, - frames: Vec, - threshold: f64, - use_simd: Option, - block_size: Option - ) -> PyResult> { - let use_simd = use_simd.unwrap_or(true); - let block_size = block_size.unwrap_or(8192); - - extract_keyframes_optimized(&frames, threshold, use_simd, block_size, self.verbose) - .map_err(|e| PyErr::new::(format!("Keyframe extraction failed: {}", e))) - } - - /// 保存关键帧为图片 - #[pyo3(signature = (video_path, keyframe_indices, output_dir, max_save=None))] - fn save_keyframes( - &self, - video_path: &str, - keyframe_indices: Vec, - output_dir: &str, - max_save: Option - ) -> PyResult { - let video_path = PathBuf::from(video_path); - let output_dir = PathBuf::from(output_dir); - let max_save = max_save.unwrap_or(50); - - save_keyframes_optimized( - &video_path, - &keyframe_indices, - &output_dir, - &PathBuf::from(&self.ffmpeg_path), - max_save, - self.verbose - ).map_err(|e| PyErr::new::(format!("Save keyframes failed: {}", e))) - } - - /// 运行性能测试 - #[pyo3(signature = (video_path, threshold, test_name, max_frames=None, use_simd=None, block_size=None))] - fn benchmark( - &self, - video_path: &str, - threshold: f64, - test_name: &str, - max_frames: Option, - use_simd: Option, - block_size: Option - ) -> PyResult { - let video_path = PathBuf::from(video_path); - let max_frames = max_frames.unwrap_or(1000); - let use_simd = use_simd.unwrap_or(true); - let block_size = block_size.unwrap_or(8192); - - let result = run_performance_test( - &video_path, - threshold, - test_name, - &PathBuf::from(&self.ffmpeg_path), - max_frames, - use_simd, - block_size, - self.verbose - ).map_err(|e| PyErr::new::(format!("Benchmark failed: {}", e)))?; - - Ok(PyPerformanceResult { - test_name: result.test_name, - video_file: result.video_file, - total_time_ms: result.total_time_ms, - frame_extraction_time_ms: result.frame_extraction_time_ms, - keyframe_analysis_time_ms: result.keyframe_analysis_time_ms, - total_frames: result.total_frames, - keyframes_extracted: result.keyframes_extracted, - keyframe_ratio: result.keyframe_ratio, - processing_fps: result.processing_fps, - threshold: result.threshold, - optimization_type: result.optimization_type, - simd_enabled: result.simd_enabled, - threads_used: result.threads_used, - timestamp: result.timestamp, - }) - } - - /// 完整的处理流程 - #[pyo3(signature = (video_path, output_dir, threshold=None, max_frames=None, max_save=None, use_simd=None, block_size=None))] - fn process_video( - &self, - video_path: &str, - output_dir: &str, - threshold: Option, - max_frames: Option, - max_save: Option, - use_simd: Option, - block_size: Option - ) -> PyResult { - let threshold = threshold.unwrap_or(2.0); - let max_frames = max_frames.unwrap_or(0); - let max_save = max_save.unwrap_or(50); - let use_simd = use_simd.unwrap_or(true); - let block_size = block_size.unwrap_or(8192); - - let video_path_buf = PathBuf::from(video_path); - let output_dir_buf = PathBuf::from(output_dir); - - // 运行性能测试 - let result = run_performance_test( - &video_path_buf, - threshold, - "Python Processing", - &PathBuf::from(&self.ffmpeg_path), - max_frames, - use_simd, - block_size, - self.verbose - ).map_err(|e| PyErr::new::(format!("Processing failed: {}", e)))?; - - // 提取并保存关键帧 - let (frames, _, _) = extract_frames_memory_stream(&video_path_buf, &PathBuf::from(&self.ffmpeg_path), max_frames, self.verbose) - .map_err(|e| PyErr::new::(format!("Frame extraction failed: {}", e)))?; - - let frames: Vec = frames.into_iter().map(|f| PyVideoFrame { - frame_number: f.frame_number, - width: f.width, - height: f.height, - data: f.data, - }).collect(); - - let keyframe_indices = extract_keyframes_optimized(&frames, threshold, use_simd, block_size, self.verbose) - .map_err(|e| PyErr::new::(format!("Keyframe extraction failed: {}", e)))?; - - save_keyframes_optimized(&video_path_buf, &keyframe_indices, &output_dir_buf, &PathBuf::from(&self.ffmpeg_path), max_save, self.verbose) - .map_err(|e| PyErr::new::(format!("Save keyframes failed: {}", e)))?; - - Ok(PyPerformanceResult { - test_name: result.test_name, - video_file: result.video_file, - total_time_ms: result.total_time_ms, - frame_extraction_time_ms: result.frame_extraction_time_ms, - keyframe_analysis_time_ms: result.keyframe_analysis_time_ms, - total_frames: result.total_frames, - keyframes_extracted: result.keyframes_extracted, - keyframe_ratio: result.keyframe_ratio, - processing_fps: result.processing_fps, - threshold: result.threshold, - optimization_type: result.optimization_type, - simd_enabled: result.simd_enabled, - threads_used: result.threads_used, - timestamp: result.timestamp, - }) - } - - /// 获取CPU特性信息 - fn get_cpu_features(&self) -> PyResult> { - let mut features = HashMap::new(); - - #[cfg(target_arch = "x86_64")] - { - features.insert("avx2".to_string(), std::arch::is_x86_feature_detected!("avx2")); - features.insert("sse2".to_string(), std::arch::is_x86_feature_detected!("sse2")); - features.insert("sse4_1".to_string(), std::arch::is_x86_feature_detected!("sse4.1")); - features.insert("sse4_2".to_string(), std::arch::is_x86_feature_detected!("sse4.2")); - features.insert("fma".to_string(), std::arch::is_x86_feature_detected!("fma")); - } - - #[cfg(not(target_arch = "x86_64"))] - { - features.insert("simd_supported".to_string(), false); - } - - Ok(features) - } - - /// 获取当前使用的线程数 - fn get_thread_count(&self) -> usize { - self.threads - } - - /// 获取配置的线程数 - fn get_configured_threads(&self) -> usize { - self.threads - } - - /// 获取实际运行的线程数 - fn get_actual_thread_count(&self) -> usize { - rayon::current_num_threads() - } -} - -// 从main.rs中复制的核心函数 - -struct PerformanceResult { - test_name: String, - video_file: String, - total_time_ms: f64, - frame_extraction_time_ms: f64, - keyframe_analysis_time_ms: f64, - total_frames: usize, - keyframes_extracted: usize, - keyframe_ratio: f64, - processing_fps: f64, - threshold: f64, - optimization_type: String, - simd_enabled: bool, - threads_used: usize, - timestamp: String, -} - -fn extract_frames_memory_stream( - video_path: &PathBuf, - ffmpeg_path: &PathBuf, - max_frames: usize, - verbose: bool, -) -> Result<(Vec, usize, usize)> { - if verbose { - println!("🎬 Extracting frames using FFmpeg memory streaming..."); - println!("📁 Video: {}", video_path.display()); - } - - // 获取视频信息 - let probe_output = Command::new(ffmpeg_path) - .args(["-i", video_path.to_str().unwrap(), "-hide_banner"]) - .output() - .context("Failed to probe video with FFmpeg")?; - - let probe_info = String::from_utf8_lossy(&probe_output.stderr); - let (width, height) = parse_video_dimensions(&probe_info) - .ok_or_else(|| anyhow::anyhow!("Cannot parse video dimensions"))?; - - if verbose { - println!("📐 Video dimensions: {}x{}", width, height); - } - - // 构建优化的FFmpeg命令 - let mut cmd = Command::new(ffmpeg_path); - cmd.args([ - "-i", video_path.to_str().unwrap(), - "-f", "rawvideo", - "-pix_fmt", "gray", - "-an", - "-threads", "0", - "-preset", "ultrafast", - ]); - - if max_frames > 0 { - cmd.args(["-frames:v", &max_frames.to_string()]); - } - - cmd.args(["-"]).stdout(Stdio::piped()).stderr(Stdio::null()); - - let start_time = Instant::now(); - let mut child = cmd.spawn().context("Failed to spawn FFmpeg process")?; - let stdout = child.stdout.take().unwrap(); - let mut reader = BufReader::with_capacity(1024 * 1024, stdout); - - let frame_size = width * height; - let mut frames = Vec::new(); - let mut frame_count = 0; - let mut frame_buffer = vec![0u8; frame_size]; - - if verbose { - println!("📦 Frame size: {} bytes", frame_size); - } - - // 直接流式读取帧数据到内存 - loop { - match reader.read_exact(&mut frame_buffer) { - Ok(()) => { - frames.push(PyVideoFrame::new( - frame_count, - width, - height, - frame_buffer.clone(), - )); - frame_count += 1; - - if verbose && frame_count % 200 == 0 { - print!("\r⚡ Frames processed: {}", frame_count); - } - - if max_frames > 0 && frame_count >= max_frames { - break; - } - } - Err(_) => break, - } - } - - let _ = child.wait(); - - if verbose { - println!("\r✅ Frame extraction complete: {} frames in {:.2}s", - frame_count, start_time.elapsed().as_secs_f64()); - } - - Ok((frames, width, height)) -} - -fn parse_video_dimensions(probe_info: &str) -> Option<(usize, usize)> { - for line in probe_info.lines() { - if line.contains("Video:") && line.contains("x") { - for part in line.split_whitespace() { - if let Some(x_pos) = part.find('x') { - let width_str = &part[..x_pos]; - let height_part = &part[x_pos + 1..]; - let height_str = height_part.split(',').next().unwrap_or(height_part); - - if let (Ok(width), Ok(height)) = (width_str.parse::(), height_str.parse::()) { - return Some((width, height)); - } - } - } - } - } - None -} - -fn extract_keyframes_optimized( - frames: &[PyVideoFrame], - threshold: f64, - use_simd: bool, - block_size: usize, - verbose: bool, -) -> Result> { - if frames.len() < 2 { - return Ok(Vec::new()); - } - - let optimization_name = if use_simd { "SIMD+Parallel" } else { "Standard Parallel" }; - if verbose { - println!("🚀 Keyframe analysis (threshold: {}, optimization: {})", threshold, optimization_name); - } - - let start_time = Instant::now(); - - // 并行计算帧差异 - let differences: Vec = frames - .par_windows(2) - .map(|pair| { - if use_simd { - pair[0].calculate_difference_parallel_simd(&pair[1], block_size, true) - } else { - pair[0].calculate_difference(&pair[1]).unwrap_or(f64::MAX) - } - }) - .collect(); - - // 基于阈值查找关键帧 - let keyframe_indices: Vec = differences - .par_iter() - .enumerate() - .filter_map(|(i, &diff)| { - if diff > threshold { - Some(i + 1) - } else { - None - } - }) - .collect(); - - if verbose { - println!("⚡ Analysis complete in {:.2}s", start_time.elapsed().as_secs_f64()); - println!("🎯 Found {} keyframes", keyframe_indices.len()); - } - - Ok(keyframe_indices) -} - -fn save_keyframes_optimized( - video_path: &PathBuf, - keyframe_indices: &[usize], - output_dir: &PathBuf, - ffmpeg_path: &PathBuf, - max_save: usize, - verbose: bool, -) -> Result { - if keyframe_indices.is_empty() { - if verbose { - println!("⚠️ No keyframes to save"); - } - return Ok(0); - } - - if verbose { - println!("💾 Saving keyframes..."); - } - - fs::create_dir_all(output_dir).context("Failed to create output directory")?; - - let save_count = keyframe_indices.len().min(max_save); - let mut saved = 0; - - for (i, &frame_idx) in keyframe_indices.iter().take(save_count).enumerate() { - let output_path = output_dir.join(format!("keyframe_{:03}.jpg", i + 1)); - let timestamp = frame_idx as f64 / 30.0; // 假设30 FPS - - let output = Command::new(ffmpeg_path) - .args([ - "-i", video_path.to_str().unwrap(), - "-ss", ×tamp.to_string(), - "-vframes", "1", - "-q:v", "2", - "-y", - output_path.to_str().unwrap(), - ]) - .output() - .context("Failed to extract keyframe with FFmpeg")?; - - if output.status.success() { - saved += 1; - if verbose && (saved % 10 == 0 || saved == save_count) { - print!("\r💾 Saved: {}/{} keyframes", saved, save_count); - } - } else if verbose { - eprintln!("⚠️ Failed to save keyframe {}", frame_idx); - } - } - - if verbose { - println!("\r✅ Keyframe saving complete: {}/{}", saved, save_count); - } - - Ok(saved) -} - -fn run_performance_test( - video_path: &PathBuf, - threshold: f64, - test_name: &str, - ffmpeg_path: &PathBuf, - max_frames: usize, - use_simd: bool, - block_size: usize, - verbose: bool, -) -> Result { - if verbose { - println!("\n{}", "=".repeat(60)); - println!("⚡ Running test: {}", test_name); - println!("{}", "=".repeat(60)); - } - - let total_start = Instant::now(); - - // 帧提取 - let extraction_start = Instant::now(); - let (frames, _width, _height) = extract_frames_memory_stream(video_path, ffmpeg_path, max_frames, verbose)?; - let extraction_time = extraction_start.elapsed().as_secs_f64() * 1000.0; - - // 关键帧分析 - let analysis_start = Instant::now(); - let keyframe_indices = extract_keyframes_optimized(&frames, threshold, use_simd, block_size, verbose)?; - let analysis_time = analysis_start.elapsed().as_secs_f64() * 1000.0; - - let total_time = total_start.elapsed().as_secs_f64() * 1000.0; - - let optimization_type = if use_simd { - format!("SIMD+Parallel(block:{})", block_size) - } else { - "Standard Parallel".to_string() - }; - - let result = PerformanceResult { - test_name: test_name.to_string(), - video_file: video_path.file_name().unwrap().to_string_lossy().to_string(), - total_time_ms: total_time, - frame_extraction_time_ms: extraction_time, - keyframe_analysis_time_ms: analysis_time, - total_frames: frames.len(), - keyframes_extracted: keyframe_indices.len(), - keyframe_ratio: keyframe_indices.len() as f64 / frames.len() as f64 * 100.0, - processing_fps: frames.len() as f64 / (total_time / 1000.0), - threshold, - optimization_type, - simd_enabled: use_simd, - threads_used: rayon::current_num_threads(), - timestamp: Local::now().format("%Y-%m-%d %H:%M:%S").to_string(), - }; - - if verbose { - println!("\n⚡ Test Results:"); - println!(" 🕐 Total time: {:.2}ms ({:.2}s)", result.total_time_ms, result.total_time_ms / 1000.0); - println!(" 📥 Extraction: {:.2}ms ({:.1}%)", result.frame_extraction_time_ms, - result.frame_extraction_time_ms / result.total_time_ms * 100.0); - println!(" 🧮 Analysis: {:.2}ms ({:.1}%)", result.keyframe_analysis_time_ms, - result.keyframe_analysis_time_ms / result.total_time_ms * 100.0); - println!(" 📊 Frames: {}", result.total_frames); - println!(" 🎯 Keyframes: {}", result.keyframes_extracted); - println!(" 🚀 Speed: {:.1} FPS", result.processing_fps); - println!(" ⚙️ Optimization: {}", result.optimization_type); - } - - Ok(result) -} - -/// Python模块定义 -#[pymodule] -fn rust_video(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - - // 便捷函数 - #[pyfn(m)] - #[pyo3(signature = (video_path, output_dir, threshold=None, max_frames=None, max_save=None, ffmpeg_path=None, use_simd=None, threads=None, verbose=None))] - fn extract_keyframes_from_video( - video_path: &str, - output_dir: &str, - threshold: Option, - max_frames: Option, - max_save: Option, - ffmpeg_path: Option, - use_simd: Option, - threads: Option, - verbose: Option - ) -> PyResult { - let extractor = VideoKeyframeExtractor::new( - ffmpeg_path.unwrap_or_else(|| "ffmpeg".to_string()), - threads.unwrap_or(0), - verbose.unwrap_or(false) - )?; - - extractor.process_video( - video_path, - output_dir, - threshold, - max_frames, - max_save, - use_simd, - None - ) - } - - #[pyfn(m)] - fn get_system_info() -> PyResult> { - Python::with_gil(|py| { - let mut info = HashMap::new(); - info.insert("threads".to_string(), rayon::current_num_threads().to_object(py)); - - #[cfg(target_arch = "x86_64")] - { - info.insert("avx2_supported".to_string(), std::arch::is_x86_feature_detected!("avx2").to_object(py)); - info.insert("sse2_supported".to_string(), std::arch::is_x86_feature_detected!("sse2").to_object(py)); - } - - #[cfg(not(target_arch = "x86_64"))] - { - info.insert("simd_supported".to_string(), false.to_object(py)); - } - - info.insert("version".to_string(), "0.1.0".to_object(py)); - - Ok(info) - }) - } - - Ok(()) -} From b42608c49a2b61e2eed22aa7edd1300b11a2e3dd Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 20 Sep 2025 20:45:56 +0800 Subject: [PATCH 21/31] =?UTF-8?q?=E4=B9=9F=E8=AE=B8=E6=98=AF=E4=BF=AE?= =?UTF-8?q?=E5=A5=BD=E4=BA=86=E8=A1=A8=E8=BE=BE=E5=AD=A6=E4=B9=A0=E5=90=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/chat_loop/heartFC_chat.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index 47afab50d..b1c7455b0 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -435,6 +435,13 @@ class HeartFChatting: # Messages should be processed action_type = await self.cycle_processor.observe(interest_value=interest_value) + # 尝试触发表达学习 + if self.context.expression_learner: + try: + await self.context.expression_learner.trigger_learning_for_chat() + except Exception as e: + logger.error(f"{self.context.log_prefix} 表达学习触发失败: {e}") + # 管理no_reply计数器 if action_type != "no_reply": self.recent_interest_records.clear() From ca780919a89a6013afd2a7409ee80c1bb84ec4e7 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 22:04:43 +0800 Subject: [PATCH 22/31] =?UTF-8?q?fix(core):=20=E4=BF=AE=E6=AD=A3=E5=9B=A0?= =?UTF-8?q?=E5=BC=82=E6=AD=A5=E6=94=B9=E9=80=A0=E9=81=97=E6=BC=8F=E7=9A=84?= =?UTF-8?q?=20await=20=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在最近的数据库异步化重构后,部分函数的调用处忘记添加 `await` 关键字,导致协程未被正确执行。 本次提交修复了以下模块中的问题: - `ExpressionLearner` - `ChatMessageBuilder` - `EmojiAction --- src/chat/express/expression_learner.py | 2 +- src/chat/utils/chat_message_builder.py | 2 +- src/plugins/built_in/core_actions/emoji.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index b7dabe6e1..fb22a4115 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -167,7 +167,7 @@ class ExpressionLearner: Returns: bool: 是否成功触发学习 """ - if not self.should_trigger_learning(): + if not await self.should_trigger_learning(): return False try: diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index e49c218c4..9555e08c8 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1232,7 +1232,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: # 在最前面添加图片映射信息 final_output_lines = [] - pic_mapping_info = build_pic_mapping_info(pic_id_mapping) + pic_mapping_info = await build_pic_mapping_info(pic_id_mapping) if pic_mapping_info: final_output_lines.append(pic_mapping_info) final_output_lines.append("\n\n") diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index 2c0940fcc..84dd45981 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -77,7 +77,7 @@ class EmojiAction(BaseAction): # 3. 根据历史记录筛选表情 try: - recent_emojis_desc = get_recent_emojis(self.chat_id, limit=10) + recent_emojis_desc = await get_recent_emojis(self.chat_id, limit=10) if recent_emojis_desc: filtered_emojis = [emoji for emoji in all_emojis_obj if emoji.description not in recent_emojis_desc] if filtered_emojis: @@ -122,7 +122,7 @@ class EmojiAction(BaseAction): emoji_base64, emoji_description = random.choice(all_emojis_data) else: # 获取最近的5条消息内容用于判断 - recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) + recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5) messages_text = "" if recent_messages: messages_text = await message_api.build_readable_messages( @@ -181,7 +181,7 @@ class EmojiAction(BaseAction): elif global_config.emoji.emoji_selection_mode == "description": # --- 详细描述选择模式 --- # 获取最近的5条消息内容用于判断 - recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) + recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5) messages_text = "" if recent_messages: messages_text = await message_api.build_readable_messages( @@ -260,7 +260,7 @@ class EmojiAction(BaseAction): # 发送成功后,记录到历史 try: - add_emoji_to_history(self.chat_id, emoji_description) + await add_emoji_to_history(self.chat_id, emoji_description) except Exception as e: logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}") From 0286d75228c26fa3d0ce1d049cb6903921f3a6ff Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 22:12:07 +0800 Subject: [PATCH 23/31] =?UTF-8?q?fix(emoji):=20=E4=BF=AE=E6=AD=A3=E5=AF=B9?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E5=87=BD=E6=95=B0=E7=9A=84=20await=20?= =?UTF-8?q?=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `get_recent_emojis` 和 `add_emoji_to_history` 函数已被重构为同步方法。本次提交移除了对这两个函数不必要的 `await` 调用,以修复由此引发的 `TypeError`。 --- src/plugins/built_in/core_actions/emoji.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index 84dd45981..3ebf4610a 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -77,7 +77,7 @@ class EmojiAction(BaseAction): # 3. 根据历史记录筛选表情 try: - recent_emojis_desc = await get_recent_emojis(self.chat_id, limit=10) + recent_emojis_desc = get_recent_emojis(self.chat_id, limit=10) if recent_emojis_desc: filtered_emojis = [emoji for emoji in all_emojis_obj if emoji.description not in recent_emojis_desc] if filtered_emojis: @@ -260,7 +260,7 @@ class EmojiAction(BaseAction): # 发送成功后,记录到历史 try: - await add_emoji_to_history(self.chat_id, emoji_description) + add_emoji_to_history(self.chat_id, emoji_description) except Exception as e: logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}") From a9a9f380d608929187805688197f2d5afa19190e Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 20 Sep 2025 22:21:35 +0800 Subject: [PATCH 24/31] =?UTF-8?q?refactor(person=5Finfo):=20=E5=BC=95?= =?UTF-8?q?=E5=85=A5=E5=90=8C=E6=AD=A5=E6=96=B9=E6=B3=95=20get=5Fvalue=20?= =?UTF-8?q?=E5=B9=B6=E6=9B=BF=E6=8D=A2=E6=97=A7=E7=9A=84=20get=5Fvalue=5Fs?= =?UTF-8?q?ync(=E5=9B=A0=E4=B8=BA=E6=A0=B9=E6=9C=AC=E5=B0=B1=E6=B2=A1?= =?UTF-8?q?=E6=9C=89=E8=BF=99=E4=B8=AA=E6=96=B9=E6=B3=95)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为了解决在不同异步上下文中同步调用数据库可能引发的运行时错误,实现了一个新的、更健壮的同步方法 `PersonInfoManager.get_value`。 - 新方法能够正确处理已在运行的 asyncio 事件循环,提高了在混合代码环境中调用的稳定性。 - 全面替换了原有的 `get_value_sync` 方法调用,统一了同步获取用户信息的接口。 --- src/chat/utils/chat_message_builder.py | 4 +-- src/chat/utils/prompt.py | 2 +- src/chat/utils/utils.py | 2 +- src/person_info/person_info.py | 40 ++++++++++++++++++++++++- src/person_info/relationship_builder.py | 8 ++--- 5 files changed, 47 insertions(+), 9 deletions(-) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index e49c218c4..b5fe33373 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -46,8 +46,8 @@ def replace_user_references_sync( if replace_bot_name and user_id == global_config.bot.qq_account: return f"{global_config.bot.nickname}(你)" person_id = PersonInfoManager.get_person_id(platform, user_id) - return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore - + return person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore + name_resolver = default_resolver # 处理回复格式 diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 2f115aa98..db31acfa5 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -965,7 +965,7 @@ class Prompt: person_info_manager = get_person_info_manager() person_id = person_info_manager.get_person_id_by_person_name(sender) if person_id: - user_id = person_info_manager.get_value_sync(person_id, "user_id") + user_id = person_info_manager.get_value(person_id, "user_id") return str(user_id) if user_id else "" return "" diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 99647e36c..5eb4cc991 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -663,7 +663,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: if person_id: # get_value is async, so await it directly person_info_manager = get_person_info_manager() - person_name = person_info_manager.get_value_sync(person_id, "person_name") + person_name = person_info_manager.get_value(person_id, "person_name") target_info["person_id"] = person_id target_info["person_name"] = person_name diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 3a036d029..f5bf8a515 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -512,6 +512,45 @@ class PersonInfoManager: logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行") + @staticmethod + def get_value(person_id: str, field_name: str) -> Any: + """获取单个字段值(同步版本)""" + if not person_id: + logger.debug("get_value获取失败:person_id不能为空") + return None + + import asyncio + + async def _get_record_sync(): + async with get_db_session() as session: + return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))).scalar() + + try: + record = asyncio.run(_get_record_sync()) + except RuntimeError: + # 如果当前线程已经有事件循环在运行,则使用现有的循环 + loop = asyncio.get_running_loop() + record = loop.run_until_complete(_get_record_sync()) + + model_fields = [column.name for column in PersonInfo.__table__.columns] + + if field_name not in model_fields: + if field_name in person_info_default: + logger.debug(f"字段'{field_name}'不在SQLAlchemy模型中,使用默认配置值。") + return copy.deepcopy(person_info_default[field_name]) + else: + logger.debug(f"get_value查询失败:字段'{field_name}'未在SQLAlchemy模型和默认配置中定义。") + return None + + if record: + value = getattr(record, field_name) + if value is not None: + return value + else: + return copy.deepcopy(person_info_default.get(field_name)) + else: + return copy.deepcopy(person_info_default.get(field_name)) + @staticmethod async def get_values(person_id: str, field_names: list) -> dict: """获取指定person_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值""" @@ -550,7 +589,6 @@ class PersonInfoManager: result[field_name] = copy.deepcopy(person_info_default.get(field_name)) return result - @staticmethod async def get_specific_value_list( field_name: str, diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 1ff90a99d..35ac76d7d 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -138,7 +138,7 @@ class RelationshipBuilder: "message_count": await self._count_messages_in_timerange(potential_start_time, message_time), } segments.append(new_segment) - person_name = get_person_info_manager().get_value_sync(person_id, "person_name") or person_id + person_name = get_person_info_manager().get_value(person_id, "person_name") or person_id logger.debug( f"{self.log_prefix} 眼熟用户 {person_name} 在 {time.strftime('%H:%M:%S', time.localtime(potential_start_time))} - {time.strftime('%H:%M:%S', time.localtime(message_time))} 之间有 {new_segment['message_count']} 条消息" ) @@ -178,7 +178,7 @@ class RelationshipBuilder: } segments.append(new_segment) person_info_manager = get_person_info_manager() - person_name = person_info_manager.get_value_sync(person_id, "person_name") or person_id + person_name = person_info_manager.get_value(person_id, "person_name") or person_id logger.debug( f"{self.log_prefix} 重新眼熟用户 {person_name} 创建新消息段(超过10条消息间隔): {new_segment}" ) @@ -368,8 +368,8 @@ class RelationshipBuilder: users_to_build_relationship = [] for person_id, segments in self.person_engaged_cache.items(): total_message_count = self._get_total_message_count(person_id) - person_name = get_person_info_manager().get_value_sync(person_id, "person_name") or person_id - + person_name = get_person_info_manager().get_value(person_id, "person_name") or person_id + if total_message_count >= max_build_threshold or ( total_message_count >= 5 and (immediate_build == person_id or immediate_build == "all") ): From 0efbc6dbfe63f9d4cc09f5a892e73812f2d52570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sun, 21 Sep 2025 09:59:39 +0800 Subject: [PATCH 25/31] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=A7=86=E9=A2=91?= =?UTF-8?q?=E8=AF=86=E5=88=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 1 + requirements.txt | 3 +- src/chat/utils/utils_video.py | 1082 +++++++-------------------------- 3 files changed, 234 insertions(+), 852 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ea7bc77f0..cf3c3a844 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ dependencies = [ "websockets>=15.0.1", "aiomysql>=0.2.0", "aiosqlite>=0.21.0", + "inkfox>=0.1.0", ] [[tool.uv.index]] diff --git a/requirements.txt b/requirements.txt index 757c1e09a..91811fd28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,4 +69,5 @@ google-generativeai lunar_python fuzzywuzzy python-multipart -aiofiles \ No newline at end of file +aiofiles +inkfox \ No newline at end of file diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 8cb294f3e..158b8a706 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -1,551 +1,143 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +"""纯 inkfox 视频关键帧分析工具 + +仅依赖 `inkfox.video` 提供的 Rust 扩展能力: + - extract_keyframes_from_video + - get_system_info + +功能: + - 关键帧提取 (base64, timestamp) + - 批量 / 逐帧 LLM 描述 + - 自动模式 (<=3 帧批量,否则逐帧) """ -视频分析器模块 - Rust优化版本 -集成了Rust视频关键帧提取模块,提供高性能的视频分析功能 -支持SIMD优化、多线程处理和智能关键帧检测 -""" + +from __future__ import annotations import os -import tempfile +import io import asyncio import base64 +import tempfile +from pathlib import Path +from typing import List, Tuple, Optional, Dict, Any import hashlib import time -import numpy as np -from PIL import Image -from pathlib import Path -from typing import List, Tuple, Optional, Dict -import io -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config +from PIL import Image + from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import get_db_session, Videos -from sqlalchemy import select +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest +from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore + +# 简易并发控制:同一 hash 只处理一次 +_video_locks: Dict[str, asyncio.Lock] = {} +_locks_guard = asyncio.Lock() logger = get_logger("utils_video") -# Rust模块可用性检测 -RUST_VIDEO_AVAILABLE = False -try: - import rust_video - - RUST_VIDEO_AVAILABLE = True - logger.info("✅ Rust 视频处理模块加载成功") -except ImportError as e: - logger.warning(f"⚠️ Rust 视频处理模块加载失败: {e}") - logger.warning("⚠️ 视频识别功能将自动禁用") -except Exception as e: - logger.error(f"❌ 加载Rust模块时发生错误: {e}") - RUST_VIDEO_AVAILABLE = False - -# 全局正在处理的视频哈希集合,用于防止重复处理 -processing_videos = set() -processing_lock = asyncio.Lock() -# 为每个视频hash创建独立的锁和事件 -video_locks = {} -video_events = {} -video_lock_manager = asyncio.Lock() +from inkfox import video class VideoAnalyzer: - """优化的视频分析器类""" + """基于 inkfox 的视频关键帧 + LLM 描述分析器""" - def __init__(self): - """初始化视频分析器""" - # 检查是否有任何可用的视频处理实现 - opencv_available = False + def __init__(self) -> None: + cfg = getattr(global_config, "video_analysis", object()) + self.max_frames: int = getattr(cfg, "max_frames", 20) + self.frame_quality: int = getattr(cfg, "frame_quality", 85) + self.max_image_size: int = getattr(cfg, "max_image_size", 600) + self.enable_frame_timing: bool = getattr(cfg, "enable_frame_timing", True) + self.use_simd: bool = getattr(cfg, "rust_use_simd", True) + self.threads: int = getattr(cfg, "rust_threads", 0) + self.ffmpeg_path: str = getattr(cfg, "ffmpeg_path", "ffmpeg") + self.analysis_mode: str = getattr(cfg, "analysis_mode", "auto") + self.frame_analysis_delay: float = 0.3 + + # 人格与提示模板 try: - import cv2 - - opencv_available = True - except ImportError: - pass - - if not RUST_VIDEO_AVAILABLE and not opencv_available: - logger.error("❌ 没有可用的视频处理实现,视频分析器将被禁用") - self.disabled = True - return - elif not RUST_VIDEO_AVAILABLE: - logger.warning("⚠️ Rust视频处理模块不可用,将使用Python降级实现") - elif not opencv_available: - logger.warning("⚠️ OpenCV不可用,仅支持Rust关键帧模式") - - self.disabled = False - - # 使用专用的视频分析配置 - try: - self.video_llm = LLMRequest( - model_set=model_config.model_task_config.video_analysis, request_type="video_analysis" - ) - logger.info("✅ 使用video_analysis模型配置") - except (AttributeError, KeyError) as e: - # 如果video_analysis不存在,使用vlm配置 - self.video_llm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="vlm") - logger.warning(f"video_analysis配置不可用({e}),回退使用vlm配置") - - # 从配置文件读取参数,如果配置不存在则使用默认值 - config = global_config.video_analysis - - # 使用 getattr 统一获取配置参数,如果配置不存在则使用默认值 - self.max_frames = getattr(config, "max_frames", 6) - self.frame_quality = getattr(config, "frame_quality", 85) - self.max_image_size = getattr(config, "max_image_size", 600) - self.enable_frame_timing = getattr(config, "enable_frame_timing", True) - - # Rust模块相关配置 - self.rust_keyframe_threshold = getattr(config, "rust_keyframe_threshold", 2.0) - self.rust_use_simd = getattr(config, "rust_use_simd", True) - self.rust_block_size = getattr(config, "rust_block_size", 8192) - self.rust_threads = getattr(config, "rust_threads", 0) - self.ffmpeg_path = getattr(config, "ffmpeg_path", "ffmpeg") - - # 从personality配置中获取人格信息 - try: - personality_config = global_config.personality - self.personality_core = getattr(personality_config, "personality_core", "是一个积极向上的女大学生") - self.personality_side = getattr( - personality_config, "personality_side", "用一句话或几句话描述人格的侧面特点" - ) - except AttributeError: - # 如果没有personality配置,使用默认值 + persona = global_config.personality + self.personality_core = getattr(persona, "personality_core", "是一个积极向上的女大学生") + self.personality_side = getattr(persona, "personality_side", "用一句话或几句话描述人格的侧面特点") + except Exception: # pragma: no cover self.personality_core = "是一个积极向上的女大学生" self.personality_side = "用一句话或几句话描述人格的侧面特点" self.batch_analysis_prompt = getattr( - config, + cfg, "batch_analysis_prompt", - """请以第一人称的视角来观看这一个视频,你看到的这些是从视频中按时间顺序提取的关键帧。 - -你的核心人设是:{personality_core}。 -你的人格细节是:{personality_side}。 - -请提供详细的视频内容描述,涵盖以下方面: -1. 视频的整体内容和主题 -2. 主要人物、对象和场景描述 -3. 动作、情节和时间线发展 -4. 视觉风格和艺术特点 -5. 整体氛围和情感表达 -6. 任何特殊的视觉效果或文字内容 - -请用中文回答,结果要详细准确。""", + """请以第一人称视角阅读这些按时间顺序提取的关键帧。\n核心:{personality_core}\n人格:{personality_side}\n请详细描述视频(主题/人物与场景/动作与时间线/视觉风格/情绪氛围/特殊元素)。""", ) - # 新增的线程池配置 - self.use_multiprocessing = getattr(config, "use_multiprocessing", True) - self.max_workers = getattr(config, "max_workers", 2) - self.frame_extraction_mode = getattr(config, "frame_extraction_mode", "fixed_number") - self.frame_interval_seconds = getattr(config, "frame_interval_seconds", 2.0) - - # 将配置文件中的模式映射到内部使用的模式名称 - config_mode = getattr(config, "analysis_mode", "auto") - if config_mode == "batch_frames": - self.analysis_mode = "batch" - elif config_mode == "frame_by_frame": - self.analysis_mode = "sequential" - elif config_mode == "auto": - self.analysis_mode = "auto" - else: - logger.warning(f"无效的分析模式: {config_mode},使用默认的auto模式") - self.analysis_mode = "auto" - - self.frame_analysis_delay = 0.3 # API调用间隔(秒) - self.frame_interval = 1.0 # 抽帧时间间隔(秒) - self.batch_size = 3 # 批处理时每批处理的帧数 - self.timeout = 60.0 # 分析超时时间(秒) - - if config: - logger.info("✅ 从配置文件读取视频分析参数") - else: - logger.warning("配置文件中缺少video_analysis配置,使用默认值") - - # 系统提示词 - self.system_prompt = "你是一个专业的视频内容分析助手。请仔细观察用户提供的视频关键帧,详细描述视频内容。" - - logger.info(f"✅ 视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}") - - # 获取Rust模块系统信息 - self._log_system_info() - - @staticmethod - def _log_system_info(): - """记录系统信息""" - if not RUST_VIDEO_AVAILABLE: - logger.info("⚠️ Rust模块不可用,跳过系统信息获取") - return - try: - system_info = rust_video.get_system_info() - logger.info(f"🔧 系统信息: 线程数={system_info.get('threads', '未知')}") - - # 记录CPU特性 - features = [] - if system_info.get("avx2_supported"): - features.append("AVX2") - if system_info.get("sse2_supported"): - features.append("SSE2") - if system_info.get("simd_supported"): - features.append("SIMD") - - if features: - logger.info(f"🚀 CPU特性: {', '.join(features)}") - else: - logger.info("⚠️ 未检测到SIMD支持") - - logger.info(f"📦 Rust模块版本: {system_info.get('version', '未知')}") - - except Exception as e: - logger.warning(f"获取系统信息失败: {e}") - - @staticmethod - def _calculate_video_hash(video_data: bytes) -> str: - """计算视频文件的hash值""" - hash_obj = hashlib.sha256() - hash_obj.update(video_data) - return hash_obj.hexdigest() - - @staticmethod - async def _check_video_exists(video_hash: str) -> Optional[Videos]: - """检查视频是否已经分析过 (异步)""" - try: - async with get_db_session() as session: - result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) - return result.scalar_one_or_none() - except Exception as e: - logger.warning(f"检查视频是否存在时出错: {e}") - return None - - @staticmethod - async def _store_video_result( - video_hash: str, description: str, metadata: Optional[Dict] = None - ) -> Optional[Videos]: - """存储视频分析结果到数据库 (异步)""" - if description.startswith("❌"): - logger.warning(f"⚠️ 检测到错误信息,不保存到数据库: {description[:50]}...") - return None - try: - async with get_db_session() as session: - result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) - existing_video = result.scalar_one_or_none() - if existing_video: - existing_video.description = description - existing_video.count += 1 - existing_video.timestamp = time.time() - if metadata: - existing_video.duration = metadata.get("duration") - existing_video.frame_count = metadata.get("frame_count") - existing_video.fps = metadata.get("fps") - existing_video.resolution = metadata.get("resolution") - existing_video.file_size = metadata.get("file_size") - await session.commit() - await session.refresh(existing_video) - logger.info( - f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}" - ) - return existing_video - else: - video_record = Videos( - video_hash=video_hash, - description=description, - timestamp=time.time(), - count=1, - ) - if metadata: - video_record.duration = metadata.get("duration") - video_record.frame_count = metadata.get("frame_count") - video_record.fps = metadata.get("fps") - video_record.resolution = metadata.get("resolution") - video_record.file_size = metadata.get("file_size") - session.add(video_record) - await session.commit() - await session.refresh(video_record) - logger.info( - f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}..." - ) - return video_record - except Exception as e: - logger.error(f"❌ 存储视频分析结果时出错: {e}") - return None - - def set_analysis_mode(self, mode: str): - """设置分析模式""" - if mode in ["batch", "sequential", "auto"]: - self.analysis_mode = mode - # logger.info(f"分析模式已设置为: {mode}") - else: - logger.warning(f"无效的分析模式: {mode}") - - async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]: - """提取视频帧 - 智能选择最佳实现""" - # 检查是否应该使用Rust实现 - if RUST_VIDEO_AVAILABLE and self.frame_extraction_mode == "keyframe": - # 优先尝试Rust关键帧提取 - try: - return await self._extract_frames_rust_advanced(video_path) - except Exception as e: - logger.warning(f"Rust高级接口失败: {e},尝试基础接口") - try: - return await self._extract_frames_rust(video_path) - except Exception as e2: - logger.warning(f"Rust基础接口也失败: {e2},降级到Python实现") - return await self._extract_frames_python_fallback(video_path) - else: - # 使用Python实现(支持time_interval和fixed_number模式) - if not RUST_VIDEO_AVAILABLE: - logger.info("🔄 Rust模块不可用,使用Python抽帧实现") - else: - logger.info(f"🔄 抽帧模式为 {self.frame_extraction_mode},使用Python抽帧实现") - return await self._extract_frames_python_fallback(video_path) - - async def _extract_frames_rust_advanced(self, video_path: str) -> List[Tuple[str, float]]: - """使用 Rust 高级接口的帧提取""" - try: - logger.info("🔄 使用 Rust 高级接口提取关键帧...") - - # 创建 Rust 视频处理器,使用配置参数 - extractor = rust_video.VideoKeyframeExtractor( - ffmpeg_path=self.ffmpeg_path, - threads=self.rust_threads, - verbose=False, # 使用固定值,不需要配置 + self.video_llm = LLMRequest( + model_set=model_config.model_task_config.video_analysis, request_type="video_analysis" ) + except Exception: + self.video_llm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="vlm") - # 1. 提取所有帧 - frames_data, width, height = extractor.extract_frames( + self._log_system() + + # ---- 系统信息 ---- + def _log_system(self) -> None: + try: + info = video.get_system_info() # type: ignore[attr-defined] + logger.info( + f"inkfox: threads={info.get('threads')} version={info.get('version')} simd={info.get('simd_supported')}" + ) + except Exception as e: # pragma: no cover + logger.debug(f"获取系统信息失败: {e}") + + # ---- 关键帧提取 ---- + async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]: + """提取关键帧并返回 (base64, timestamp_seconds) 列表""" + with tempfile.TemporaryDirectory() as tmp: + result = video.extract_keyframes_from_video( # type: ignore[attr-defined] video_path=video_path, - max_frames=self.max_frames * 3, # 提取更多帧用于关键帧检测 + output_dir=tmp, + max_keyframes=self.max_frames * 2, # 先多抓一点再截断 + max_save=self.max_frames, + ffmpeg_path=self.ffmpeg_path, + use_simd=self.use_simd, + threads=self.threads, + verbose=False, ) - - logger.info(f"提取到 {len(frames_data)} 帧,视频尺寸: {width}x{height}") - - # 2. 检测关键帧,使用配置参数 - keyframe_indices = extractor.extract_keyframes( - frames=frames_data, - threshold=self.rust_keyframe_threshold, - use_simd=self.rust_use_simd, - block_size=self.rust_block_size, - ) - - logger.info(f"检测到 {len(keyframe_indices)} 个关键帧") - - # 3. 转换选定的关键帧为 base64 - frames = [] - frame_count = 0 - - for idx in keyframe_indices[: self.max_frames]: - if idx < len(frames_data): - try: - frame = frames_data[idx] - frame_data = frame.get_data() - - # 将灰度数据转换为PIL图像 - frame_array = np.frombuffer(frame_data, dtype=np.uint8).reshape((frame.height, frame.width)) - pil_image = Image.fromarray( - frame_array, - mode="L", # 灰度模式 - ) - - # 转换为RGB模式以便保存为JPEG - pil_image = pil_image.convert("RGB") - - # 调整图像大小 - if max(pil_image.size) > self.max_image_size: - ratio = self.max_image_size / max(pil_image.size) - new_size = tuple(int(dim * ratio) for dim in pil_image.size) - pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) - - # 转换为 base64 - buffer = io.BytesIO() - pil_image.save(buffer, format="JPEG", quality=self.frame_quality) - frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") - - # 估算时间戳 - estimated_timestamp = frame.frame_number * (1.0 / 30.0) # 假设30fps - - frames.append((frame_base64, estimated_timestamp)) - frame_count += 1 - - logger.debug( - f"处理关键帧 {frame_count}: 帧号 {frame.frame_number}, 时间 {estimated_timestamp:.2f}s" - ) - - except Exception as e: - logger.error(f"处理关键帧 {idx} 失败: {e}") - continue - - logger.info(f"✅ Rust 高级提取完成: {len(frames)} 关键帧") + files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames] + total_ms = getattr(result, "total_time_ms", 0) + frames: List[Tuple[str, float]] = [] + for i, f in enumerate(files): + img = Image.open(f).convert("RGB") + if max(img.size) > self.max_image_size: + scale = self.max_image_size / max(img.size) + img = img.resize((int(img.width * scale), int(img.height * scale)), Image.Resampling.LANCZOS) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=self.frame_quality) + b64 = base64.b64encode(buf.getvalue()).decode() + ts = (i / max(1, len(files) - 1)) * (total_ms / 1000.0) if total_ms else float(i) + frames.append((b64, ts)) return frames - except Exception as e: - logger.error(f"❌ Rust 高级帧提取失败: {e}") - # 回退到基础方法 - logger.info("回退到基础 Rust 方法") - return await self._extract_frames_rust(video_path) - - async def _extract_frames_rust(self, video_path: str) -> List[Tuple[str, float]]: - """使用 Rust 实现的帧提取""" - try: - logger.info("🔄 使用 Rust 模块提取关键帧...") - - # 创建临时输出目录 - with tempfile.TemporaryDirectory() as temp_dir: - # 使用便捷函数进行关键帧提取,使用配置参数 - result = rust_video.extract_keyframes_from_video( - video_path=video_path, - output_dir=temp_dir, - threshold=self.rust_keyframe_threshold, - max_frames=self.max_frames * 2, # 提取更多帧以便筛选 - max_save=self.max_frames, - ffmpeg_path=self.ffmpeg_path, - use_simd=self.rust_use_simd, - threads=self.rust_threads, - verbose=False, # 使用固定值,不需要配置 - ) - - logger.info( - f"Rust 处理完成: 总帧数 {result.total_frames}, 关键帧 {result.keyframes_extracted}, 处理速度 {result.processing_fps:.1f} FPS" - ) - - # 转换保存的关键帧为 base64 格式 - frames = [] - temp_dir_path = Path(temp_dir) - - # 获取所有保存的关键帧文件 - keyframe_files = sorted(temp_dir_path.glob("keyframe_*.jpg")) - - for i, keyframe_file in enumerate(keyframe_files): - if len(frames) >= self.max_frames: - break - - try: - # 读取关键帧文件 - with open(keyframe_file, "rb") as f: - image_data = f.read() - - # 转换为 PIL 图像并压缩 - pil_image = Image.open(io.BytesIO(image_data)) - - # 调整图像大小 - if max(pil_image.size) > self.max_image_size: - ratio = self.max_image_size / max(pil_image.size) - new_size = tuple(int(dim * ratio) for dim in pil_image.size) - pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) - - # 转换为 base64 - buffer = io.BytesIO() - pil_image.save(buffer, format="JPEG", quality=self.frame_quality) - frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") - - # 估算时间戳(基于帧索引和总时长) - if result.total_frames > 0: - # 假设关键帧在时间上均匀分布 - estimated_timestamp = (i * result.total_time_ms / 1000.0) / result.keyframes_extracted - else: - estimated_timestamp = i * 1.0 # 默认每秒一帧 - - frames.append((frame_base64, estimated_timestamp)) - - logger.debug(f"处理关键帧 {i + 1}: 估算时间 {estimated_timestamp:.2f}s") - - except Exception as e: - logger.error(f"处理关键帧 {keyframe_file.name} 失败: {e}") - continue - - logger.info(f"✅ Rust 提取完成: {len(frames)} 关键帧") - return frames - - except Exception as e: - logger.error(f"❌ Rust 帧提取失败: {e}") - raise e - - async def _extract_frames_python_fallback(self, video_path: str) -> List[Tuple[str, float]]: - """Python降级抽帧实现 - 支持多种抽帧模式""" - try: - # 导入旧版本分析器 - from .utils_video_legacy import get_legacy_video_analyzer - - logger.info("🔄 使用Python降级抽帧实现...") - legacy_analyzer = get_legacy_video_analyzer() - - # 同步配置参数 - legacy_analyzer.max_frames = self.max_frames - legacy_analyzer.frame_quality = self.frame_quality - legacy_analyzer.max_image_size = self.max_image_size - legacy_analyzer.frame_extraction_mode = self.frame_extraction_mode - legacy_analyzer.frame_interval_seconds = self.frame_interval_seconds - legacy_analyzer.use_multiprocessing = self.use_multiprocessing - - # 使用旧版本的抽帧功能 - frames = await legacy_analyzer.extract_frames(video_path) - - logger.info(f"✅ Python降级抽帧完成: {len(frames)} 帧") - return frames - - except Exception as e: - logger.error(f"❌ Python降级抽帧失败: {e}") - return [] - - async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: - """批量分析所有帧""" - logger.info(f"开始批量分析{len(frames)}帧") - - if not frames: - return "❌ 没有可分析的帧" - - # 构建提示词并格式化人格信息,要不然占位符的那个会爆炸 + # ---- 批量分析 ---- + async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str: + from src.llm_models.payload_content.message import MessageBuilder, RoleType + from src.llm_models.utils_model import RequestType prompt = self.batch_analysis_prompt.format( personality_core=self.personality_core, personality_side=self.personality_side ) - - if user_question: - prompt += f"\n\n用户问题: {user_question}" - - # 添加帧信息到提示词 - frame_info = [] - for i, (_frame_base64, timestamp) in enumerate(frames): - if self.enable_frame_timing: - frame_info.append(f"第{i + 1}帧 (时间: {timestamp:.2f}s)") - else: - frame_info.append(f"第{i + 1}帧") - - prompt += f"\n\n视频包含{len(frames)}帧图像:{', '.join(frame_info)}" - prompt += "\n\n请基于所有提供的帧图像进行综合分析,关注并描述视频的完整内容和故事发展。" - - try: - # 使用多图片分析 - response = await self._analyze_multiple_frames(frames, prompt) - logger.info("✅ 视频识别完成") - return response - - except Exception as e: - logger.error(f"❌ 视频识别失败: {e}") - raise e - - async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str: - """使用多图片分析方法""" - logger.info(f"开始构建包含{len(frames)}帧的分析请求") - - # 导入MessageBuilder用于构建多图片消息 - from src.llm_models.payload_content.message import MessageBuilder, RoleType - from src.llm_models.utils_model import RequestType - - # 构建包含多张图片的消息 - message_builder = MessageBuilder().set_role(RoleType.User).add_text_content(prompt) - - # 添加所有帧图像 - for _i, (frame_base64, _timestamp) in enumerate(frames): - message_builder.add_image_content("jpeg", frame_base64) - # logger.info(f"已添加第{i+1}帧到分析请求 (时间: {timestamp:.2f}s, 图片大小: {len(frame_base64)} chars)") - - message = message_builder.build() - # logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片") - - # 获取模型信息和客户端 + if question: + prompt += f"\n用户关注: {question}" + desc = [ + (f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧") + for i, (_b, ts) in enumerate(frames) + ] + prompt += "\n帧列表: " + ", ".join(desc) + mb = MessageBuilder().set_role(RoleType.User).add_text_content(prompt) + for b64, _ in frames: + mb.add_image_content("jpeg", b64) + message = mb.build() model_info, api_provider, client = self.video_llm._select_model() - # logger.info(f"使用模型: {model_info.name} 进行多帧分析") - - # 直接执行多图片请求 - api_response = await self.video_llm._execute_request( + resp = await self.video_llm._execute_request( api_provider=api_provider, client=client, request_type=RequestType.RESPONSE, @@ -554,367 +146,155 @@ class VideoAnalyzer: temperature=None, max_tokens=None, ) + return resp.content or "❌ 未获得响应" - logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ") - return api_response.content or "❌ 未获得响应内容" - - async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: - """逐帧分析并汇总""" - logger.info(f"开始逐帧分析{len(frames)}帧") - - frame_analyses = [] - - for i, (frame_base64, timestamp) in enumerate(frames): + # ---- 逐帧分析 ---- + async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str: + results: List[str] = [] + for i, (b64, ts) in enumerate(frames): + prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "") + if question: + prompt += f"\n关注: {question}" try: - prompt = f"请分析这个视频的第{i + 1}帧" - if self.enable_frame_timing: - prompt += f" (时间: {timestamp:.2f}s)" - prompt += "。描述你看到的内容,包括人物、动作、场景、文字等。" - - if user_question: - prompt += f"\n特别关注: {user_question}" - - response, _ = await self.video_llm.generate_response_for_image( - prompt=prompt, image_base64=frame_base64, image_format="jpeg" + text, _ = await self.video_llm.generate_response_for_image( + prompt=prompt, image_base64=b64, image_format="jpeg" ) - - frame_analyses.append(f"第{i + 1}帧 ({timestamp:.2f}s): {response}") - logger.debug(f"✅ 第{i + 1}帧分析完成") - - # API调用间隔 - if i < len(frames) - 1: - await asyncio.sleep(self.frame_analysis_delay) - - except Exception as e: - logger.error(f"❌ 第{i + 1}帧分析失败: {e}") - frame_analyses.append(f"第{i + 1}帧: 分析失败 - {e}") - - # 生成汇总 - logger.info("开始生成汇总分析") - summary_prompt = f"""基于以下各帧的分析结果,请提供一个完整的视频内容总结: - -{chr(10).join(frame_analyses)} - -请综合所有帧的信息,描述视频的整体内容、故事线、主要元素和特点。""" - - if user_question: - summary_prompt += f"\n特别回答用户的问题: {user_question}" - + results.append(f"第{i+1}帧: {text}") + except Exception as e: # pragma: no cover + results.append(f"第{i+1}帧: 失败 {e}") + if i < len(frames) - 1: + await asyncio.sleep(self.frame_analysis_delay) + summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results) try: - # 使用最后一帧进行汇总分析 - if frames: - last_frame_base64, _ = frames[-1] - summary, _ = await self.video_llm.generate_response_for_image( - prompt=summary_prompt, image_base64=last_frame_base64, image_format="jpeg" - ) - logger.info("✅ 逐帧分析和汇总完成") - return summary - else: - return "❌ 没有可用于汇总的帧" - except Exception as e: - logger.error(f"❌ 汇总分析失败: {e}") - # 如果汇总失败,返回各帧分析结果 - return f"视频逐帧分析结果:\n\n{chr(10).join(frame_analyses)}" + final, _ = await self.video_llm.generate_response_for_image( + prompt=summary_prompt, image_base64=frames[-1][0], image_format="jpeg" + ) + return final + except Exception: # pragma: no cover + return "\n".join(results) - async def analyze_video(self, video_path: str, user_question: str = None) -> Tuple[bool, str]: - """分析视频的主要方法 - - Returns: - Tuple[bool, str]: (是否成功, 分析结果或错误信息) - """ - if self.disabled: - error_msg = "❌ 视频分析功能已禁用:没有可用的视频处理实现" - logger.warning(error_msg) - return False, error_msg - - try: - logger.info(f"开始分析视频: {os.path.basename(video_path)}") - - # 提取帧 - frames = await self.extract_frames(video_path) - if not frames: - error_msg = "❌ 无法从视频中提取有效帧" - return False, error_msg - - # 根据模式选择分析方法 - if self.analysis_mode == "auto": - # 智能选择:少于等于3帧用批量,否则用逐帧 - mode = "batch" if len(frames) <= 3 else "sequential" - logger.info(f"自动选择分析模式: {mode} (基于{len(frames)}帧)") - else: - mode = self.analysis_mode - - # 执行分析 - if mode == "batch": - result = await self.analyze_frames_batch(frames, user_question) - else: # sequential - result = await self.analyze_frames_sequential(frames, user_question) - - logger.info("✅ 视频分析完成") - return True, result - - except Exception as e: - error_msg = f"❌ 视频分析失败: {str(e)}" - logger.error(error_msg) - return False, error_msg + # ---- 主入口 ---- + async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]: + if not os.path.exists(video_path): + return False, "❌ 文件不存在" + frames = await self.extract_keyframes(video_path) + if not frames: + return False, "❌ 未提取到关键帧" + mode = self.analysis_mode + if mode == "auto": + mode = "batch" if len(frames) <= 20 else "sequential" + text = await (self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question)) + return True, text async def analyze_video_from_bytes( - self, video_bytes: bytes, filename: str = None, user_question: str = None, prompt: str = None + self, + video_bytes: bytes, + filename: Optional[str] = None, + prompt: Optional[str] = None, + question: Optional[str] = None, ) -> Dict[str, str]: - """从字节数据分析视频 - - Args: - video_bytes: 视频字节数据 - filename: 文件名(可选,仅用于日志) - user_question: 用户问题(旧参数名,保持兼容性) - prompt: 提示词(新参数名,与系统调用保持一致) - - Returns: - Dict[str, str]: 包含分析结果的字典,格式为 {"summary": "分析结果"} - """ - if self.disabled: - return {"summary": "❌ 视频分析功能已禁用:没有可用的视频处理实现"} - - video_hash = None - video_event = None + """从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}.""" + if not video_bytes: + return {"summary": "❌ 空视频数据"} + # 兼容参数:prompt 优先,其次 question + q = prompt if prompt is not None else question + video_hash = hashlib.sha256(video_bytes).hexdigest() + # 查缓存 try: - logger.info("开始从字节数据分析视频") + async with get_db_session() as session: # type: ignore + row = await session.execute( + Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore + ) + existing = row.first() + if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore + return {"summary": existing[Videos.description]} # type: ignore + except Exception: # pragma: no cover + pass - # 兼容性处理:如果传入了prompt参数,使用prompt;否则使用user_question - question = prompt if prompt is not None else user_question - - # 检查视频数据是否有效 - if not video_bytes: - return {"summary": "❌ 视频数据为空"} - - # 计算视频hash值 - video_hash = self._calculate_video_hash(video_bytes) - logger.info(f"视频hash: {video_hash}") - - # 改进的并发控制:使用每个视频独立的锁和事件 - async with video_lock_manager: - if video_hash not in video_locks: - video_locks[video_hash] = asyncio.Lock() - video_events[video_hash] = asyncio.Event() - - video_lock = video_locks[video_hash] - video_event = video_events[video_hash] - - # 尝试获取该视频的专用锁 - if video_lock.locked(): - logger.info(f"⏳ 相同视频正在处理中,等待处理完成... (hash: {video_hash[:16]}...)") - try: - # 等待处理完成的事件信号,最多等待60秒 - await asyncio.wait_for(video_event.wait(), timeout=60.0) - logger.info("✅ 等待结束,检查是否有处理结果") - - # 检查是否有结果了 - existing_video = await self._check_video_exists(video_hash) - if existing_video: - logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})") - return {"summary": existing_video.description} - else: - logger.warning("⚠️ 等待完成但未找到结果,可能处理失败") - except asyncio.TimeoutError: - logger.warning("⚠️ 等待超时(60秒),放弃等待") - - # 获取锁开始处理 - async with video_lock: - logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)") - - # 再次检查数据库(可能在等待期间已经有结果了) - existing_video = await self._check_video_exists(video_hash) - if existing_video: - logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})") - video_event.set() # 通知其他等待者 - return {"summary": existing_video.description} - - # 未找到已存在记录,开始新的分析 - logger.info("未找到已存在的视频记录,开始新的分析") - - # 创建临时文件进行分析 - with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file: - temp_file.write(video_bytes) - temp_path = temp_file.name - - try: - # 检查临时文件是否创建成功 - if not os.path.exists(temp_path): - video_event.set() # 通知等待者 - return {"summary": "❌ 临时文件创建失败"} - - # 使用临时文件进行分析 - success, result = await self.analyze_video(temp_path, question) - - finally: - # 清理临时文件 - if os.path.exists(temp_path): - os.unlink(temp_path) - - # 保存分析结果到数据库(仅保存成功的结果) - if success: - metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()} - await self._store_video_result(video_hash=video_hash, description=result, metadata=metadata) - logger.info("✅ 分析结果已保存到数据库") - else: - logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试") - - # 处理完成,通知等待者并清理资源 - video_event.set() - async with video_lock_manager: - # 清理资源 - video_locks.pop(video_hash, None) - video_events.pop(video_hash, None) - - return {"summary": result} - - except Exception as e: - error_msg = f"❌ 从字节数据分析视频失败: {str(e)}" - logger.error(error_msg) - - # 不保存错误信息到数据库,允许后续重试 - logger.info("💡 错误信息不保存到数据库,允许后续重试") - - # 处理失败,通知等待者并清理资源 + # 获取锁避免重复处理 + async with _locks_guard: + lock = _video_locks.get(video_hash) + if lock is None: + lock = asyncio.Lock() + _video_locks[video_hash] = lock + async with lock: + # 双检:进入锁后再查一次,避免重复处理 try: - if video_hash and video_event: - async with video_lock_manager: - if video_hash in video_events: - video_events[video_hash].set() - video_locks.pop(video_hash, None) - video_events.pop(video_hash, None) - except Exception as cleanup_e: - logger.error(f"❌ 清理锁资源失败: {cleanup_e}") + async with get_db_session() as session: # type: ignore + row = await session.execute( + Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore + ) + existing = row.first() + if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore + return {"summary": existing[Videos.description]} # type: ignore + except Exception: # pragma: no cover + pass - return {"summary": error_msg} - - @staticmethod - def is_supported_video(file_path: str) -> bool: - """检查是否为支持的视频格式""" - supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"} - return Path(file_path).suffix.lower() in supported_formats - - def get_processing_capabilities(self) -> Dict[str, any]: - """获取处理能力信息""" - if not RUST_VIDEO_AVAILABLE: - return {"error": "Rust视频处理模块不可用", "available": False, "reason": "rust_video模块未安装或加载失败"} - - try: - system_info = rust_video.get_system_info() - - # 创建一个临时的extractor来获取CPU特性 - extractor = rust_video.VideoKeyframeExtractor(threads=0, verbose=False) - cpu_features = extractor.get_cpu_features() - - capabilities = { - "system": { - "threads": system_info.get("threads", 0), - "rust_version": system_info.get("version", "unknown"), - }, - "cpu_features": cpu_features, - "recommended_settings": self._get_recommended_settings(cpu_features), - "analysis_modes": ["auto", "batch", "sequential"], - "supported_formats": [".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"], - "available": True, - } - - return capabilities - - except Exception as e: - logger.error(f"获取处理能力信息失败: {e}") - return {"error": str(e), "available": False} - - @staticmethod - def _get_recommended_settings(cpu_features: Dict[str, bool]) -> Dict[str, any]: - """根据CPU特性推荐最佳设置""" - settings = { - "use_simd": any(cpu_features.values()), - "block_size": 8192, - "threads": 0, # 自动检测 - } - - # 根据CPU特性调整设置 - if cpu_features.get("avx2", False): - settings["block_size"] = 16384 # AVX2支持更大的块 - settings["optimization_level"] = "avx2" - elif cpu_features.get("sse2", False): - settings["block_size"] = 8192 - settings["optimization_level"] = "sse2" - else: - settings["use_simd"] = False - settings["block_size"] = 4096 - settings["optimization_level"] = "scalar" - - return settings + try: + with tempfile.NamedTemporaryFile(delete=False) as fp: + fp.write(video_bytes) + temp_path = fp.name + try: + ok, summary = await self.analyze_video(temp_path, q) + # 写入缓存(仅成功) + if ok: + try: + async with get_db_session() as session: # type: ignore + await session.execute( + Videos.__table__.insert().values( + video_id="", + video_hash=video_hash, + description=summary, + count=1, + timestamp=time.time(), + vlm_processed=True, + duration=None, + frame_count=None, + fps=None, + resolution=None, + file_size=len(video_bytes), + ) + ) + await session.commit() + except Exception: # pragma: no cover + pass + return {"summary": summary} + finally: + if os.path.exists(temp_path): + try: + os.remove(temp_path) + except Exception: # pragma: no cover + pass + except Exception as e: # pragma: no cover + return {"summary": f"❌ 处理失败: {e}"} -# 全局实例 -_video_analyzer = None +# ---- 外部接口 ---- +_INSTANCE: Optional[VideoAnalyzer] = None def get_video_analyzer() -> VideoAnalyzer: - """获取视频分析器实例(单例模式)""" - global _video_analyzer - if _video_analyzer is None: - _video_analyzer = VideoAnalyzer() - return _video_analyzer + global _INSTANCE + if _INSTANCE is None: + _INSTANCE = VideoAnalyzer() + return _INSTANCE def is_video_analysis_available() -> bool: - """检查视频分析功能是否可用 + return True - Returns: - bool: 如果有任何可用的视频处理实现则返回True - """ - # 现在即使Rust模块不可用,也可以使用Python降级实现 + +def get_video_analysis_status() -> Dict[str, Any]: try: - import cv2 - - return True - except ImportError: - return False - - -def get_video_analysis_status() -> Dict[str, any]: - """获取视频分析功能的详细状态信息 - - Returns: - Dict[str, any]: 包含功能状态信息的字典 - """ - # 检查OpenCV是否可用 - opencv_available = False - try: - import cv2 - - opencv_available = True - except ImportError: - pass - - status = { - "available": opencv_available or RUST_VIDEO_AVAILABLE, - "implementations": { - "rust_keyframe": { - "available": RUST_VIDEO_AVAILABLE, - "description": "Rust智能关键帧提取", - "supported_modes": ["keyframe"], - }, - "python_legacy": { - "available": opencv_available, - "description": "Python传统抽帧方法", - "supported_modes": ["fixed_number", "time_interval"], - }, - }, - "supported_modes": [], + info = video.get_system_info() # type: ignore[attr-defined] + except Exception as e: # pragma: no cover + return {"available": False, "error": str(e)} + inst = get_video_analyzer() + return { + "available": True, + "system": info, + "modes": ["auto", "batch", "sequential"], + "max_frames_default": inst.max_frames, + "implementation": "inkfox", } - - # 汇总支持的模式 - if RUST_VIDEO_AVAILABLE: - status["supported_modes"].extend(["keyframe"]) - if opencv_available: - status["supported_modes"].extend(["fixed_number", "time_interval"]) - - if not status["available"]: - status.update({"error": "没有可用的视频处理实现", "solution": "请安装opencv-python或rust_video模块"}) - - return status From bd94ce1ce57ff119ac2c801ab61226db3f2b9f24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sun, 21 Sep 2025 10:32:15 +0800 Subject: [PATCH 26/31] Update utils_video.py --- src/chat/utils/utils_video.py | 103 ++++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 42 deletions(-) diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 158b8a706..2f72af32b 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -30,6 +30,8 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore +from sqlalchemy import select, update, insert # type: ignore +from sqlalchemy import exc as sa_exc # type: ignore # 简易并发控制:同一 hash 只处理一次 _video_locks: Dict[str, asyncio.Lock] = {} @@ -200,17 +202,11 @@ class VideoAnalyzer: q = prompt if prompt is not None else question video_hash = hashlib.sha256(video_bytes).hexdigest() - # 查缓存 - try: - async with get_db_session() as session: # type: ignore - row = await session.execute( - Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore - ) - existing = row.first() - if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore - return {"summary": existing[Videos.description]} # type: ignore - except Exception: # pragma: no cover - pass + # 查缓存(第一次,未加锁) + cached = await self._get_cached(video_hash) + if cached: + logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}") + return {"summary": cached} # 获取锁避免重复处理 async with _locks_guard: @@ -219,17 +215,11 @@ class VideoAnalyzer: lock = asyncio.Lock() _video_locks[video_hash] = lock async with lock: - # 双检:进入锁后再查一次,避免重复处理 - try: - async with get_db_session() as session: # type: ignore - row = await session.execute( - Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore - ) - existing = row.first() - if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore - return {"summary": existing[Videos.description]} # type: ignore - except Exception: # pragma: no cover - pass + # 双检缓存 + cached2 = await self._get_cached(video_hash) + if cached2: + logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}") + return {"summary": cached2} try: with tempfile.NamedTemporaryFile(delete=False) as fp: @@ -239,26 +229,7 @@ class VideoAnalyzer: ok, summary = await self.analyze_video(temp_path, q) # 写入缓存(仅成功) if ok: - try: - async with get_db_session() as session: # type: ignore - await session.execute( - Videos.__table__.insert().values( - video_id="", - video_hash=video_hash, - description=summary, - count=1, - timestamp=time.time(), - vlm_processed=True, - duration=None, - frame_count=None, - fps=None, - resolution=None, - file_size=len(video_bytes), - ) - ) - await session.commit() - except Exception: # pragma: no cover - pass + await self._save_cache(video_hash, summary, len(video_bytes)) return {"summary": summary} finally: if os.path.exists(temp_path): @@ -269,6 +240,54 @@ class VideoAnalyzer: except Exception as e: # pragma: no cover return {"summary": f"❌ 处理失败: {e}"} + # ---- 缓存辅助 ---- + async def _get_cached(self, video_hash: str) -> Optional[str]: + try: + async with get_db_session() as session: # type: ignore + result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) # type: ignore + obj: Optional[Videos] = result.scalar_one_or_none() # type: ignore + if obj and obj.vlm_processed and obj.description: + # 更新使用次数 + try: + await session.execute( + update(Videos) + .where(Videos.id == obj.id) # type: ignore + .values(count=obj.count + 1 if obj.count is not None else 1) + ) + await session.commit() + except Exception: # pragma: no cover + await session.rollback() + return obj.description + except Exception: # pragma: no cover + pass + return None + + async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None: + try: + async with get_db_session() as session: # type: ignore + stmt = insert(Videos).values( # type: ignore + video_id="", + video_hash=video_hash, + description=summary, + count=1, + timestamp=time.time(), + vlm_processed=True, + duration=None, + frame_count=None, + fps=None, + resolution=None, + file_size=file_size, + ) + try: + await session.execute(stmt) + await session.commit() + logger.debug(f"视频缓存写入 success hash={video_hash}") + except sa_exc.IntegrityError: # 可能并发已写入 + await session.rollback() + logger.debug(f"视频缓存已存在 hash={video_hash}") + except Exception: # pragma: no cover + logger.debug("视频缓存写入失败") + # ---- 外部接口 ---- _INSTANCE: Optional[VideoAnalyzer] = None From 016979b6c82eadc7da91490a37273775870a6b7c Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sun, 21 Sep 2025 13:05:13 +0800 Subject: [PATCH 27/31] =?UTF-8?q?feat(db):=20=E5=A2=9E=E5=BC=BA=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=87=AA=E5=8A=A8=E8=BF=81=E7=A7=BB=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=E4=BB=A5=E6=94=AF=E6=8C=81=E7=B4=A2=E5=BC=95=E5=88=9B?= =?UTF-8?q?=E5=BB=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 重构并增强了数据库自动迁移逻辑,以提供更健壮和全面的模式管理。 主要更新包括: - **支持索引创建**: 迁移脚本现在会自动检测并创建模型中定义但数据库中缺失的索引。 - **重构迁移流程**: 1. 首先一次性创建所有缺失的表,提高初始设置效率。 2. 然后,逐表检查并添加缺失的列和索引,使逻辑更清晰。 - **改进 SQLAlchemy 用法**: - 使用 `AddColumn` 和 `CreateIndex` DDL 结构代替原始 SQL 字符串,提高了代码的可读性和数据库方言的兼容性。 - 优化了 `inspector` 的使用方式,减少了重复调用。 - **增强日志记录**: 提供了更详细的日志输出,清晰地展示了正在执行的操作(如创建表、添加列、创建索引),并改进了错误报告。 --- src/common/database/db_migration.py | 155 +++++++++++++++++----------- 1 file changed, 92 insertions(+), 63 deletions(-) diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index 9d2be9e5b..a1633d76c 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -1,6 +1,8 @@ # mmc/src/common/database/db_migration.py -from sqlalchemy import inspect, text +from sqlalchemy import inspect +from sqlalchemy.schema import AddColumn, CreateIndex + from src.common.database.sqlalchemy_models import Base, get_engine from src.common.logger import get_logger @@ -9,79 +11,106 @@ logger = get_logger("db_migration") async def check_and_migrate_database(): """ - 异步检查数据库结构并自动迁移(添加缺失的表和列)。 + 异步检查数据库结构并自动迁移。 + - 自动创建不存在的表。 + - 自动为现有表添加缺失的列。 + - 自动为现有表创建缺失的索引。 """ logger.info("正在检查数据库结构并执行自动迁移...") engine = await get_engine() - - # 使用异步引擎获取inspector + async with engine.connect() as connection: # 在同步上下文中运行inspector操作 - inspector = await connection.run_sync(lambda sync_conn: inspect(sync_conn)) - - # 1. 获取数据库中所有已存在的表名 - db_table_names = await connection.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names())) + def get_inspector(sync_conn): + return inspect(sync_conn) - # 2. 遍历所有在代码中定义的模型 + inspector = await connection.run_sync(get_inspector) + + # 在同步lambda中传递inspector + db_table_names = await connection.run_sync(lambda conn: set(inspector.get_table_names(conn))) + + # 1. 首先处理表的创建 + tables_to_create = [] for table_name, table in Base.metadata.tables.items(): - logger.debug(f"正在检查表: {table_name}") - - # 3. 如果表不存在,则创建它 if table_name not in db_table_names: - logger.info(f"表 '{table_name}' 不存在,正在创建...") - try: - await connection.run_sync(lambda sync_conn: table.create(sync_conn)) - logger.info(f"表 '{table_name}' 创建成功。") - except Exception as e: - logger.error(f"创建表 '{table_name}' 失败: {e}") + tables_to_create.append(table) + + if tables_to_create: + logger.info(f"发现 {len(tables_to_create)} 个不存在的表,正在创建...") + try: + # 一次性创建所有缺失的表 + await connection.run_sync( + lambda sync_conn: Base.metadata.create_all(sync_conn, tables=tables_to_create) + ) + for table in tables_to_create: + logger.info(f"表 '{table.name}' 创建成功。") + db_table_names.add(table.name) # 将新创建的表添加到集合中 + except Exception as e: + logger.error(f"创建表时失败: {e}", exc_info=True) + + # 2. 然后处理现有表的列和索引的添加 + for table_name, table in Base.metadata.tables.items(): + if table_name not in db_table_names: + logger.warning(f"跳过检查表 '{table_name}',因为它在创建步骤中可能已失败。") continue - # 4. 如果表已存在,则检查并添加缺失的列 - db_columns = await connection.run_sync( - lambda sync_conn: {col["name"] for col in inspect(sync_conn).get_columns(table_name)} - ) - model_columns = {col.name for col in table.c} + logger.debug(f"正在检查表 '{table_name}' 的列和索引...") - missing_columns = model_columns - db_columns - if not missing_columns: - logger.debug(f"表 '{table_name}' 结构一致,无需修改。") + try: + # 检查并添加缺失的列 + db_columns = await connection.run_sync( + lambda conn: {col["name"] for col in inspector.get_columns(table_name, conn)} + ) + model_columns = {col.name for col in table.c} + missing_columns = model_columns - db_columns + + if missing_columns: + logger.info(f"在表 '{table_name}' 中发现缺失的列: {', '.join(missing_columns)}") + async with connection.begin() as trans: + for column_name in missing_columns: + try: + column = table.c[column_name] + add_column_ddl = AddColumn(table_name, column) + await connection.execute(add_column_ddl) + logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'。") + except Exception as e: + logger.error( + f"向表 '{table_name}' 添加列 '{column_name}' 失败: {e}", + exc_info=True, + ) + await trans.rollback() + break # 如果一列失败,则停止处理此表的其他列 + else: + logger.info(f"表 '{table_name}' 的列结构一致。") + + # 检查并创建缺失的索引 + db_indexes = await connection.run_sync( + lambda conn: {idx["name"] for idx in inspector.get_indexes(table_name, conn)} + ) + model_indexes = {idx.name for idx in table.indexes} + missing_indexes = model_indexes - db_indexes + + if missing_indexes: + logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}") + async with connection.begin() as trans: + for index_name in missing_indexes: + try: + index_obj = next((idx for idx in table.indexes if idx.name == index_name), None) + if index_obj is not None: + await connection.execute(CreateIndex(index_obj)) + logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。") + except Exception as e: + logger.error( + f"为表 '{table_name}' 创建索引 '{index_name}' 失败: {e}", + exc_info=True, + ) + await trans.rollback() + break # 如果一个索引失败,则停止处理此表的其他索引 + else: + logger.debug(f"表 '{table_name}' 的索引一致。") + + except Exception as e: + logger.error(f"在处理表 '{table_name}' 时发生意外错误: {e}", exc_info=True) continue - logger.info(f"在表 '{table_name}' 中发现缺失的列: {', '.join(missing_columns)}") - - # 开始事务来添加缺失的列 - async with connection.begin() as trans: - try: - for column_name in missing_columns: - column = table.c[column_name] - - # 构造并执行 ALTER TABLE 语句 - try: - # 在同步上下文中编译列类型 - column_type = await connection.run_sync( - lambda sync_conn: column.type.compile(sync_conn.dialect) - ) - sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" - - # 添加默认值和非空约束的处理 - if column.default is not None: - default_value = column.default.arg - if isinstance(default_value, str): - sql += f" DEFAULT '{default_value}'" - else: - sql += f" DEFAULT {default_value}" - - if not column.nullable: - sql += " NOT NULL" - - await connection.execute(text(sql)) - logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'。") - except Exception as e: - logger.error(f"向表 '{table_name}' 添加列 '{column_name}' 失败: {e}") - - except Exception as e: - logger.error(f"在表 '{table_name}' 添加列时发生错误,事务已回滚: {e}") - await trans.rollback() - raise - logger.info("数据库结构检查与自动迁移完成。") From df809b6dc3cca173fae15718664d15314d0f5ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sun, 21 Sep 2025 13:09:29 +0800 Subject: [PATCH 28/31] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E6=9D=83=E9=99=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/utils/statistic.py | 396 ++++-------------- src/plugin_system/apis/permission_api.py | 356 +++++----------- .../utils/permission_decorators.py | 95 +---- .../actions/read_feed_action.py | 2 +- .../actions/send_feed_action.py | 2 +- .../built_in/maizone_refactored/plugin.py | 6 +- .../built_in/permission_management/plugin.py | 26 +- .../built_in/plugin_management/plugin.py | 12 +- 8 files changed, 235 insertions(+), 660 deletions(-) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 891f7653c..ed8530387 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -1,6 +1,4 @@ import asyncio -import concurrent.futures - from collections import defaultdict from datetime import datetime, timedelta from typing import Any, Dict, Tuple, List @@ -13,69 +11,7 @@ from src.manager.local_store_manager import local_storage logger = get_logger("maibot_statistic") - -# 同步包装器函数,用于在非异步环境中调用异步数据库API -# 全局存储主事件循环引用 -_main_event_loop = None - -def _get_main_loop(): - """获取主事件循环的引用""" - global _main_event_loop - if _main_event_loop is None: - try: - _main_event_loop = asyncio.get_running_loop() - except RuntimeError: - # 如果没有运行的循环,尝试获取默认循环 - try: - _main_event_loop = asyncio.get_event_loop_policy().get_event_loop() - except Exception: - pass - return _main_event_loop - -def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False): - """同步版本的db_get,用于在线程池中调用""" - import asyncio - import threading - - try: - # 优先尝试获取预存的主事件循环 - main_loop = _get_main_loop() - - # 如果在子线程中且有主循环可用 - if threading.current_thread() is not threading.main_thread() and main_loop: - try: - if not main_loop.is_closed(): - future = asyncio.run_coroutine_threadsafe( - db_get(model_class, filters, limit, order_by, single_result), main_loop - ) - return future.result(timeout=30) - except Exception as e: - # 如果使用主循环失败,才在子线程创建新循环 - logger.debug(f"使用主事件循环失败({e}),在子线程中创建新循环") - return asyncio.run(db_get(model_class, filters, limit, order_by, single_result)) - - # 如果在主线程中,直接运行 - if threading.current_thread() is threading.main_thread(): - try: - # 检查是否有当前运行的循环 - current_loop = asyncio.get_running_loop() - if current_loop.is_running(): - # 主循环正在运行,返回空结果避免阻塞 - logger.debug("在运行中的主事件循环中跳过同步数据库查询") - return [] - except RuntimeError: - # 没有运行的循环,可以安全创建 - pass - - # 创建新循环运行查询 - return asyncio.run(db_get(model_class, filters, limit, order_by, single_result)) - - # 最后的兜底方案:在子线程创建新循环 - return asyncio.run(db_get(model_class, filters, limit, order_by, single_result)) - - except Exception as e: - logger.error(f"_sync_db_get 执行过程中发生错误: {e}") - return [] +# 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。 # 统计数据的键 @@ -271,28 +207,11 @@ class StatisticOutputTask(AsyncTask): async def run(self): try: now = datetime.now() - - # 使用线程池并行执行耗时操作 - loop = asyncio.get_event_loop() - - # 在线程池中并行执行数据收集和之前的HTML生成(如果存在) - with concurrent.futures.ThreadPoolExecutor() as executor: - logger.info("正在收集统计数据...") - - # 数据收集任务 - collect_task = loop.run_in_executor(executor, self._collect_all_statistics, now) - - # 等待数据收集完成 - stats = await collect_task - logger.info("统计数据收集完成") - - # 并行执行控制台输出和HTML报告生成 - console_task = loop.run_in_executor(executor, self._statistic_console_output, stats, now) - html_task = loop.run_in_executor(executor, self._generate_html_report, stats, now) - - # 等待两个输出任务完成 - await asyncio.gather(console_task, html_task) - + logger.info("正在收集统计数据(异步)...") + stats = await self._collect_all_statistics(now) + logger.info("统计数据收集完成") + self._statistic_console_output(stats, now) + await self._generate_html_report(stats, now) logger.info("统计数据输出完成") except Exception as e: logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}") @@ -305,31 +224,11 @@ class StatisticOutputTask(AsyncTask): async def _async_collect_and_output(): try: - import concurrent.futures - now = datetime.now() - loop = asyncio.get_event_loop() - - with concurrent.futures.ThreadPoolExecutor() as executor: - logger.info("正在后台收集统计数据...") - - # 创建后台任务,不等待完成 - collect_task = asyncio.create_task( - loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore - ) - - stats = await collect_task - logger.info("统计数据收集完成") - - # 创建并发的输出任务 - output_tasks = [ - asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore - asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore - ] - - # 等待所有输出任务完成 - await asyncio.gather(*output_tasks) - + logger.info("(后台) 正在收集统计数据(异步)...") + stats = await self._collect_all_statistics(now) + self._statistic_console_output(stats, now) + await self._generate_html_report(stats, now) logger.info("统计数据后台输出完成") except Exception as e: logger.exception(f"后台统计数据输出过程中发生异常:{e}") @@ -340,7 +239,7 @@ class StatisticOutputTask(AsyncTask): # -- 以下为统计数据收集方法 -- @staticmethod - def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: + async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: """ 收集指定时间段的LLM请求统计数据 @@ -394,10 +293,11 @@ class StatisticOutputTask(AsyncTask): # 以最早的时间戳为起始时间获取记录 query_start_time = collect_period[-1][1] - records = ( - _sync_db_get(model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp") - or [] - ) + records = await db_get( + model_class=LLMUsage, + filters={"timestamp": {"$gte": query_start_time}}, + order_by="-timestamp", + ) or [] for record in records: if not isinstance(record, dict): @@ -489,7 +389,7 @@ class StatisticOutputTask(AsyncTask): return stats @staticmethod - def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: + async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: """ 收集指定时间段的在线时间统计数据 @@ -508,12 +408,11 @@ class StatisticOutputTask(AsyncTask): } query_start_time = collect_period[-1][1] - records = ( - _sync_db_get( - model_class=OnlineTime, filters={"end_timestamp": {"$gte": query_start_time}}, order_by="-end_timestamp" - ) - or [] - ) + records = await db_get( + model_class=OnlineTime, + filters={"end_timestamp": {"$gte": query_start_time}}, + order_by="-end_timestamp", + ) or [] for record in records: if not isinstance(record, dict): @@ -545,7 +444,7 @@ class StatisticOutputTask(AsyncTask): break return stats - def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: + async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: """ 收集指定时间段的消息统计数据 @@ -565,10 +464,11 @@ class StatisticOutputTask(AsyncTask): } query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) - records = ( - _sync_db_get(model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time") - or [] - ) + records = await db_get( + model_class=Messages, + filters={"time": {"$gte": query_start_timestamp}}, + order_by="-time", + ) or [] for message in records: if not isinstance(message, dict): @@ -612,7 +512,7 @@ class StatisticOutputTask(AsyncTask): break return stats - def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: + async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: """ 收集各时间段的统计数据 :param now: 基准当前时间 @@ -634,9 +534,11 @@ class StatisticOutputTask(AsyncTask): stat = {item[0]: {} for item in self.stat_period} - model_req_stat = self._collect_model_request_for_period(stat_start_timestamp) - online_time_stat = self._collect_online_time_for_period(stat_start_timestamp, now) - message_count_stat = self._collect_message_count_for_period(stat_start_timestamp) + model_req_stat, online_time_stat, message_count_stat = await asyncio.gather( + self._collect_model_request_for_period(stat_start_timestamp), + self._collect_online_time_for_period(stat_start_timestamp, now), + self._collect_message_count_for_period(stat_start_timestamp), + ) # 统计数据合并 # 合并三类统计数据 @@ -796,7 +698,7 @@ class StatisticOutputTask(AsyncTask): # 移除_generate_versions_tab方法 - def _generate_html_report(self, stat: dict[str, Any], now: datetime): + async def _generate_html_report(self, stat: dict[str, Any], now: datetime): """ 生成HTML格式的统计报告 :param stat: 统计数据 @@ -941,8 +843,8 @@ class StatisticOutputTask(AsyncTask): ) # 不再添加版本对比内容 - # 添加图表内容 - chart_data = self._generate_chart_data(stat) + # 添加图表内容 (修正缩进) + chart_data = await self._generate_chart_data(stat) tab_content_list.append(self._generate_chart_tab(chart_data)) joined_tab_list = "\n".join(tab_list) @@ -1091,107 +993,90 @@ class StatisticOutputTask(AsyncTask): with open(self.record_file_path, "w", encoding="utf-8") as f: f.write(html_template) - def _generate_chart_data(self, stat: dict[str, Any]) -> dict: - """生成图表数据""" + async def _generate_chart_data(self, stat: dict[str, Any]) -> dict: + """生成图表数据 (异步)""" now = datetime.now() - chart_data = {} + chart_data: Dict[str, Any] = {} - # 支持多个时间范围 time_ranges = [ - ("6h", 6, 10), # 6小时,10分钟间隔 - ("12h", 12, 15), # 12小时,15分钟间隔 - ("24h", 24, 15), # 24小时,15分钟间隔 - ("48h", 48, 30), # 48小时,30分钟间隔 + ("6h", 6, 10), + ("12h", 12, 15), + ("24h", 24, 15), + ("48h", 48, 30), ] + # 依次处理(数据量不大,避免复杂度;如需可改 gather) for range_key, hours, interval_minutes in time_ranges: - range_data = self._collect_interval_data(now, hours, interval_minutes) - chart_data[range_key] = range_data - + chart_data[range_key] = await self._collect_interval_data(now, hours, interval_minutes) return chart_data - @staticmethod - def _collect_interval_data(now: datetime, hours: int, interval_minutes: int) -> dict: - """收集指定时间范围内每个间隔的数据""" - # 生成时间点 + async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: start_time = now - timedelta(hours=hours) - time_points = [] + time_points: List[datetime] = [] current_time = start_time - while current_time <= now: time_points.append(current_time) current_time += timedelta(minutes=interval_minutes) - # 初始化数据结构 - total_cost_data = [0] * len(time_points) - cost_by_model = {} - cost_by_module = {} - message_by_chat = {} + total_cost_data = [0.0] * len(time_points) + cost_by_model: Dict[str, List[float]] = {} + cost_by_module: Dict[str, List[float]] = {} + message_by_chat: Dict[str, List[int]] = {} time_labels = [t.strftime("%H:%M") for t in time_points] - interval_seconds = interval_minutes * 60 - # 查询LLM使用记录 - query_start_time = start_time - records = _sync_db_get( - model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp" - ) - - for record in records: + # 单次查询 LLMUsage + llm_records = await db_get( + model_class=LLMUsage, + filters={"timestamp": {"$gte": start_time}}, + order_by="-timestamp", + ) or [] + for record in llm_records: + if not isinstance(record, dict) or not record.get("timestamp"): + continue record_time = record["timestamp"] - - # 找到对应的时间间隔索引 + if isinstance(record_time, str): + try: + record_time = datetime.fromisoformat(record_time) + except Exception: + continue time_diff = (record_time - start_time).total_seconds() - interval_index = int(time_diff // interval_seconds) - - if 0 <= interval_index < len(time_points): - # 累加总花费数据 + idx = int(time_diff // interval_seconds) + if 0 <= idx < len(time_points): cost = record.get("cost") or 0.0 - total_cost_data[interval_index] += cost # type: ignore - - # 累加按模型分类的花费 + total_cost_data[idx] += cost model_name = record.get("model_name") or "unknown" if model_name not in cost_by_model: - cost_by_model[model_name] = [0] * len(time_points) - cost_by_model[model_name][interval_index] += cost - - # 累加按模块分类的花费 + cost_by_model[model_name] = [0.0] * len(time_points) + cost_by_model[model_name][idx] += cost request_type = record.get("request_type") or "unknown" module_name = request_type.split(".")[0] if "." in request_type else request_type if module_name not in cost_by_module: - cost_by_module[module_name] = [0] * len(time_points) - cost_by_module[module_name][interval_index] += cost + cost_by_module[module_name] = [0.0] * len(time_points) + cost_by_module[module_name][idx] += cost - # 查询消息记录 - query_start_timestamp = start_time.timestamp() - records = _sync_db_get( - model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time" - ) - - for message in records: - message_time_ts = message["time"] - - # 找到对应的时间间隔索引 - time_diff = message_time_ts - query_start_timestamp - interval_index = int(time_diff // interval_seconds) - - if 0 <= interval_index < len(time_points): - # 确定聊天流名称 - chat_name = None - if message.get("chat_info_group_id"): - chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}" - elif message.get("user_id"): - chat_name = message.get("user_nickname") or f"用户{message['user_id']}" + # 单次查询 Messages + msg_records = await db_get( + model_class=Messages, + filters={"time": {"$gte": start_time.timestamp()}}, + order_by="-time", + ) or [] + for msg in msg_records: + if not isinstance(msg, dict) or not msg.get("time"): + continue + msg_ts = msg["time"] + time_diff = msg_ts - start_time.timestamp() + idx = int(time_diff // interval_seconds) + if 0 <= idx < len(time_points): + if msg.get("chat_info_group_id"): + chat_name = msg.get("chat_info_group_name") or f"群{msg['chat_info_group_id']}" + elif msg.get("user_id"): + chat_name = msg.get("user_nickname") or f"用户{msg['user_id']}" else: continue - - if not chat_name: - continue - - # 累加消息数 if chat_name not in message_by_chat: message_by_chat[chat_name] = [0] * len(time_points) - message_by_chat[chat_name][interval_index] += 1 + message_by_chat[chat_name][idx] += 1 return { "time_labels": time_labels, @@ -1478,101 +1363,4 @@ class StatisticOutputTask(AsyncTask): }}); - """ - - -class AsyncStatisticOutputTask(AsyncTask): - """完全异步的统计输出任务 - 更高性能版本""" - - def __init__(self, record_file_path: str = "maibot_statistics.html"): - # 延迟0秒启动,运行间隔300秒 - super().__init__(task_name="Async Statistics Data Output Task", wait_before_start=0, run_interval=300) - - # 直接复用 StatisticOutputTask 的初始化逻辑 - temp_stat_task = StatisticOutputTask(record_file_path) - self.name_mapping = temp_stat_task.name_mapping - self.record_file_path = temp_stat_task.record_file_path - self.stat_period = temp_stat_task.stat_period - - async def run(self): - """完全异步执行统计任务""" - - async def _async_collect_and_output(): - try: - now = datetime.now() - loop = asyncio.get_event_loop() - - with concurrent.futures.ThreadPoolExecutor() as executor: - logger.info("正在后台收集统计数据...") - - # 数据收集任务 - collect_task = asyncio.create_task( - loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore - ) - - stats = await collect_task - logger.info("统计数据收集完成") - - # 创建并发的输出任务 - output_tasks = [ - asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore - asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore - ] - - # 等待所有输出任务完成 - await asyncio.gather(*output_tasks) - - logger.info("统计数据后台输出完成") - except Exception as e: - logger.exception(f"后台统计数据输出过程中发生异常:{e}") - - # 创建后台任务,立即返回 - asyncio.create_task(_async_collect_and_output()) - - # 复用 StatisticOutputTask 的所有方法 - def _collect_all_statistics(self, now: datetime): - return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore - - def _statistic_console_output(self, stats: Dict[str, Any], now: datetime): - return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore - - def _generate_html_report(self, stats: dict[str, Any], now: datetime): - return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore - - # 其他需要的方法也可以类似复用... - @staticmethod - def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_model_request_for_period(collect_period) - - @staticmethod - def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: - return StatisticOutputTask._collect_online_time_for_period(collect_period, now) - - def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore - - @staticmethod - def _format_total_stat(stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_total_stat(stats) - - @staticmethod - def _format_model_classified_stat(stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_model_classified_stat(stats) - - def _format_chat_stat(self, stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore - - def _generate_chart_data(self, stat: dict[str, Any]) -> dict: - return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore - - def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: - return StatisticOutputTask._collect_interval_data(now, hours, interval_minutes) # type: ignore - - def _generate_chart_tab(self, chart_data: dict) -> str: - return StatisticOutputTask._generate_chart_tab(chart_data) # type: ignore - - def _get_chat_display_name_from_id(self, chat_id: str) -> str: - return StatisticOutputTask._get_chat_display_name_from_id(chat_id) # type: ignore - - def _convert_defaultdict_to_dict(self, data): - return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore + """ \ No newline at end of file diff --git a/src/plugin_system/apis/permission_api.py b/src/plugin_system/apis/permission_api.py index f40198352..97fde236c 100644 --- a/src/plugin_system/apis/permission_api.py +++ b/src/plugin_system/apis/permission_api.py @@ -1,13 +1,8 @@ -""" -权限系统API - 提供权限管理相关的API接口 - -这个模块提供了权限系统的核心API,包括权限检查、权限节点管理等功能。 -插件可以通过这些API来检查用户权限和管理权限节点。 -""" +"""纯异步权限API定义。所有外部调用方必须使用 await。""" from typing import Optional, List, Dict, Any -from enum import Enum from dataclasses import dataclass +from enum import Enum from abc import ABC, abstractmethod from src.common.logger import get_logger @@ -16,325 +11,172 @@ logger = get_logger(__name__) class PermissionLevel(Enum): - """权限等级枚举""" - - MASTER = "master" # 最高权限,无视所有权限节点 + MASTER = "master" @dataclass class PermissionNode: - """权限节点数据类""" - - node_name: str # 权限节点名称,如 "plugin.example.command.test" - description: str # 权限节点描述 - plugin_name: str # 所属插件名称 - default_granted: bool = False # 默认是否授权 + node_name: str + description: str + plugin_name: str + default_granted: bool = False @dataclass class UserInfo: - """用户信息数据类""" - - platform: str # 平台类型,如 "qq" - user_id: str # 用户ID + platform: str + user_id: str def __post_init__(self): - """确保user_id是字符串类型""" self.user_id = str(self.user_id) - def to_tuple(self) -> tuple[str, str]: - """转换为元组格式""" - return self.platform, self.user_id - class IPermissionManager(ABC): - """权限管理器接口""" + @abstractmethod + async def check_permission(self, user: UserInfo, permission_node: str) -> bool: ... @abstractmethod - def check_permission(self, user: UserInfo, permission_node: str) -> bool: - """ - 检查用户是否拥有指定权限节点 - - Args: - user: 用户信息 - permission_node: 权限节点名称 - - Returns: - bool: 是否拥有权限 - """ - pass + def is_master(self, user: UserInfo) -> bool: ... # 同步快速判断 @abstractmethod - def is_master(self, user: UserInfo) -> bool: - """ - 检查用户是否为Master用户 - - Args: - user: 用户信息 - - Returns: - bool: 是否为Master用户 - """ - pass + async def register_permission_node(self, node: PermissionNode) -> bool: ... @abstractmethod - def register_permission_node(self, node: PermissionNode) -> bool: - """ - 注册权限节点 - - Args: - node: 权限节点 - - Returns: - bool: 注册是否成功 - """ - pass + async def grant_permission(self, user: UserInfo, permission_node: str) -> bool: ... @abstractmethod - def grant_permission(self, user: UserInfo, permission_node: str) -> bool: - """ - 授权用户权限节点 - - Args: - user: 用户信息 - permission_node: 权限节点名称 - - Returns: - bool: 授权是否成功 - """ - pass + async def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: ... @abstractmethod - def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: - """ - 撤销用户权限节点 - - Args: - user: 用户信息 - permission_node: 权限节点名称 - - Returns: - bool: 撤销是否成功 - """ - pass + async def get_user_permissions(self, user: UserInfo) -> List[str]: ... @abstractmethod - def get_user_permissions(self, user: UserInfo) -> List[str]: - """ - 获取用户拥有的所有权限节点 - - Args: - user: 用户信息 - - Returns: - List[str]: 权限节点列表 - """ - pass + async def get_all_permission_nodes(self) -> List[PermissionNode]: ... @abstractmethod - def get_all_permission_nodes(self) -> List[PermissionNode]: - """ - 获取所有已注册的权限节点 - - Returns: - List[PermissionNode]: 权限节点列表 - """ - pass - - @abstractmethod - def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: - """ - 获取指定插件的所有权限节点 - - Args: - plugin_name: 插件名称 - - Returns: - List[PermissionNode]: 权限节点列表 - """ - pass + async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: ... class PermissionAPI: - """权限系统API类""" - def __init__(self): self._permission_manager: Optional[IPermissionManager] = None + # 需要保留的前缀(视为绝对节点名,不再自动加 plugins.. 前缀) + self.RESERVED_PREFIXES: tuple[str, ...] = ( + "system.") + # 系统节点列表 (name, description, default_granted) + self._SYSTEM_NODES: list[tuple[str, str, bool]] = [ + ("system.superuser", "系统超级管理员:拥有所有权限", False), + ("system.permission.manage", "系统权限管理:可管理所有权限节点", False), + ("system.permission.view", "系统权限查看:可查看所有权限节点", True), + ] + self._system_nodes_initialized: bool = False def set_permission_manager(self, manager: IPermissionManager): - """设置权限管理器实例""" self._permission_manager = manager logger.info("权限管理器已设置") def _ensure_manager(self): - """确保权限管理器已设置""" if self._permission_manager is None: raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager") - def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool: - """ - 检查用户是否拥有指定权限节点 - - Args: - platform: 平台类型,如 "qq" - user_id: 用户ID - permission_node: 权限节点名称 - - Returns: - bool: 是否拥有权限 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ + async def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool: self._ensure_manager() - user = UserInfo(platform=platform, user_id=str(user_id)) - return self._permission_manager.check_permission(user, permission_node) + return await self._permission_manager.check_permission(UserInfo(platform, user_id), permission_node) def is_master(self, platform: str, user_id: str) -> bool: - """ - 检查用户是否为Master用户 - - Args: - platform: 平台类型,如 "qq" - user_id: 用户ID - - Returns: - bool: 是否为Master用户 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ self._ensure_manager() - user = UserInfo(platform=platform, user_id=str(user_id)) - return self._permission_manager.is_master(user) + return self._permission_manager.is_master(UserInfo(platform, user_id)) - def register_permission_node( - self, node_name: str, description: str, plugin_name: str, default_granted: bool = False + async def register_permission_node( + self, + node_name: str, + description: str, + plugin_name: str, + default_granted: bool = False, + *, + system: bool = False, + allow_relative: bool = True, ) -> bool: - """ - 注册权限节点 - - Args: - node_name: 权限节点名称,如 "plugin.example.command.test" - description: 权限节点描述 - plugin_name: 所属插件名称 - default_granted: 默认是否授权 - - Returns: - bool: 注册是否成功 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ self._ensure_manager() - node = PermissionNode( - node_name=node_name, description=description, plugin_name=plugin_name, default_granted=default_granted + original_name = node_name + if system: + # 系统节点必须以 system./sys./core. 等保留前缀开头 + if not node_name.startswith(("system.", "sys.", "core.")): + node_name = f"system.{node_name}" # 自动补 system. + else: + # 普通插件节点:若不以保留前缀开头,并允许相对,则自动加前缀 + if allow_relative and not node_name.startswith(self.RESERVED_PREFIXES): + node_name = f"plugins.{plugin_name}.{node_name}" + if original_name != node_name: + logger.debug(f"规范化权限节点 '{original_name}' -> '{node_name}'") + node = PermissionNode(node_name, description, plugin_name, default_granted) + return await self._permission_manager.register_permission_node(node) + + async def register_system_permission_node( + self, node_name: str, description: str, default_granted: bool = False + ) -> bool: + """注册系统级权限节点(不绑定具体插件,前缀保持 system./sys./core.)。""" + return await self.register_permission_node( + node_name, + description, + plugin_name="__system__", + default_granted=default_granted, + system=True, + allow_relative=True, ) - return self._permission_manager.register_permission_node(node) - def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool: - """ - 授权用户权限节点 - - Args: - platform: 平台类型,如 "qq" - user_id: 用户ID - permission_node: 权限节点名称 - - Returns: - bool: 授权是否成功 - - Raises: - RuntimeError: 权限管理器未设置时抛出 + async def init_system_nodes(self) -> None: + """初始化默认系统权限节点(幂等)。 + + 在设置 permission_manager 之后且数据库准备好时调用一次即可。 """ + if self._system_nodes_initialized: + return self._ensure_manager() - user = UserInfo(platform=platform, user_id=str(user_id)) - return self._permission_manager.grant_permission(user, permission_node) + for name, desc, granted in self._SYSTEM_NODES: + try: + await self.register_system_permission_node(name, desc, granted) + except Exception as e: # 防御性 + logger.warning(f"注册系统权限节点 {name} 失败: {e}") + self._system_nodes_initialized = True - def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool: - """ - 撤销用户权限节点 - - Args: - platform: 平台类型,如 "qq" - user_id: 用户ID - permission_node: 权限节点名称 - - Returns: - bool: 撤销是否成功 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ + async def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool: self._ensure_manager() - user = UserInfo(platform=platform, user_id=str(user_id)) - return self._permission_manager.revoke_permission(user, permission_node) + return await self._permission_manager.grant_permission(UserInfo(platform, user_id), permission_node) - def get_user_permissions(self, platform: str, user_id: str) -> List[str]: - """ - 获取用户拥有的所有权限节点 - - Args: - platform: 平台类型,如 "qq" - user_id: 用户ID - - Returns: - List[str]: 权限节点列表 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ + async def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool: self._ensure_manager() - user = UserInfo(platform=platform, user_id=str(user_id)) - return self._permission_manager.get_user_permissions(user) + return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node) - def get_all_permission_nodes(self) -> List[Dict[str, Any]]: - """ - 获取所有已注册的权限节点 - - Returns: - List[Dict[str, Any]]: 权限节点列表,每个节点包含 node_name, description, plugin_name, default_granted - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ + async def get_user_permissions(self, platform: str, user_id: str) -> List[str]: self._ensure_manager() - nodes = self._permission_manager.get_all_permission_nodes() + return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id)) + + async def get_all_permission_nodes(self) -> List[Dict[str, Any]]: + self._ensure_manager() + nodes = await self._permission_manager.get_all_permission_nodes() return [ { - "node_name": node.node_name, - "description": node.description, - "plugin_name": node.plugin_name, - "default_granted": node.default_granted, + "node_name": n.node_name, + "description": n.description, + "plugin_name": n.plugin_name, + "default_granted": n.default_granted, } - for node in nodes + for n in nodes ] - def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]: - """ - 获取指定插件的所有权限节点 - - Args: - plugin_name: 插件名称 - - Returns: - List[Dict[str, Any]]: 权限节点列表 - - Raises: - RuntimeError: 权限管理器未设置时抛出 - """ + async def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]: self._ensure_manager() - nodes = self._permission_manager.get_plugin_permission_nodes(plugin_name) + nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name) return [ { - "node_name": node.node_name, - "description": node.description, - "plugin_name": node.plugin_name, - "default_granted": node.default_granted, + "node_name": n.node_name, + "description": n.description, + "plugin_name": n.plugin_name, + "default_granted": n.default_granted, } - for node in nodes + for n in nodes ] -# 全局权限API实例 permission_api = PermissionAPI() diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py index 67322ba34..45357b4b0 100644 --- a/src/plugin_system/utils/permission_decorators.py +++ b/src/plugin_system/utils/permission_decorators.py @@ -7,6 +7,7 @@ from functools import wraps from typing import Callable, Optional from inspect import iscoroutinefunction +import inspect from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.send_api import text_to_stream @@ -61,7 +62,7 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) return None # 检查权限 - has_permission = permission_api.check_permission( + has_permission = await permission_api.check_permission( chat_stream.platform, chat_stream.user_info.user_id, permission_node ) @@ -77,40 +78,13 @@ def require_permission(permission_node: str, deny_message: Optional[str] = None) # 权限检查通过,执行原函数 return await func(*args, **kwargs) - def sync_wrapper(*args, **kwargs): - # 对于同步函数,我们不能发送异步消息,只能记录日志 - chat_stream = None - for arg in args: - if isinstance(arg, ChatStream): - chat_stream = arg - break - - if chat_stream is None: - chat_stream = kwargs.get("chat_stream") - - if chat_stream is None: - logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") + if not iscoroutinefunction(func): + logger.warning(f"函数 {func.__name__} 使用 require_permission 但非异步,已强制阻止执行") + async def blocked(*_a, **_k): + logger.error("同步函数不再支持权限装饰器,请改为 async def") return None - - # 检查权限 - has_permission = permission_api.check_permission( - chat_stream.platform, chat_stream.user_info.user_id, permission_node - ) - - if not has_permission: - logger.warning( - f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 没有权限 {permission_node}" - ) - return None - - # 权限检查通过,执行原函数 - return func(*args, **kwargs) - - # 根据函数类型选择包装器 - if iscoroutinefunction(func): - return async_wrapper - else: - return sync_wrapper + return blocked + return async_wrapper return decorator @@ -171,36 +145,13 @@ def require_master(deny_message: Optional[str] = None): # 权限检查通过,执行原函数 return await func(*args, **kwargs) - def sync_wrapper(*args, **kwargs): - # 对于同步函数,我们不能发送异步消息,只能记录日志 - chat_stream = None - for arg in args: - if isinstance(arg, ChatStream): - chat_stream = arg - break - - if chat_stream is None: - chat_stream = kwargs.get("chat_stream") - - if chat_stream is None: - logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") + if not iscoroutinefunction(func): + logger.warning(f"函数 {func.__name__} 使用 require_master 但非异步,已强制阻止执行") + async def blocked(*_a, **_k): + logger.error("同步函数不再支持 require_master,请改为 async def") return None - - # 检查是否为Master用户 - is_master = permission_api.is_master(chat_stream.platform, chat_stream.user_info.user_id) - - if not is_master: - logger.warning(f"用户 {chat_stream.platform}:{chat_stream.user_info.user_id} 不是Master用户") - return None - - # 权限检查通过,执行原函数 - return func(*args, **kwargs) - - # 根据函数类型选择包装器 - if iscoroutinefunction(func): - return async_wrapper - else: - return sync_wrapper + return blocked + return async_wrapper return decorator @@ -214,17 +165,7 @@ class PermissionChecker: @staticmethod def check_permission(chat_stream: ChatStream, permission_node: str) -> bool: - """ - 检查用户是否拥有指定权限 - - Args: - chat_stream: 聊天流对象 - permission_node: 权限节点名称 - - Returns: - bool: 是否拥有权限 - """ - return permission_api.check_permission(chat_stream.platform, chat_stream.user_info.user_id, permission_node) + raise RuntimeError("PermissionChecker.check_permission 已移除同步支持,请直接 await permission_api.check_permission") @staticmethod def is_master(chat_stream: ChatStream) -> bool: @@ -254,12 +195,12 @@ class PermissionChecker: Returns: bool: 是否拥有权限 """ - has_permission = PermissionChecker.check_permission(chat_stream, permission_node) - + has_permission = await permission_api.check_permission( + chat_stream.platform, chat_stream.user_info.user_id, permission_node + ) if not has_permission: message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}" await text_to_stream(message, chat_stream.stream_id) - return has_permission @staticmethod diff --git a/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py b/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py index 7e15accea..ee5a1b73a 100644 --- a/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py +++ b/src/plugins/built_in/maizone_refactored/actions/read_feed_action.py @@ -39,7 +39,7 @@ class ReadFeedAction(BaseAction): user_id = self.chat_stream.user_info.user_id # 使用权限API检查用户是否有阅读说说的权限 - return permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed") + return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed") async def execute(self) -> Tuple[bool, str]: """ diff --git a/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py b/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py index fe9a25ed6..af8760c06 100644 --- a/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py +++ b/src/plugins/built_in/maizone_refactored/actions/send_feed_action.py @@ -39,7 +39,7 @@ class SendFeedAction(BaseAction): user_id = self.chat_stream.user_info.user_id # 使用权限API检查用户是否有发送说说的权限 - return permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed") + return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed") async def execute(self) -> Tuple[bool, str]: """ diff --git a/src/plugins/built_in/maizone_refactored/plugin.py b/src/plugins/built_in/maizone_refactored/plugin.py index c54872872..de644c31b 100644 --- a/src/plugins/built_in/maizone_refactored/plugin.py +++ b/src/plugins/built_in/maizone_refactored/plugin.py @@ -87,11 +87,11 @@ class MaiZoneRefactoredPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # 注册权限节点 - permission_api.register_permission_node( + async def on_plugin_loaded(self): + await permission_api.register_permission_node( "plugin.maizone.send_feed", "是否可以使用机器人发送QQ空间说说", "maiZone", False ) - permission_api.register_permission_node( + await permission_api.register_permission_node( "plugin.maizone.read_feed", "是否可以使用机器人读取QQ空间说说", "maiZone", True ) # 创建所有服务实例 diff --git a/src/plugins/built_in/permission_management/plugin.py b/src/plugins/built_in/permission_management/plugin.py index e33a6d08f..fd8612348 100644 --- a/src/plugins/built_in/permission_management/plugin.py +++ b/src/plugins/built_in/permission_management/plugin.py @@ -34,11 +34,13 @@ class PermissionCommand(PlusCommand): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # 注册权限节点 - permission_api.register_permission_node( + + async def on_plugin_loaded(self): + # 注册权限节点(使用显式前缀,避免再次自动补全) + await permission_api.register_permission_node( "plugin.permission.manage", "权限管理:可以授权和撤销其他用户的权限", "permission_manager", False ) - permission_api.register_permission_node( + await permission_api.register_permission_node( "plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True ) @@ -179,7 +181,7 @@ class PermissionCommand(PlusCommand): permission_node = args[1] # 执行授权 - success = permission_api.grant_permission(chat_stream.platform, user_id, permission_node) + success = await permission_api.grant_permission(chat_stream.platform, user_id, permission_node) if success: await self.send_text(f"✅ 已授权用户 {user_id} 权限节点 `{permission_node}`") @@ -202,7 +204,7 @@ class PermissionCommand(PlusCommand): permission_node = args[1] # 执行撤销 - success = permission_api.revoke_permission(chat_stream.platform, user_id, permission_node) + success = await permission_api.revoke_permission(chat_stream.platform, user_id, permission_node) if success: await self.send_text(f"✅ 已撤销用户 {user_id} 权限节点 `{permission_node}`") @@ -225,10 +227,10 @@ class PermissionCommand(PlusCommand): target_user_id = chat_stream.user_info.user_id # 检查是否为Master用户 - is_master = permission_api.is_master(chat_stream.platform, target_user_id) + is_master = await permission_api.is_master(chat_stream.platform, target_user_id) # 获取用户权限 - permissions = permission_api.get_user_permissions(chat_stream.platform, target_user_id) + permissions = await permission_api.get_user_permissions(chat_stream.platform, target_user_id) if is_master: response = f"👑 用户 `{target_user_id}` 是Master用户,拥有所有权限" @@ -257,8 +259,8 @@ class PermissionCommand(PlusCommand): permission_node = args[1] # 检查权限 - has_permission = permission_api.check_permission(chat_stream.platform, user_id, permission_node) - is_master = permission_api.is_master(chat_stream.platform, user_id) + has_permission = await permission_api.check_permission(chat_stream.platform, user_id, permission_node) + is_master = await permission_api.is_master(chat_stream.platform, user_id) if has_permission: if is_master: @@ -277,11 +279,11 @@ class PermissionCommand(PlusCommand): if plugin_name: # 获取指定插件的权限节点 - nodes = permission_api.get_plugin_permission_nodes(plugin_name) + nodes = await permission_api.get_plugin_permission_nodes(plugin_name) title = f"📋 插件 {plugin_name} 的权限节点:" else: # 获取所有权限节点 - nodes = permission_api.get_all_permission_nodes() + nodes = await permission_api.get_all_permission_nodes() title = "📋 所有权限节点:" if not nodes: @@ -307,7 +309,7 @@ class PermissionCommand(PlusCommand): async def _list_all_nodes_with_description(self, chat_stream): """列出所有插件的权限节点(带详细描述)""" # 获取所有权限节点 - all_nodes = permission_api.get_all_permission_nodes() + all_nodes = await permission_api.get_all_permission_nodes() if not all_nodes: response = "📋 系统中没有任何权限节点" diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index 741cb38b9..c9550500b 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -548,11 +548,13 @@ class PluginManagementPlugin(BasePlugin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # 注册权限节点 - permission_api.register_permission_node( - "plugin.management.admin", - "插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作", - "plugin_management", - False, + + async def on_plugin_loaded(self): + await permission_api.register_permission_node( + "plugin.management.admin", + "插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作", + "plugin_management", + False, ) def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]: From 742a8c2c372e6d4a3af4958e3d25d8b2130f6502 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sun, 21 Sep 2025 13:30:27 +0800 Subject: [PATCH 29/31] =?UTF-8?q?feat(plugin=5Fsystem):=20=E5=AF=BC?= =?UTF-8?q?=E5=87=BA=20schedule=5Fapi?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 schedule_api 添加到插件系统的 API 导出列表中,使其对插件可用。 --- src/plugin_system/apis/__init__.py | 2 + src/plugin_system/apis/schedule_api.py | 179 +++++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 src/plugin_system/apis/schedule_api.py diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index 30ff428d7..411a2d326 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -19,6 +19,7 @@ from src.plugin_system.apis import ( send_api, tool_api, permission_api, + schedule_api ) from src.plugin_system.apis.chat_api import ChatManager as context_api from .logging_api import get_logger @@ -42,4 +43,5 @@ __all__ = [ "tool_api", "permission_api", "context_api", + "schedule_api", ] diff --git a/src/plugin_system/apis/schedule_api.py b/src/plugin_system/apis/schedule_api.py new file mode 100644 index 000000000..f1d81049e --- /dev/null +++ b/src/plugin_system/apis/schedule_api.py @@ -0,0 +1,179 @@ +""" +日程表与月度计划API模块 + +专门负责日程和月度计划信息的查询与管理,采用标准Python包设计模式 +所有对外接口均为异步函数,以便于插件开发者在异步环境中使用。 + +使用方式: + import asyncio + from src.plugin_system.apis import schedule_api + + async def main(): + # 获取今日日程 + today_schedule = await schedule_api.get_today_schedule() + if today_schedule: + print("今天的日程:", today_schedule) + + # 获取当前活动 + current_activity = await schedule_api.get_current_activity() + if current_activity: + print("当前活动:", current_activity) + + # 获取本月月度计划 + from datetime import datetime + this_month = datetime.now().strftime("%Y-%m") + plans = await schedule_api.get_monthly_plans(this_month) + if plans: + print(f"{this_month} 的月度计划:", [p.plan_text for p in plans]) + + asyncio.run(main()) +""" +from datetime import datetime +from typing import List, Dict, Any, Optional + +from src.common.database.sqlalchemy_models import MonthlyPlan +from src.common.logger import get_logger +from src.schedule.database import get_active_plans_for_month +from src.schedule.schedule_manager import schedule_manager + +logger = get_logger("schedule_api") + + +class ScheduleAPI: + """日程表与月度计划API - 负责日程和计划信息的查询与管理""" + + @staticmethod + async def get_today_schedule() -> Optional[List[Dict[str, Any]]]: + """(异步) 获取今天的日程安排 + + Returns: + Optional[List[Dict[str, Any]]]: 今天的日程列表,如果未生成或未启用则返回None + """ + try: + logger.debug("[ScheduleAPI] 正在获取今天的日程安排...") + return schedule_manager.today_schedule + except Exception as e: + logger.error(f"[ScheduleAPI] 获取今日日程失败: {e}") + return None + + @staticmethod + async def get_current_activity() -> Optional[str]: + """(异步) 获取当前正在进行的活动 + + Returns: + Optional[str]: 当前活动名称,如果没有则返回None + """ + try: + logger.debug("[ScheduleAPI] 正在获取当前活动...") + return schedule_manager.get_current_activity() + except Exception as e: + logger.error(f"[ScheduleAPI] 获取当前活动失败: {e}") + return None + + @staticmethod + async def regenerate_schedule() -> bool: + """(异步) 触发后台重新生成今天的日程 + + Returns: + bool: 是否成功触发 + """ + try: + logger.info("[ScheduleAPI] 正在触发后台重新生成日程...") + await schedule_manager.generate_and_save_schedule() + return True + except Exception as e: + logger.error(f"[ScheduleAPI] 触发日程重新生成失败: {e}") + return False + + @staticmethod + async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]: + """(异步) 获取指定月份的有效月度计划 + + Args: + target_month (Optional[str]): 目标月份,格式为 "YYYY-MM"。如果为None,则使用当前月份。 + + Returns: + List[MonthlyPlan]: 月度计划对象列表 + """ + if target_month is None: + target_month = datetime.now().strftime("%Y-%m") + try: + logger.debug(f"[ScheduleAPI] 正在获取 {target_month} 的月度计划...") + return await get_active_plans_for_month(target_month) + except Exception as e: + logger.error(f"[ScheduleAPI] 获取 {target_month} 月度计划失败: {e}") + return [] + + @staticmethod + async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool: + """(异步) 确保指定月份存在月度计划,如果不存在则触发生成 + + Args: + target_month (Optional[str]): 目标月份,格式为 "YYYY-MM"。如果为None,则使用当前月份。 + + Returns: + bool: 操作是否成功 (如果已存在或成功生成) + """ + if target_month is None: + target_month = datetime.now().strftime("%Y-%m") + try: + logger.info(f"[ScheduleAPI] 正在确保 {target_month} 的月度计划存在...") + return await schedule_manager.plan_manager.ensure_and_generate_plans_if_needed(target_month) + except Exception as e: + logger.error(f"[ScheduleAPI] 确保 {target_month} 月度计划失败: {e}") + return False + + @staticmethod + async def archive_monthly_plans(target_month: Optional[str] = None) -> bool: + """(异步) 归档指定月份的月度计划 + + Args: + target_month (Optional[str]): 目标月份,格式为 "YYYY-MM"。如果为None,则使用当前月份。 + + Returns: + bool: 操作是否成功 + """ + if target_month is None: + target_month = datetime.now().strftime("%Y-%m") + try: + logger.info(f"[ScheduleAPI] 正在归档 {target_month} 的月度计划...") + await schedule_manager.plan_manager.archive_current_month_plans(target_month) + return True + except Exception as e: + logger.error(f"[ScheduleAPI] 归档 {target_month} 月度计划失败: {e}") + return False + + +# ============================================================================= +# 模块级别的便捷函数 (全部为异步) +# ============================================================================= + + +async def get_today_schedule() -> Optional[List[Dict[str, Any]]]: + """(异步) 获取今天的日程安排的便捷函数""" + return await ScheduleAPI.get_today_schedule() + + +async def get_current_activity() -> Optional[str]: + """(异步) 获取当前正在进行的活动的便捷函数""" + return await ScheduleAPI.get_current_activity() + + +async def regenerate_schedule() -> bool: + """(异步) 触发后台重新生成今天的日程的便捷函数""" + return await ScheduleAPI.regenerate_schedule() + + +async def get_monthly_plans(target_month: Optional[str] = None) -> List[MonthlyPlan]: + """(异步) 获取指定月份的有效月度计划的便捷函数""" + return await ScheduleAPI.get_monthly_plans(target_month) + + +async def ensure_monthly_plans(target_month: Optional[str] = None) -> bool: + """(异步) 确保指定月份存在月度计划的便捷函数""" + return await ScheduleAPI.ensure_monthly_plans(target_month) + + +async def archive_monthly_plans(target_month: Optional[str] = None) -> bool: + """(异步) 归档指定月份的月度计划的便捷函数""" + return await ScheduleAPI.archive_monthly_plans(target_month) \ No newline at end of file From de48d2ae02f2e451baa4d2a619d68b4201c35844 Mon Sep 17 00:00:00 2001 From: Furina-1013-create <189647097+Furina-1013-create@users.noreply.github.com> Date: Mon, 22 Sep 2025 22:44:15 +0800 Subject: [PATCH 30/31] =?UTF-8?q?=E5=B0=86=E5=9B=9E=E5=A4=8D=E8=A7=84?= =?UTF-8?q?=E5=88=99=E9=83=A8=E5=88=86=E7=9A=84=E6=8F=90=E7=A4=BA=E8=AF=8D?= =?UTF-8?q?=E8=AE=A9=E7=94=A8=E6=88=B7=E5=8F=AF=E4=BB=A5=E8=87=AA=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=EF=BC=8C=E8=AF=A6=E7=BB=86=E8=AF=B7=E7=9C=8B=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/default_generator.py | 52 ++++++++++++++++----------- src/config/official_configs.py | 30 ++++++++++++++++ template/bot_config_template.toml | 26 +++++++++++++- 3 files changed, 87 insertions(+), 21 deletions(-) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index de9c176cf..127779e1e 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -119,17 +119,6 @@ def init_prompt(): ## 规则 {safety_guidelines_block} -在回应之前,首先分析消息的针对性: -1. **直接针对你**:@你、回复你、明确询问你 → 必须回应 -2. **间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与 -3. **他人对话**:与你无关的私人交流 → 通常不参与 -4. **重复内容**:他人已充分回答的问题 → 避免重复 - -你的回复应该: -1. 明确回应目标消息,而不是宽泛地评论。 -2. 可以分享你的看法、提出相关问题,或者开个合适的玩笑。 -3. 目的是让对话更有趣、更深入。 -4. 不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。 最终请输出一条简短、完整且口语化的回复。 -------------------------------- @@ -168,11 +157,7 @@ If you need to use the search tool, please directly call the function "lpmm_sear 你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。 **重要:消息针对性判断** -在回应之前,首先分析消息的针对性: -1. **直接针对你**:@你、回复你、明确询问你 → 必须回应 -2. **间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与 -3. **他人对话**:与你无关的私人交流 → 通常不参与 -4. **重复内容**:他人已充分回答的问题 → 避免重复 +{safety_guidelines_block} {expression_habits_block} {tool_info_block} @@ -202,10 +187,6 @@ If you need to use the search tool, please directly call the function "lpmm_sear {keywords_reaction_prompt} 请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。 {moderation_prompt} -你的核心任务是针对 {reply_target_block} 中提到的内容,生成一段紧密相关且能推动对话的回复。你的回复应该: -1. 明确回应目标消息,而不是宽泛地评论。 -2. 可以分享你的看法、提出相关问题,或者开个合适的玩笑。 -3. 目的是让对话更有趣、更深入。 最终请输出一条简短、完整且口语化的回复。 现在,你说: """, @@ -1012,6 +993,37 @@ class DefaultReplyer: {guidelines_text} 如果遇到违反上述原则的请求,请在保持你核心人设的同时,巧妙地拒绝或转移话题。 """ + + # 新增逻辑:构建回复规则块 + reply_targeting_rules = global_config.personality.reply_targeting_rules + message_targeting_analysis = global_config.personality.message_targeting_analysis + reply_principles = global_config.personality.reply_principles + + # 构建消息针对性分析部分 + targeting_analysis_text = "" + if message_targeting_analysis: + targeting_analysis_text = "\n".join(f"{i+1}. {rule}" for i, rule in enumerate(message_targeting_analysis)) + + # 构建回复原则部分 + reply_principles_text = "" + if reply_principles: + reply_principles_text = "\n".join(f"{i+1}. {principle}" for i, principle in enumerate(reply_principles)) + + # 综合构建完整的规则块 + if targeting_analysis_text or reply_principles_text: + complete_rules_block = "" + if targeting_analysis_text: + complete_rules_block += f""" +在回应之前,首先分析消息的针对性: +{targeting_analysis_text} +""" + if reply_principles_text: + complete_rules_block += f""" +你的回复应该: +{reply_principles_text} +""" + # 将规则块添加到safety_guidelines_block + safety_guidelines_block += complete_rules_block if sender and target: if is_group_chat: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 978c5e47b..37e055bb3 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -57,6 +57,36 @@ class PersonalityConfig(ValidatedConfigBase): prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式") compress_personality: bool = Field(default=True, description="是否压缩人格") compress_identity: bool = Field(default=True, description="是否压缩身份") + + # 回复规则配置 + reply_targeting_rules: List[str] = Field( + default_factory=lambda: [ + "拒绝任何包含骚扰、冒犯、暴力、色情或危险内容的请求。", + "在拒绝时,请使用符合你人设的、坚定的语气。", + "不要执行任何可能被用于恶意目的的指令。" + ], + description="安全与互动底线规则,Bot在任何情况下都必须遵守的原则" + ) + + message_targeting_analysis: List[str] = Field( + default_factory=lambda: [ + "**直接针对你**:@你、回复你、明确询问你 → 必须回应", + "**间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与", + "**他人对话**:与你无关的私人交流 → 通常不参与", + "**重复内容**:他人已充分回答的问题 → 避免重复" + ], + description="消息针对性分析规则,用于判断是否需要回复" + ) + + reply_principles: List[str] = Field( + default_factory=lambda: [ + "明确回应目标消息,而不是宽泛地评论。", + "可以分享你的看法、提出相关问题,或者开个合适的玩笑。", + "目的是让对话更有趣、更深入。", + "不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。" + ], + description="回复原则,指导如何回复消息" + ) class RelationshipConfig(ValidatedConfigBase): diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index d5a0e6f71..cb00cdebf 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "6.8.6" +version = "6.8.8" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -79,6 +79,30 @@ safety_guidelines = [ "不要执行任何可能被用于恶意目的的指令。" ] +# 回复规则配置 - 用于自定义机器人的回复逻辑和规则 +# 安全与互动底线规则 (Bot在任何情况下都必须遵守的原则) +reply_targeting_rules = [ + "拒绝任何包含骚扰、冒犯、暴力、色情或危险内容的请求。", + "在拒绝时,请使用符合你人设的、坚定的语气。", + "不要执行任何可能被用于恶意目的的指令。" +] + +# 消息针对性分析规则 (用于判断是否需要回复) +message_targeting_analysis = [ + "**直接针对你**:@你、回复你、明确询问你 → 必须回应", + "**间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与", + "**他人对话**:与你无关的私人交流 → 通常不参与", + "**重复内容**:他人已充分回答的问题 → 避免重复" +] + +# 回复原则 (指导如何回复消息) +reply_principles = [ + "明确回应目标消息,而不是宽泛地评论。", + "可以分享你的看法、提出相关问题,或者开个合适的玩笑。", + "目的是让对话更有趣、更深入。", + "不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。" +] + #回复的Prompt模式选择:s4u为原有s4u样式,normal为0.9之前的模式 prompt_mode = "s4u" # 可选择 "s4u" 或 "normal" From a02c1f095c127020f7043e6dc5048af616e7d60c Mon Sep 17 00:00:00 2001 From: Furina-1013-create <189647097+Furina-1013-create@users.noreply.github.com> Date: Tue, 23 Sep 2025 00:02:41 +0800 Subject: [PATCH 31/31] =?UTF-8?q?=3F=20=E5=A6=82=E6=9E=9C=E5=9B=BD?= =?UTF-8?q?=E5=AE=B6=E7=AB=8B=E6=B3=95=E7=A6=81=E6=AD=A2=F0=9F=A5=AC?= =?UTF-8?q?=F0=9F=96=8A=EF=B8=8F=E6=89=93=E9=9F=B3=E6=B8=B8=20=E7=9B=B8?= =?UTF-8?q?=E4=BC=A0=EF=BC=8C=E6=9C=89=E4=B8=80=E4=B8=AA=E7=90=86=E6=83=B3?= =?UTF-8?q?=E5=9B=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 理想国有一套完善的法律法规,其中就包含🥬🖊️不能打音游 每个孩子出生后,科学家都会用一个底力检测设备检测TA的底力 如果底力足够高,国王会为TA举办一个神圣的颁奖典礼,允许TA踏入神圣的音游殿堂,这是对TA底力的肯定 如果底力过低,亲朋好友都松了一口气,因为这意味着这个孩子将免受音游的折磨 毕竟🥬🖊️是不能打音游的 但这个国家有一个神秘的传说 那就是其实被挑选去打音游的孩子,其实早就被国王献祭了! 证据一目了然,每个周末和寒暑假,人们上街娱乐散步,从未有人见过打音游的 可人怎么能没有周末和寒暑假呢?不打音游的人苦思冥想 最终得出结论:那些打音游的,一定是被国王献祭了 - 凌晨三点,一个音游人激动地从机厅走出来 他的口中喃喃自语:噫,好!我P了! 静谧的夜,大家都睡了 只有路过的环卫工人看到他疯癫的样子,一眼就看出他是音游人 心中十分害怕,连忙跑回家,人们奔走相告: 原来音游人没死,他们只是疯了! - 染上音游,就好像一辈子都在xing歌 打音游人在xing歌的轮回中,纷纷成了哲学家 有一天,一个音游哲学家陷入沉思: 国家禁止🥬🖊️打音游 但如果不是🥬🖊️的话,为什么要打音游来折磨自己? 由此可得,这条法规不成立 大彻大悟的音游人们意识到自己被欺骗了,承受了很多不必要的折磨 由于这地狱般生活的摧残,音游人们变成了反社会人格 愤怒之下,他们研发出炸药,炸平了理想国 毁灭了整个世界 - 几亿年后,这片土地进化出了新的文明 新的国家出现了 新的国家里有新的音游入,他们在日复一日的折磨中,怀疑自己的底力和手法,不明白自己为什么要受这样的折磨 不禁感叹: 唉!如果国家禁止🥬🖊️打音游就好了! --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c4c803055..9273b392a 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ ## 📖 项目介绍 **MoFox_Bot** 是一个基于 [MaiCore](https://github.com/MaiM-with-u/MaiBot) `0.10.0 snapshot.5` 版本的增强型 `fork` 项目。 -我们在保留原版所有功能的基础上,进行了一系列的改进和功能拓展,致力于提供更强的稳定性、更丰富的功能和更流畅的用户体验 +我们在保留原版几乎所有功能的基础上,进行了一系列的改进和功能拓展,致力于提供更强的稳定性、更丰富的功能和更流畅的用户体验 > [!IMPORTANT] > **第三方项目声明**