三次修改
This commit is contained in:
@@ -246,11 +246,11 @@ class ChatManager:
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
def _db_find_stream_sync(s_id: str):
|
||||
with get_db_session() as session:
|
||||
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
|
||||
async def _db_find_stream_async(s_id: str):
|
||||
async with get_db_session() as session:
|
||||
return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalar()
|
||||
|
||||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||
model_instance = await _db_find_stream_async(stream_id)
|
||||
|
||||
if model_instance:
|
||||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||||
@@ -344,11 +344,10 @@ class ChatManager:
|
||||
return
|
||||
stream_data_dict = stream.to_dict()
|
||||
|
||||
def _db_save_stream_sync(s_data_dict: dict):
|
||||
with get_db_session() as session:
|
||||
async def _db_save_stream_async(s_data_dict: dict):
|
||||
async with get_db_session() as session:
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
|
||||
fields_to_save = {
|
||||
"platform": s_data_dict["platform"],
|
||||
"create_time": s_data_dict["create_time"],
|
||||
@@ -364,8 +363,6 @@ class ChatManager:
|
||||
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
|
||||
"focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value),
|
||||
}
|
||||
|
||||
# 根据数据库类型选择插入语句
|
||||
if global_config.database.database_type == "sqlite":
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||||
@@ -375,15 +372,13 @@ class ChatManager:
|
||||
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
# 默认使用通用插入,尝试SQLite语法
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||||
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||
await _db_save_stream_async(stream_data_dict)
|
||||
stream.saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
||||
@@ -397,10 +392,10 @@ class ChatManager:
|
||||
"""从数据库加载所有聊天流"""
|
||||
logger.info("正在从数据库加载所有聊天流")
|
||||
|
||||
def _db_load_all_streams_sync():
|
||||
async def _db_load_all_streams_async():
|
||||
loaded_streams_data = []
|
||||
with get_db_session() as session:
|
||||
for model_instance in session.execute(select(ChatStreams)).scalars():
|
||||
async with get_db_session() as session:
|
||||
for model_instance in (await session.execute(select(ChatStreams))).scalars():
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
@@ -414,7 +409,6 @@ class ChatManager:
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
@@ -427,11 +421,11 @@ class ChatManager:
|
||||
"focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value),
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return loaded_streams_data
|
||||
|
||||
try:
|
||||
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
|
||||
all_streams_data_list = await _db_load_all_streams_async()
|
||||
self.streams.clear()
|
||||
for data in all_streams_data_list:
|
||||
stream = ChatStream.from_dict(data)
|
||||
|
||||
Reference in New Issue
Block a user