fix typing, api change

This commit is contained in:
UnCLASPrommer
2025-07-15 00:57:43 +08:00
parent eae399fb95
commit af02f2ab57
5 changed files with 110 additions and 77 deletions

View File

@@ -8,7 +8,7 @@
"""
import traceback
from typing import Dict, List, Any, Union, Type
from typing import Dict, List, Any, Union, Type, Optional
from src.common.logger import get_logger
from peewee import Model, DoesNotExist
@@ -21,12 +21,12 @@ logger = get_logger("database_api")
async def db_query(
model_class: Type[Model],
query_type: str = "get",
filters: Dict[str, Any] = None,
data: Dict[str, Any] = None,
limit: int = None,
order_by: List[str] = None,
single_result: bool = False,
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
@@ -34,11 +34,11 @@ async def db_query(
Args:
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
data: 用于创建或更新的数据字典
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典,键为字段名,值为要匹配的值
data: 用于创建或更新的数据字典
limit: 限制结果数量
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果
Returns:
@@ -48,7 +48,8 @@ async def db_query(
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
"""
"""
示例:
# 查询最近10条消息
messages = await database_api.db_query(
@@ -62,16 +63,16 @@ async def db_query(
# 创建一条记录
new_record = await database_api.db_query(
ActionRecords,
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
query_type="create",
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
)
# 更新记录
updated_count = await database_api.db_query(
ActionRecords,
data={"action_done": True},
query_type="update",
filters={"action_id": "123"},
data={"action_done": True}
)
# 删除记录
@@ -129,7 +130,7 @@ async def db_query(
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get()
return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore
elif query_type == "update":
if not data:
@@ -168,7 +169,7 @@ async def db_query(
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Union[Dict[str, Any], None]:
"""保存数据到数据库(创建或更新)
@@ -213,14 +214,14 @@ async def db_save(
existing_record.save()
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore
return created_record
except Exception as e:
@@ -230,7 +231,11 @@ async def db_save(
async def db_get(
model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
model_class: Type[Model],
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
@@ -239,11 +244,12 @@ async def db_get(
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
limit: 结果数量限制如果为1则返回单个记录而不是列表
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段即time字段降序
limit: 结果数量限制
single_result: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表
Returns:
如果limit=1返回单个记录字典或None
如果single_result为True返回单个记录字典或None
否则返回记录字典列表或空列表。
示例:
@@ -258,8 +264,8 @@ async def db_get(
records = await database_api.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by="-time",
limit=10
)
"""
try:
@@ -286,14 +292,14 @@ async def db_get(
results = list(query.dicts())
# 返回结果
if limit == 1:
if single_result:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if limit == 1 else []
return None if single_result else []
async def store_action_info(
@@ -302,7 +308,7 @@ async def store_action_info(
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: dict = None,
action_data: Optional[dict] = None,
action_name: str = "",
) -> Union[Dict[str, Any], None]:
"""存储动作信息到数据库