Files
Mofox-Core/src/plugin_system/apis/database_api.py
2025-06-11 15:17:08 +09:00

380 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import traceback
import time
from typing import Dict, List, Any, Union, Type
from src.common.logger import get_logger
from src.common.database.database_model import ActionRecords
from src.common.database.database import db
from peewee import Model, DoesNotExist
logger = get_logger("database_api")
class DatabaseAPI:
"""数据库API模块
提供了数据库操作相关的功能
"""
async def store_action_info(
self,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: dict = None,
) -> None:
"""存储action信息到数据库
Args:
action_build_into_prompt: 是否构建到提示中
action_prompt_display: 显示的action提示信息
action_done: action是否完成
thinking_id: 思考ID
action_data: action数据如果不提供则使用空字典
"""
try:
chat_stream = self.get_service("chat_stream")
if not chat_stream:
logger.error(f"{self.log_prefix} 无法存储action信息缺少chat_stream服务")
return
action_time = time.time()
action_id = f"{action_time}_{thinking_id}"
ActionRecords.create(
action_id=action_id,
time=action_time,
action_name=self.__class__.__name__,
action_data=str(action_data or {}),
action_done=action_done,
action_build_into_prompt=action_build_into_prompt,
action_prompt_display=action_prompt_display,
chat_id=chat_stream.stream_id,
chat_info_stream_id=chat_stream.stream_id,
chat_info_platform=chat_stream.platform,
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "",
)
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
except Exception as e:
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
traceback.print_exc()
async def db_query(
self,
model_class: Type[Model],
query_type: str = "get",
filters: Dict[str, Any] = None,
data: Dict[str, Any] = None,
limit: int = None,
order_by: List[str] = None,
single_result: bool = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
Args:
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典,键为字段名,值为要匹配的值
data: 用于创建或更新的数据字典
limit: 限制结果数量
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回不同的结果:
- "get": 返回查询结果列表或单个结果(如果 single_result=True
- "create": 返回创建的记录
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
示例:
# 查询最近10条消息
messages = await self.db_query(
Messages,
query_type="get",
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by=["-time"]
)
# 创建一条记录
new_record = await self.db_query(
ActionRecords,
query_type="create",
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
)
# 更新记录
updated_count = await self.db_query(
ActionRecords,
query_type="update",
filters={"action_id": "123"},
data={"action_done": True}
)
# 删除记录
deleted_count = await self.db_query(
ActionRecords,
query_type="delete",
filters={"action_id": "123"}
)
# 计数
count = await self.db_query(
Messages,
query_type="count",
filters={"chat_id": chat_stream.stream_id}
)
"""
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
# 构建基本查询
if query_type in ["get", "update", "delete", "count"]:
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 执行查询
if query_type == "get":
# 应用排序
if order_by:
for field in order_by:
if field.startswith("-"):
query = query.order_by(getattr(model_class, field[1:]).desc())
else:
query = query.order_by(getattr(model_class, field))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get()
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
# 更新记录
return query.update(**data).execute()
elif query_type == "delete":
# 删除记录
return query.delete().execute()
elif query_type == "count":
# 计数
return query.count()
else:
raise ValueError(f"不支持的查询类型: {query_type}")
except DoesNotExist:
# 记录不存在
if query_type == "get" and single_result:
return None
return []
except Exception as e:
logger.error(f"{self.log_prefix} 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
async def db_raw_query(
self, sql: str, params: List[Any] = None, fetch_results: bool = True
) -> Union[List[Dict[str, Any]], int, None]:
"""执行原始SQL查询
警告: 使用此方法需要小心确保SQL语句已正确构造以避免SQL注入风险。
Args:
sql: 原始SQL查询字符串
params: 查询参数列表用于替换SQL中的占位符
fetch_results: 是否获取查询结果对于SELECT查询设为True对于
UPDATE/INSERT/DELETE等操作设为False
Returns:
如果fetch_results为True返回查询结果列表
如果fetch_results为False返回受影响的行数
如果出错返回None
"""
try:
cursor = db.execute_sql(sql, params or [])
if fetch_results:
# 获取列名
columns = [col[0] for col in cursor.description]
# 构建结果字典列表
results = []
for row in cursor.fetchall():
results.append(dict(zip(columns, row)))
return results
else:
# 返回受影响的行数
return cursor.rowcount
except Exception as e:
logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}")
traceback.print_exc()
return None
async def db_save(
self, model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
) -> Union[Dict[str, Any], None]:
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录。
Args:
model_class: Peewee模型类如ActionRecords, Messages等
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名,例如"action_id"
key_value: 用于查找现有记录的字段值
Returns:
Dict[str, Any]: 保存后的记录数据
None: 如果操作失败
示例:
# 创建或更新一条记录
record = await self.db_save(
ActionRecords,
{
"action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
key_field="action_id",
key_value="123"
)
"""
try:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
# 查找现有记录
existing_records = list(
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
)
if existing_records:
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items():
setattr(existing_record, field, value)
existing_record.save()
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
return created_record
except Exception as e:
logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}")
traceback.print_exc()
return None
async def db_get(
self, model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作。
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
limit: 结果数量限制如果为1则返回单个记录而不是列表
Returns:
如果limit=1返回单个记录字典或None
否则返回记录字典列表或空列表。
示例:
# 获取单个记录
record = await self.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
# 获取最近10条记录
records = await self.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
order_by="-time",
limit=10
)
"""
try:
# 构建查询
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 应用排序
if order_by:
if order_by.startswith("-"):
query = query.order_by(getattr(model_class, order_by[1:]).desc())
else:
query = query.order_by(getattr(model_class, order_by))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if limit == 1:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}")
traceback.print_exc()
return None if limit == 1 else []