修复代码格式和文件名大小写问题
This commit is contained in:
@@ -11,38 +11,51 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy import desc, asc, func, and_
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
|
||||
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
|
||||
Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus,
|
||||
CacheEntries
|
||||
Base,
|
||||
get_db_session,
|
||||
Messages,
|
||||
ActionRecords,
|
||||
PersonInfo,
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
Emoji,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
OnlineTime,
|
||||
Memory,
|
||||
Expression,
|
||||
ThinkingLog,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
Schedule,
|
||||
MaiZoneScheduleStatus,
|
||||
CacheEntries,
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_database_api")
|
||||
|
||||
# 模型映射表,用于通过名称获取模型类
|
||||
MODEL_MAPPING = {
|
||||
'Messages': Messages,
|
||||
'ActionRecords': ActionRecords,
|
||||
'PersonInfo': PersonInfo,
|
||||
'ChatStreams': ChatStreams,
|
||||
'LLMUsage': LLMUsage,
|
||||
'Emoji': Emoji,
|
||||
'Images': Images,
|
||||
'ImageDescriptions': ImageDescriptions,
|
||||
'OnlineTime': OnlineTime,
|
||||
'Memory': Memory,
|
||||
'Expression': Expression,
|
||||
'ThinkingLog': ThinkingLog,
|
||||
'GraphNodes': GraphNodes,
|
||||
'GraphEdges': GraphEdges,
|
||||
'Schedule': Schedule,
|
||||
'MaiZoneScheduleStatus': MaiZoneScheduleStatus,
|
||||
'CacheEntries': CacheEntries,
|
||||
"Messages": Messages,
|
||||
"ActionRecords": ActionRecords,
|
||||
"PersonInfo": PersonInfo,
|
||||
"ChatStreams": ChatStreams,
|
||||
"LLMUsage": LLMUsage,
|
||||
"Emoji": Emoji,
|
||||
"Images": Images,
|
||||
"ImageDescriptions": ImageDescriptions,
|
||||
"OnlineTime": OnlineTime,
|
||||
"Memory": Memory,
|
||||
"Expression": Expression,
|
||||
"ThinkingLog": ThinkingLog,
|
||||
"GraphNodes": GraphNodes,
|
||||
"GraphEdges": GraphEdges,
|
||||
"Schedule": Schedule,
|
||||
"MaiZoneScheduleStatus": MaiZoneScheduleStatus,
|
||||
"CacheEntries": CacheEntries,
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]):
|
||||
"""构建查询过滤条件"""
|
||||
conditions = []
|
||||
@@ -225,10 +238,7 @@ async def db_query(
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Base],
|
||||
data: Dict[str, Any],
|
||||
key_field: Optional[str] = None,
|
||||
key_value: Optional[Any] = None
|
||||
model_class: Type[Base], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
@@ -246,9 +256,9 @@ async def db_save(
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
if hasattr(model_class, key_field):
|
||||
existing_record = session.query(model_class).filter(
|
||||
getattr(model_class, key_field) == key_value
|
||||
).first()
|
||||
existing_record = (
|
||||
session.query(model_class).filter(getattr(model_class, key_field) == key_value).first()
|
||||
)
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
@@ -312,7 +322,7 @@ async def db_get(
|
||||
filters=filters,
|
||||
limit=limit,
|
||||
order_by=order_by_list,
|
||||
single_result=single_result
|
||||
single_result=single_result,
|
||||
)
|
||||
|
||||
|
||||
@@ -347,7 +357,7 @@ async def store_action_info(
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": orjson.dumps(action_data or {}).decode('utf-8'),
|
||||
"action_data": orjson.dumps(action_data or {}).decode("utf-8"),
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
@@ -355,24 +365,25 @@ async def store_action_info(
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update({
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
})
|
||||
record_data.update(
|
||||
{
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
}
|
||||
)
|
||||
else:
|
||||
record_data.update({
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
})
|
||||
record_data.update(
|
||||
{
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
}
|
||||
)
|
||||
|
||||
# 保存记录
|
||||
saved_record = await db_save(
|
||||
ActionRecords,
|
||||
data=record_data,
|
||||
key_field="action_id",
|
||||
key_value=record_data["action_id"]
|
||||
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
@@ -386,4 +397,3 @@ async def store_action_info(
|
||||
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user