"""数据库API模块 提供数据库操作相关功能,采用标准Python包设计模式 使用方式: from src.plugin_system.apis import database_api records = await database_api.db_query(ActionRecords, query_type="get") record = await database_api.db_save(ActionRecords, data={"action_id": "123"}) """ import traceback from typing import Dict, List, Any, Union, Type from src.common.logger import get_logger from peewee import Model, DoesNotExist logger = get_logger("database_api") # ============================================================================= # 通用数据库查询API函数 # ============================================================================= async def db_query( model_class: Type[Model], query_type: str = "get", filters: Dict[str, Any] = None, data: Dict[str, Any] = None, limit: int = None, order_by: List[str] = None, single_result: bool = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: """执行数据库查询操作 这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。 Args: model_class: Peewee 模型类,例如 ActionRecords, Messages 等 query_type: 查询类型,可选值: "get", "create", "update", "delete", "count" filters: 过滤条件字典,键为字段名,值为要匹配的值 data: 用于创建或更新的数据字典 limit: 限制结果数量 order_by: 排序字段列表,使用字段名,前缀'-'表示降序 single_result: 是否只返回单个结果 Returns: 根据查询类型返回不同的结果: - "get": 返回查询结果列表或单个结果(如果 single_result=True) - "create": 返回创建的记录 - "update": 返回受影响的行数 - "delete": 返回受影响的行数 - "count": 返回记录数量 示例: # 查询最近10条消息 messages = await database_api.db_query( Messages, query_type="get", filters={"chat_id": chat_stream.stream_id}, limit=10, order_by=["-time"] ) # 创建一条记录 new_record = await database_api.db_query( ActionRecords, query_type="create", data={"action_id": "123", "time": time.time(), "action_name": "TestAction"} ) # 更新记录 updated_count = await database_api.db_query( ActionRecords, query_type="update", filters={"action_id": "123"}, data={"action_done": True} ) # 删除记录 deleted_count = await database_api.db_query( ActionRecords, query_type="delete", filters={"action_id": "123"} ) # 计数 count = await database_api.db_query( Messages, query_type="count", filters={"chat_id": chat_stream.stream_id} ) """ try: if query_type not in ["get", "create", "update", "delete", "count"]: raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'") # 构建基本查询 if query_type in ["get", "update", "delete", "count"]: query = model_class.select() # 应用过滤条件 if filters: for field, value in filters.items(): query = query.where(getattr(model_class, field) == value) # 执行查询 if query_type == "get": # 应用排序 if order_by: for field in order_by: if field.startswith("-"): query = query.order_by(getattr(model_class, field[1:]).desc()) else: query = query.order_by(getattr(model_class, field)) # 应用限制 if limit: query = query.limit(limit) # 执行查询 results = list(query.dicts()) # 返回结果 if single_result: return results[0] if results else None return results elif query_type == "create": if not data: raise ValueError("创建记录需要提供data参数") # 创建记录 record = model_class.create(**data) # 返回创建的记录 return model_class.select().where(model_class.id == record.id).dicts().get() elif query_type == "update": if not data: raise ValueError("更新记录需要提供data参数") # 更新记录 return query.update(**data).execute() elif query_type == "delete": # 删除记录 return query.delete().execute() elif query_type == "count": # 计数 return query.count() else: raise ValueError(f"不支持的查询类型: {query_type}") except DoesNotExist: # 记录不存在 if query_type == "get" and single_result: return None return [] except Exception as e: logger.error(f"[DatabaseAPI] 数据库操作出错: {e}") traceback.print_exc() # 根据查询类型返回合适的默认值 if query_type == "get": return None if single_result else [] elif query_type in ["create", "update", "delete", "count"]: return None return None async def db_save( model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None ) -> Union[Dict[str, Any], None]: """保存数据到数据库(创建或更新) 如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新; 如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。 Args: model_class: Peewee模型类,如ActionRecords, Messages等 data: 要保存的数据字典 key_field: 用于查找现有记录的字段名,例如"action_id" key_value: 用于查找现有记录的字段值 Returns: Dict[str, Any]: 保存后的记录数据 None: 如果操作失败 示例: # 创建或更新一条记录 record = await database_api.db_save( ActionRecords, { "action_id": "123", "time": time.time(), "action_name": "TestAction", "action_done": True }, key_field="action_id", key_value="123" ) """ try: # 如果提供了key_field和key_value,尝试更新现有记录 if key_field and key_value is not None: # 查找现有记录 existing_records = list(model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)) if existing_records: # 更新现有记录 existing_record = existing_records[0] for field, value in data.items(): setattr(existing_record, field, value) existing_record.save() # 返回更新后的记录 updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() return updated_record # 如果没有找到现有记录或未提供key_field和key_value,创建新记录 new_record = model_class.create(**data) # 返回创建的记录 created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() return created_record except Exception as e: logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}") traceback.print_exc() return None async def db_get( model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: """从数据库获取记录 这是db_query方法的简化版本,专注于数据检索操作。 Args: model_class: Peewee模型类 filters: 过滤条件,字段名和值的字典 order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序 limit: 结果数量限制,如果为1则返回单个记录而不是列表 Returns: 如果limit=1,返回单个记录字典或None; 否则返回记录字典列表或空列表。 示例: # 获取单个记录 record = await database_api.db_get( ActionRecords, filters={"action_id": "123"}, limit=1 ) # 获取最近10条记录 records = await database_api.db_get( Messages, filters={"chat_id": chat_stream.stream_id}, order_by="-time", limit=10 ) """ try: # 构建查询 query = model_class.select() # 应用过滤条件 if filters: for field, value in filters.items(): query = query.where(getattr(model_class, field) == value) # 应用排序 if order_by: if order_by.startswith("-"): query = query.order_by(getattr(model_class, order_by[1:]).desc()) else: query = query.order_by(getattr(model_class, order_by)) # 应用限制 if limit: query = query.limit(limit) # 执行查询 results = list(query.dicts()) # 返回结果 if limit == 1: return results[0] if results else None return results except Exception as e: logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}") traceback.print_exc() return None if limit == 1 else [] async def store_action_info( chat_stream=None, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True, thinking_id: str = "", action_data: dict = None, action_name: str = "", ) -> Union[Dict[str, Any], None]: """存储动作信息到数据库 将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。 Args: chat_stream: 聊天流对象,包含聊天相关信息 action_build_into_prompt: 是否将此动作构建到提示中 action_prompt_display: 动作的提示显示文本 action_done: 动作是否完成 thinking_id: 关联的思考ID action_data: 动作数据字典 action_name: 动作名称 Returns: Dict[str, Any]: 保存的记录数据 None: 如果保存失败 示例: record = await database_api.store_action_info( chat_stream=chat_stream, action_build_into_prompt=True, action_prompt_display="执行了回复动作", action_done=True, thinking_id="thinking_123", action_data={"content": "Hello"}, action_name="reply_action" ) """ try: import time import json from src.common.database.database_model import ActionRecords # 构建动作记录数据 record_data = { "action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID "time": time.time(), "action_name": action_name, "action_data": json.dumps(action_data or {}, ensure_ascii=False), "action_done": action_done, "action_build_into_prompt": action_build_into_prompt, "action_prompt_display": action_prompt_display, } # 从chat_stream获取聊天信息 if chat_stream: record_data.update( { "chat_id": getattr(chat_stream, "stream_id", ""), "chat_info_stream_id": getattr(chat_stream, "stream_id", ""), "chat_info_platform": getattr(chat_stream, "platform", ""), } ) else: # 如果没有chat_stream,设置默认值 record_data.update( { "chat_id": "", "chat_info_stream_id": "", "chat_info_platform": "", } ) # 使用已有的db_save函数保存记录 saved_record = await db_save( ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"] ) if saved_record: logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})") else: logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}") return saved_record except Exception as e: logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}") traceback.print_exc() return None