Files
Mofox-Core/src/plugin_system/apis/database_api.py

387 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""数据库API模块
提供数据库操作相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import database_api
records = await database_api.db_query(ActionRecords, query_type="get")
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
"""
import traceback
from typing import Dict, List, Any, Union, Type
from src.common.logger import get_logger
from peewee import Model, DoesNotExist
logger = get_logger("database_api")
# =============================================================================
# 通用数据库查询API函数
# =============================================================================
async def db_query(
model_class: Type[Model],
query_type: str = "get",
filters: Dict[str, Any] = None,
data: Dict[str, Any] = None,
limit: int = None,
order_by: List[str] = None,
single_result: bool = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
Args:
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典,键为字段名,值为要匹配的值
data: 用于创建或更新的数据字典
limit: 限制结果数量
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回不同的结果:
- "get": 返回查询结果列表或单个结果(如果 single_result=True
- "create": 返回创建的记录
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
示例:
# 查询最近10条消息
messages = await database_api.db_query(
Messages,
query_type="get",
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by=["-time"]
)
# 创建一条记录
new_record = await database_api.db_query(
ActionRecords,
query_type="create",
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
)
# 更新记录
updated_count = await database_api.db_query(
ActionRecords,
query_type="update",
filters={"action_id": "123"},
data={"action_done": True}
)
# 删除记录
deleted_count = await database_api.db_query(
ActionRecords,
query_type="delete",
filters={"action_id": "123"}
)
# 计数
count = await database_api.db_query(
Messages,
query_type="count",
filters={"chat_id": chat_stream.stream_id}
)
"""
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
# 构建基本查询
if query_type in ["get", "update", "delete", "count"]:
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 执行查询
if query_type == "get":
# 应用排序
if order_by:
for field in order_by:
if field.startswith("-"):
query = query.order_by(getattr(model_class, field[1:]).desc())
else:
query = query.order_by(getattr(model_class, field))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get()
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
# 更新记录
return query.update(**data).execute()
elif query_type == "delete":
# 删除记录
return query.delete().execute()
elif query_type == "count":
# 计数
return query.count()
else:
raise ValueError(f"不支持的查询类型: {query_type}")
except DoesNotExist:
# 记录不存在
if query_type == "get" and single_result:
return None
return []
except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
) -> Union[Dict[str, Any], None]:
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录。
Args:
model_class: Peewee模型类如ActionRecords, Messages等
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名,例如"action_id"
key_value: 用于查找现有记录的字段值
Returns:
Dict[str, Any]: 保存后的记录数据
None: 如果操作失败
示例:
# 创建或更新一条记录
record = await database_api.db_save(
ActionRecords,
{
"action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
key_field="action_id",
key_value="123"
)
"""
try:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
# 查找现有记录
existing_records = list(model_class.select().where(getattr(model_class, key_field) == key_value).limit(1))
if existing_records:
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items():
setattr(existing_record, field, value)
existing_record.save()
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
return created_record
except Exception as e:
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
traceback.print_exc()
return None
async def db_get(
model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作。
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
limit: 结果数量限制如果为1则返回单个记录而不是列表
Returns:
如果limit=1返回单个记录字典或None
否则返回记录字典列表或空列表。
示例:
# 获取单个记录
record = await database_api.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
# 获取最近10条记录
records = await database_api.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
order_by="-time",
limit=10
)
"""
try:
# 构建查询
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 应用排序
if order_by:
if order_by.startswith("-"):
query = query.order_by(getattr(model_class, order_by[1:]).desc())
else:
query = query.order_by(getattr(model_class, order_by))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if limit == 1:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if limit == 1 else []
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: dict = None,
action_name: str = "",
) -> Union[Dict[str, Any], None]:
"""存储动作信息到数据库
将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建。
Args:
chat_stream: 聊天流对象,包含聊天相关信息
action_build_into_prompt: 是否将此动作构建到提示中
action_prompt_display: 动作的提示显示文本
action_done: 动作是否完成
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
Returns:
Dict[str, Any]: 保存的记录数据
None: 如果保存失败
示例:
record = await database_api.store_action_info(
chat_stream=chat_stream,
action_build_into_prompt=True,
action_prompt_display="执行了回复动作",
action_done=True,
thinking_id="thinking_123",
action_data={"content": "Hello"},
action_name="reply_action"
)
"""
try:
import time
import json
from src.common.database.database_model import ActionRecords
# 构建动作记录数据
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID
"time": time.time(),
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}
# 从chat_stream获取聊天信息
if chat_stream:
record_data.update(
{
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
}
)
else:
# 如果没有chat_stream设置默认值
record_data.update(
{
"chat_id": "",
"chat_info_stream_id": "",
"chat_info_platform": "",
}
)
# 使用已有的db_save函数保存记录
saved_record = await db_save(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
traceback.print_exc()
return None