refactor(chat): 移除 ChatStream 的历史消息自动加载功能

移除 ChatStream 初始化时的 `_load_history_messages()` 调用,改为按需异步加载历史消息。这解决了启动时阻塞事件循环的问题,并提高了聊天流初始化的性能。

主要变更:
- 删除 `ChatStream._load_history_messages()` 方法及相关代码
- 将多个模块中的同步数据库查询函数改为异步版本
- 修复相关调用处的异步调用方式
- 优化图片描述查询的错误处理

BREAKING CHANGE: `get_raw_msg_before_timestamp_with_chat` 和相关消息查询函数现在改为异步操作,需要调用处使用 await
This commit is contained in:
Windpicker-owo
2025-09-28 21:31:49 +08:00
parent fd76e36320
commit 28bce19d27
12 changed files with 32 additions and 156 deletions

View File

@@ -174,8 +174,8 @@ class InstantMemory:
)
)).scalars()
else:
query = result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id))
result.scalars()
result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id))
query = result.scalars()
for mem in query:
# 对每条记忆
mem_keywords_str = mem.keywords or "[]"

View File

@@ -212,7 +212,7 @@ class MessageManager:
return
context = chat_stream.stream_context
# 获取未读消息
unread_messages = chat_stream.context_manager.get_unread_messages()
if not unread_messages:

View File

@@ -66,9 +66,6 @@ class ChatStream:
self._focus_energy = 0.5 # 内部存储的focus_energy值
self.no_reply_consecutive = 0
# 自动加载历史消息
self._load_history_messages()
def __deepcopy__(self, memo):
"""自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象"""
import copy
@@ -380,119 +377,6 @@ class ChatStream:
# 默认基础分
return 0.3
def _load_history_messages(self):
"""从数据库加载历史消息到StreamContext"""
try:
from src.common.database.sqlalchemy_models import Messages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.data_models.database_data_model import DatabaseMessages
from sqlalchemy import select, desc
import asyncio
async def _load_history_messages_async():
"""异步加载并转换历史消息到 stream_context在事件循环中运行"""
try:
async with get_db_session() as session:
stmt = (
select(Messages)
.where(Messages.chat_info_stream_id == self.stream_id)
.order_by(desc(Messages.time))
.limit(global_config.chat.max_context_size)
)
result = await session.execute(stmt)
db_messages = result.scalars().all()
# 转换为DatabaseMessages对象并添加到StreamContext
for db_msg in db_messages:
try:
import orjson
actions = None
if db_msg.actions:
try:
actions = orjson.loads(db_msg.actions)
except (orjson.JSONDecodeError, TypeError):
actions = None
db_message = DatabaseMessages(
message_id=db_msg.message_id,
time=db_msg.time,
chat_id=db_msg.chat_id,
reply_to=db_msg.reply_to,
interest_value=db_msg.interest_value,
key_words=db_msg.key_words,
key_words_lite=db_msg.key_words_lite,
is_mentioned=db_msg.is_mentioned,
processed_plain_text=db_msg.processed_plain_text,
display_message=db_msg.display_message,
priority_mode=db_msg.priority_mode,
priority_info=db_msg.priority_info,
additional_config=db_msg.additional_config,
is_emoji=db_msg.is_emoji,
is_picid=db_msg.is_picid,
is_command=db_msg.is_command,
is_notify=db_msg.is_notify,
user_id=db_msg.user_id,
user_nickname=db_msg.user_nickname,
user_cardname=db_msg.user_cardname,
user_platform=db_msg.user_platform,
chat_info_group_id=db_msg.chat_info_group_id,
chat_info_group_name=db_msg.chat_info_group_name,
chat_info_group_platform=db_msg.chat_info_group_platform,
chat_info_user_id=db_msg.chat_info_user_id,
chat_info_user_nickname=db_msg.chat_info_user_nickname,
chat_info_user_cardname=db_msg.chat_info_user_cardname,
chat_info_user_platform=db_msg.chat_info_user_platform,
chat_info_stream_id=db_msg.chat_info_stream_id,
chat_info_platform=db_msg.chat_info_platform,
chat_info_create_time=db_msg.chat_info_create_time,
chat_info_last_active_time=db_msg.chat_info_last_active_time,
actions=actions,
should_reply=getattr(db_msg, "should_reply", False) or False,
)
logger.debug(
f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}"
)
db_message.is_read = True
self.stream_context.history_messages.append(db_message)
except Exception as e:
logger.warning(f"转换消息 {getattr(db_msg, 'message_id', '<unknown>')} 失败: {e}")
continue
if self.stream_context.history_messages:
logger.info(
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}"
)
except Exception as e:
logger.warning(f"异步加载历史消息失败: {e}")
# 在已有事件循环中,避免调用 asyncio.run()
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# 没有运行的事件循环,安全地运行并等待完成
asyncio.run(_load_history_messages_async())
else:
# 如果事件循环正在运行,在后台创建任务
if loop.is_running():
try:
asyncio.create_task(_load_history_messages_async())
except Exception as e:
# 如果无法创建任务,退回到阻塞运行
logger.warning(f"无法在事件循环中创建后台任务,尝试阻塞运行: {e}")
asyncio.run(_load_history_messages_async())
else:
# loop 存在但未运行,使用 asyncio.run
asyncio.run(_load_history_messages_async())
except Exception as e:
logger.error(f"加载历史消息失败: {e}")
class ChatManager:
"""聊天管理器,管理所有聊天流"""

View File

@@ -317,10 +317,10 @@ class ChatterActionManager:
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream:
context = chat_stream.context_manager
if context.interruption_count > 0:
old_count = context.interruption_count
old_afc_adjustment = context.get_afc_threshold_adjustment()
context.reset_interruption_count()
if context.context.interruption_count > 0:
old_count = context.context.interruption_count
old_afc_adjustment = context.context.get_afc_threshold_adjustment()
context.context.reset_interruption_count()
logger.debug(f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0")
except Exception as e:
logger.warning(f"重置打断计数时出错: {e}")

View File

@@ -690,7 +690,7 @@ class DefaultReplyer:
if chat_stream:
stream_context = chat_stream.context_manager
# 使用真正的已读和未读消息
read_messages = stream_context.history_messages # 已读消息
read_messages = stream_context.context.history_messages # 已读消息
unread_messages = stream_context.get_unread_messages() # 未读消息
# 构建已读历史消息 prompt
@@ -706,7 +706,7 @@ class DefaultReplyer:
else:
# 如果没有已读消息,则从数据库加载最近的上下文
logger.info("暂无已读历史消息,正在从数据库加载上下文...")
fallback_messages = get_raw_msg_before_timestamp_with_chat(
fallback_messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size,
@@ -1063,13 +1063,13 @@ class DefaultReplyer:
action_descriptions += f"- {action_name}: {action_description}\n"
action_descriptions += "\n"
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size * 2,
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
@@ -1322,7 +1322,7 @@ class DefaultReplyer:
else:
mood_prompt = ""
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),

View File

@@ -12,6 +12,8 @@ from src.person_info.person_info import PersonInfoManager, get_person_info_manag
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select, and_
from src.common.logger import get_logger
logger = get_logger("chat_message_builder")
install(extra_lines=3)
@@ -458,13 +460,13 @@ def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
@@ -826,10 +828,11 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
async with get_db_session() as session:
result = await session.execute(select(Images).where(Images.image_id == pic_id))
image = result.scalar_one_or_none()
if image and image.description: # type: ignore
if image and hasattr(image, 'description') and image.description:
description = image.description
except Exception:
except Exception as e:
# 如果查询失败,保持默认描述
logger.debug(f"[chat_message_builder] 查询图片描述失败: {e}")
pass
mapping_lines.append(f"[{display_name}] 的内容:{description}")

View File

@@ -308,8 +308,8 @@ class ImageManager:
async with get_db_session() as session:
# 优先检查Images表中是否已有完整的描述
existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
result.scalar()
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
existing_image = result.scalar()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
@@ -528,8 +528,8 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
async with get_db_session() as session:
existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
result.scalar()
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
existing_image = result.scalar()
if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录
if (