三次修改

This commit is contained in:
tt-P607
2025-09-20 02:21:53 +08:00
parent 6a98ae6208
commit 0cc4f5bb27
20 changed files with 478 additions and 481 deletions

View File

@@ -246,11 +246,11 @@ class ChatManager:
return stream
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
with get_db_session() as session:
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
async def _db_find_stream_async(s_id: str):
async with get_db_session() as session:
return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalar()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
model_instance = await _db_find_stream_async(stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
@@ -344,11 +344,10 @@ class ChatManager:
return
stream_data_dict = stream.to_dict()
def _db_save_stream_sync(s_data_dict: dict):
with get_db_session() as session:
async def _db_save_stream_async(s_data_dict: dict):
async with get_db_session() as session:
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
@@ -364,8 +363,6 @@ class ChatManager:
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
"focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value),
}
# 根据数据库类型选择插入语句
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
@@ -375,15 +372,13 @@ class ChatManager:
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
)
else:
# 默认使用通用插入尝试SQLite语法
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
session.execute(stmt)
session.commit()
await session.execute(stmt)
await session.commit()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
await _db_save_stream_async(stream_data_dict)
stream.saved = True
except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
@@ -397,10 +392,10 @@ class ChatManager:
"""从数据库加载所有聊天流"""
logger.info("正在从数据库加载所有聊天流")
def _db_load_all_streams_sync():
async def _db_load_all_streams_async():
loaded_streams_data = []
with get_db_session() as session:
for model_instance in session.execute(select(ChatStreams)).scalars():
async with get_db_session() as session:
for model_instance in (await session.execute(select(ChatStreams))).scalars():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
@@ -414,7 +409,6 @@ class ChatManager:
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
@@ -427,11 +421,11 @@ class ChatManager:
"focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value),
}
loaded_streams_data.append(data_for_from_dict)
session.commit()
await session.commit()
return loaded_streams_data
try:
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
all_streams_data_list = await _db_load_all_streams_async()
self.streams.clear()
for data in all_streams_data_list:
stream = ChatStream.from_dict(data)