三次修改

This commit is contained in:
tt-P607
2025-09-20 02:21:53 +08:00
committed by Windpicker-owo
parent 635311bc80
commit aba4f1a947
20 changed files with 923 additions and 479 deletions

View File

@@ -246,11 +246,11 @@ class ChatManager:
return stream
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
with get_db_session() as session:
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
async def _db_find_stream_async(s_id: str):
async with get_db_session() as session:
return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalar()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
model_instance = await _db_find_stream_async(stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
@@ -344,11 +344,10 @@ class ChatManager:
return
stream_data_dict = stream.to_dict()
def _db_save_stream_sync(s_data_dict: dict):
with get_db_session() as session:
async def _db_save_stream_async(s_data_dict: dict):
async with get_db_session() as session:
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
@@ -364,8 +363,6 @@ class ChatManager:
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
"focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value),
}
# 根据数据库类型选择插入语句
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
@@ -375,15 +372,13 @@ class ChatManager:
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
)
else:
# 默认使用通用插入尝试SQLite语法
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
session.execute(stmt)
session.commit()
await session.execute(stmt)
await session.commit()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
await _db_save_stream_async(stream_data_dict)
stream.saved = True
except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
@@ -397,10 +392,10 @@ class ChatManager:
"""从数据库加载所有聊天流"""
logger.info("正在从数据库加载所有聊天流")
def _db_load_all_streams_sync():
async def _db_load_all_streams_async():
loaded_streams_data = []
with get_db_session() as session:
for model_instance in session.execute(select(ChatStreams)).scalars():
async with get_db_session() as session:
for model_instance in (await session.execute(select(ChatStreams))).scalars():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
@@ -414,7 +409,6 @@ class ChatManager:
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
@@ -427,11 +421,11 @@ class ChatManager:
"focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value),
}
loaded_streams_data.append(data_for_from_dict)
session.commit()
await session.commit()
return loaded_streams_data
try:
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
all_streams_data_list = await _db_load_all_streams_async()
self.streams.clear()
for data in all_streams_data_list:
stream = ChatStream.from_dict(data)

View File

@@ -42,7 +42,7 @@ class MessageStorage:
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
@@ -129,9 +129,9 @@ class MessageStorage:
key_words=key_words,
key_words_lite=key_words_lite,
)
with get_db_session() as session:
async with get_db_session() as session:
session.add(new_message)
session.commit()
await session.commit()
except Exception:
logger.exception("存储消息失败")
@@ -174,16 +174,18 @@ class MessageStorage:
# 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
matched_message = session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
async with get_db_session() as session:
matched_message = (
await session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
)
).scalar()
if matched_message:
session.execute(
await session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
)
session.commit()
await session.commit()
# 会在上下文管理器中自动调用
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
@@ -197,28 +199,36 @@ class MessageStorage:
)
@staticmethod
def replace_image_descriptions(text: str) -> str:
async def replace_image_descriptions(text: str) -> str:
"""将[图片:描述]替换为[picid:image_id]"""
# 先检查文本中是否有图片标记
pattern = r"\[图片:([^\]]+)\]"
matches = re.findall(pattern, text)
matches = list(re.finditer(pattern, text))
if not matches:
logger.debug("文本中没有图片标记,直接返回原文本")
return text
def replace_match(match):
new_text = ""
last_end = 0
for match in matches:
new_text += text[last_end : match.start()]
description = match.group(1).strip()
try:
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
image_record = session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
async with get_db_session() as session:
image_record = (
await session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
)
).scalar()
session.commit()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
if image_record:
new_text += f"[picid:{image_record.image_id}]"
else:
new_text += match.group(0)
except Exception:
return match.group(0)
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)
new_text += match.group(0)
last_end = match.end()
new_text += text[last_end:]
return new_text