import traceback import time from typing import Dict, List, Any, Union, Type from src.common.logger import get_logger from src.common.database.database_model import ActionRecords from src.common.database.database import db from peewee import Model, DoesNotExist logger = get_logger("database_api") class DatabaseAPI: """数据库API模块 提供了数据库操作相关的功能 """ async def store_action_info( self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True, thinking_id: str = "", action_data: dict = None, ) -> None: """存储action信息到数据库 Args: action_build_into_prompt: 是否构建到提示中 action_prompt_display: 显示的action提示信息 action_done: action是否完成 thinking_id: 思考ID action_data: action数据,如果不提供则使用空字典 """ try: chat_stream = self.get_service("chat_stream") if not chat_stream: logger.error(f"{self.log_prefix} 无法存储action信息:缺少chat_stream服务") return action_time = time.time() action_id = f"{action_time}_{thinking_id}" ActionRecords.create( action_id=action_id, time=action_time, action_name=self.__class__.__name__, action_data=str(action_data or {}), action_done=action_done, action_build_into_prompt=action_build_into_prompt, action_prompt_display=action_prompt_display, chat_id=chat_stream.stream_id, chat_info_stream_id=chat_stream.stream_id, chat_info_platform=chat_stream.platform, user_id=chat_stream.user_info.user_id if chat_stream.user_info else "", user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "", user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "", ) logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}") except Exception as e: logger.error(f"{self.log_prefix} 存储action信息时出错: {e}") traceback.print_exc() async def db_query( self, model_class: Type[Model], query_type: str = "get", filters: Dict[str, Any] = None, data: Dict[str, Any] = None, limit: int = None, order_by: List[str] = None, single_result: bool = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: """执行数据库查询操作 这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。 Args: model_class: Peewee 模型类,例如 ActionRecords, Messages 等 query_type: 查询类型,可选值: "get", "create", "update", "delete", "count" filters: 过滤条件字典,键为字段名,值为要匹配的值 data: 用于创建或更新的数据字典 limit: 限制结果数量 order_by: 排序字段列表,使用字段名,前缀'-'表示降序 single_result: 是否只返回单个结果 Returns: 根据查询类型返回不同的结果: - "get": 返回查询结果列表或单个结果(如果 single_result=True) - "create": 返回创建的记录 - "update": 返回受影响的行数 - "delete": 返回受影响的行数 - "count": 返回记录数量 示例: # 查询最近10条消息 messages = await self.db_query( Messages, query_type="get", filters={"chat_id": chat_stream.stream_id}, limit=10, order_by=["-time"] ) # 创建一条记录 new_record = await self.db_query( ActionRecords, query_type="create", data={"action_id": "123", "time": time.time(), "action_name": "TestAction"} ) # 更新记录 updated_count = await self.db_query( ActionRecords, query_type="update", filters={"action_id": "123"}, data={"action_done": True} ) # 删除记录 deleted_count = await self.db_query( ActionRecords, query_type="delete", filters={"action_id": "123"} ) # 计数 count = await self.db_query( Messages, query_type="count", filters={"chat_id": chat_stream.stream_id} ) """ try: if query_type not in ["get", "create", "update", "delete", "count"]: raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'") # 构建基本查询 if query_type in ["get", "update", "delete", "count"]: query = model_class.select() # 应用过滤条件 if filters: for field, value in filters.items(): query = query.where(getattr(model_class, field) == value) # 执行查询 if query_type == "get": # 应用排序 if order_by: for field in order_by: if field.startswith("-"): query = query.order_by(getattr(model_class, field[1:]).desc()) else: query = query.order_by(getattr(model_class, field)) # 应用限制 if limit: query = query.limit(limit) # 执行查询 results = list(query.dicts()) # 返回结果 if single_result: return results[0] if results else None return results elif query_type == "create": if not data: raise ValueError("创建记录需要提供data参数") # 创建记录 record = model_class.create(**data) # 返回创建的记录 return model_class.select().where(model_class.id == record.id).dicts().get() elif query_type == "update": if not data: raise ValueError("更新记录需要提供data参数") # 更新记录 return query.update(**data).execute() elif query_type == "delete": # 删除记录 return query.delete().execute() elif query_type == "count": # 计数 return query.count() else: raise ValueError(f"不支持的查询类型: {query_type}") except DoesNotExist: # 记录不存在 if query_type == "get" and single_result: return None return [] except Exception as e: logger.error(f"{self.log_prefix} 数据库操作出错: {e}") traceback.print_exc() # 根据查询类型返回合适的默认值 if query_type == "get": return None if single_result else [] elif query_type in ["create", "update", "delete", "count"]: return None return None async def db_raw_query( self, sql: str, params: List[Any] = None, fetch_results: bool = True ) -> Union[List[Dict[str, Any]], int, None]: """执行原始SQL查询 警告: 使用此方法需要小心,确保SQL语句已正确构造以避免SQL注入风险。 Args: sql: 原始SQL查询字符串 params: 查询参数列表,用于替换SQL中的占位符 fetch_results: 是否获取查询结果,对于SELECT查询设为True,对于 UPDATE/INSERT/DELETE等操作设为False Returns: 如果fetch_results为True,返回查询结果列表; 如果fetch_results为False,返回受影响的行数; 如果出错,返回None """ try: cursor = db.execute_sql(sql, params or []) if fetch_results: # 获取列名 columns = [col[0] for col in cursor.description] # 构建结果字典列表 results = [] for row in cursor.fetchall(): results.append(dict(zip(columns, row))) return results else: # 返回受影响的行数 return cursor.rowcount except Exception as e: logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}") traceback.print_exc() return None async def db_save( self, model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None ) -> Union[Dict[str, Any], None]: """保存数据到数据库(创建或更新) 如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新; 如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。 Args: model_class: Peewee模型类,如ActionRecords, Messages等 data: 要保存的数据字典 key_field: 用于查找现有记录的字段名,例如"action_id" key_value: 用于查找现有记录的字段值 Returns: Dict[str, Any]: 保存后的记录数据 None: 如果操作失败 示例: # 创建或更新一条记录 record = await self.db_save( ActionRecords, { "action_id": "123", "time": time.time(), "action_name": "TestAction", "action_done": True }, key_field="action_id", key_value="123" ) """ try: # 如果提供了key_field和key_value,尝试更新现有记录 if key_field and key_value is not None: # 查找现有记录 existing_records = list( model_class.select().where(getattr(model_class, key_field) == key_value).limit(1) ) if existing_records: # 更新现有记录 existing_record = existing_records[0] for field, value in data.items(): setattr(existing_record, field, value) existing_record.save() # 返回更新后的记录 updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() return updated_record # 如果没有找到现有记录或未提供key_field和key_value,创建新记录 new_record = model_class.create(**data) # 返回创建的记录 created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() return created_record except Exception as e: logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}") traceback.print_exc() return None async def db_get( self, model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: """从数据库获取记录 这是db_query方法的简化版本,专注于数据检索操作。 Args: model_class: Peewee模型类 filters: 过滤条件,字段名和值的字典 order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序 limit: 结果数量限制,如果为1则返回单个记录而不是列表 Returns: 如果limit=1,返回单个记录字典或None; 否则返回记录字典列表或空列表。 示例: # 获取单个记录 record = await self.db_get( ActionRecords, filters={"action_id": "123"}, limit=1 ) # 获取最近10条记录 records = await self.db_get( Messages, filters={"chat_id": chat_stream.stream_id}, order_by="-time", limit=10 ) """ try: # 构建查询 query = model_class.select() # 应用过滤条件 if filters: for field, value in filters.items(): query = query.where(getattr(model_class, field) == value) # 应用排序 if order_by: if order_by.startswith("-"): query = query.order_by(getattr(model_class, order_by[1:]).desc()) else: query = query.order_by(getattr(model_class, order_by)) # 应用限制 if limit: query = query.limit(limit) # 执行查询 results = list(query.dicts()) # 返回结果 if limit == 1: return results[0] if results else None return results except Exception as e: logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}") traceback.print_exc() return None if limit == 1 else []