"""SQLAlchemy数据库API模块 提供基于SQLAlchemy的数据库操作,替换Peewee以解决MySQL连接问题 支持自动重连、连接池管理和更好的错误处理 """ import time import traceback from typing import Dict, List, Any, Union, Optional from sqlalchemy import desc, asc, func, and_, select from sqlalchemy.exc import SQLAlchemyError from src.common.database.sqlalchemy_models import ( get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams, LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory, Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus, CacheEntries, UserRelationships, ) from src.common.logger import get_logger logger = get_logger("sqlalchemy_database_api") # 模型映射表,用于通过名称获取模型类 MODEL_MAPPING = { "Messages": Messages, "ActionRecords": ActionRecords, "PersonInfo": PersonInfo, "ChatStreams": ChatStreams, "LLMUsage": LLMUsage, "Emoji": Emoji, "Images": Images, "ImageDescriptions": ImageDescriptions, "OnlineTime": OnlineTime, "Memory": Memory, "Expression": Expression, "ThinkingLog": ThinkingLog, "GraphNodes": GraphNodes, "GraphEdges": GraphEdges, "Schedule": Schedule, "MaiZoneScheduleStatus": MaiZoneScheduleStatus, "CacheEntries": CacheEntries, "UserRelationships": UserRelationships, } async def build_filters(model_class, filters: Dict[str, Any]): """构建查询过滤条件""" conditions = [] for field_name, value in filters.items(): if not hasattr(model_class, field_name): logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'") continue field = getattr(model_class, field_name) if isinstance(value, dict): # 处理 MongoDB 风格的操作符 for op, op_value in value.items(): if op == "$gt": conditions.append(field > op_value) elif op == "$lt": conditions.append(field < op_value) elif op == "$gte": conditions.append(field >= op_value) elif op == "$lte": conditions.append(field <= op_value) elif op == "$ne": conditions.append(field != op_value) elif op == "$in": conditions.append(field.in_(op_value)) elif op == "$nin": conditions.append(~field.in_(op_value)) else: logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')") else: # 直接相等比较 conditions.append(field == value) return conditions async def db_query( model_class, data: Optional[Dict[str, Any]] = None, query_type: Optional[str] = "get", filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, order_by: Optional[List[str]] = None, single_result: Optional[bool] = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: """执行异步数据库查询操作 Args: model_class: SQLAlchemy模型类 data: 用于创建或更新的数据字典 query_type: 查询类型 ("get", "create", "update", "delete", "count") filters: 过滤条件字典 limit: 限制结果数量 order_by: 排序字段,前缀'-'表示降序 single_result: 是否只返回单个结果 Returns: 根据查询类型返回相应结果 """ try: if query_type not in ["get", "create", "update", "delete", "count"]: raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'") async with get_db_session() as session: if query_type == "get": query = select(model_class) # 应用过滤条件 if filters: conditions = await build_filters(model_class, filters) if conditions: query = query.where(and_(*conditions)) # 应用排序 if order_by: for field_name in order_by: if field_name.startswith("-"): field_name = field_name[1:] if hasattr(model_class, field_name): query = query.order_by(desc(getattr(model_class, field_name))) else: if hasattr(model_class, field_name): query = query.order_by(asc(getattr(model_class, field_name))) # 应用限制 if limit and limit > 0: query = query.limit(limit) # 执行查询 result = await session.execute(query) results = result.scalars().all() # 转换为字典格式 result_dicts = [] for result_obj in results: result_dict = {} for column in result_obj.__table__.columns: result_dict[column.name] = getattr(result_obj, column.name) result_dicts.append(result_dict) if single_result: return result_dicts[0] if result_dicts else None return result_dicts elif query_type == "create": if not data: raise ValueError("创建记录需要提供data参数") # 创建新记录 new_record = model_class(**data) session.add(new_record) await session.flush() # 获取自动生成的ID # 转换为字典格式返回 result_dict = {} for column in new_record.__table__.columns: result_dict[column.name] = getattr(new_record, column.name) return result_dict elif query_type == "update": if not data: raise ValueError("更新记录需要提供data参数") query = select(model_class) # 应用过滤条件 if filters: conditions = await build_filters(model_class, filters) if conditions: query = query.where(and_(*conditions)) # 首先获取要更新的记录 result = await session.execute(query) records_to_update = result.scalars().all() # 更新每个记录 affected_rows = 0 for record in records_to_update: for field, value in data.items(): if hasattr(record, field): setattr(record, field, value) affected_rows += 1 return affected_rows elif query_type == "delete": query = select(model_class) # 应用过滤条件 if filters: conditions = await build_filters(model_class, filters) if conditions: query = query.where(and_(*conditions)) # 首先获取要删除的记录 result = await session.execute(query) records_to_delete = result.scalars().all() # 删除记录 affected_rows = 0 for record in records_to_delete: session.delete(record) affected_rows += 1 return affected_rows elif query_type == "count": query = select(func.count(model_class.id)) # 应用过滤条件 if filters: conditions = await build_filters(model_class, filters) if conditions: query = query.where(and_(*conditions)) result = await session.execute(query) return result.scalar() except SQLAlchemyError as e: logger.error(f"[SQLAlchemy] 数据库操作出错: {e}") traceback.print_exc() # 根据查询类型返回合适的默认值 if query_type == "get": return None if single_result else [] elif query_type in ["create", "update", "delete", "count"]: return None return None except Exception as e: logger.error(f"[SQLAlchemy] 意外错误: {e}") traceback.print_exc() if query_type == "get": return None if single_result else [] return None async def db_save( model_class, data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None ) -> Optional[Dict[str, Any]]: """异步保存数据到数据库(创建或更新) Args: model_class: SQLAlchemy模型类 data: 要保存的数据字典 key_field: 用于查找现有记录的字段名 key_value: 用于查找现有记录的字段值 Returns: 保存后的记录数据或None """ try: async with get_db_session() as session: # 如果提供了key_field和key_value,尝试更新现有记录 if key_field and key_value is not None: if hasattr(model_class, key_field): query = select(model_class).where(getattr(model_class, key_field) == key_value) result = await session.execute(query) existing_record = result.scalars().first() if existing_record: # 更新现有记录 for field, value in data.items(): if hasattr(existing_record, field): setattr(existing_record, field, value) await session.flush() # 转换为字典格式返回 result_dict = {} for column in existing_record.__table__.columns: result_dict[column.name] = getattr(existing_record, column.name) return result_dict # 创建新记录 new_record = model_class(**data) session.add(new_record) await session.flush() # 转换为字典格式返回 result_dict = {} for column in new_record.__table__.columns: result_dict[column.name] = getattr(new_record, column.name) return result_dict except SQLAlchemyError as e: logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}") traceback.print_exc() return None except Exception as e: logger.error(f"[SQLAlchemy] 保存时意外错误: {e}") traceback.print_exc() return None async def db_get( model_class, filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, order_by: Optional[str] = None, single_result: Optional[bool] = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: """异步从数据库获取记录 Args: model_class: SQLAlchemy模型类 filters: 过滤条件 limit: 结果数量限制 order_by: 排序字段,前缀'-'表示降序 single_result: 是否只返回单个结果 Returns: 记录数据或None """ order_by_list = [order_by] if order_by else None return await db_query( model_class=model_class, query_type="get", filters=filters, limit=limit, order_by=order_by_list, single_result=single_result, ) async def store_action_info( chat_stream=None, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True, thinking_id: str = "", action_data: Optional[dict] = None, action_name: str = "", ) -> Optional[Dict[str, Any]]: """异步存储动作信息到数据库 Args: chat_stream: 聊天流对象 action_build_into_prompt: 是否将此动作构建到提示中 action_prompt_display: 动作的提示显示文本 action_done: 动作是否完成 thinking_id: 关联的思考ID action_data: 动作数据字典 action_name: 动作名称 Returns: 保存的记录数据或None """ try: import orjson # 构建动作记录数据 record_data = { "action_id": thinking_id or str(int(time.time() * 1000000)), "time": time.time(), "action_name": action_name, "action_data": orjson.dumps(action_data or {}).decode("utf-8"), "action_done": action_done, "action_build_into_prompt": action_build_into_prompt, "action_prompt_display": action_prompt_display, } # 从chat_stream获取聊天信息 if chat_stream: record_data.update( { "chat_id": getattr(chat_stream, "stream_id", ""), "chat_info_stream_id": getattr(chat_stream, "stream_id", ""), "chat_info_platform": getattr(chat_stream, "platform", ""), } ) else: record_data.update( { "chat_id": "", "chat_info_stream_id": "", "chat_info_platform": "", } ) # 保存记录 saved_record = await db_save( ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"] ) if saved_record: logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})") else: logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}") return saved_record except Exception as e: logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}") traceback.print_exc() return None