refactor(chat): 移除 ChatStream 的历史消息自动加载功能
移除 ChatStream 初始化时的 `_load_history_messages()` 调用,改为按需异步加载历史消息。这解决了启动时阻塞事件循环的问题,并提高了聊天流初始化的性能。 主要变更: - 删除 `ChatStream._load_history_messages()` 方法及相关代码 - 将多个模块中的同步数据库查询函数改为异步版本 - 修复相关调用处的异步调用方式 - 优化图片描述查询的错误处理 BREAKING CHANGE: `get_raw_msg_before_timestamp_with_chat` 和相关消息查询函数现在改为异步操作,需要调用处使用 await
This commit is contained in:
@@ -174,8 +174,8 @@ class InstantMemory:
|
|||||||
)
|
)
|
||||||
)).scalars()
|
)).scalars()
|
||||||
else:
|
else:
|
||||||
query = result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id))
|
result = await session.execute(select(Memory).where(Memory.chat_id == self.chat_id))
|
||||||
result.scalars()
|
query = result.scalars()
|
||||||
for mem in query:
|
for mem in query:
|
||||||
# 对每条记忆
|
# 对每条记忆
|
||||||
mem_keywords_str = mem.keywords or "[]"
|
mem_keywords_str = mem.keywords or "[]"
|
||||||
|
|||||||
@@ -66,9 +66,6 @@ class ChatStream:
|
|||||||
self._focus_energy = 0.5 # 内部存储的focus_energy值
|
self._focus_energy = 0.5 # 内部存储的focus_energy值
|
||||||
self.no_reply_consecutive = 0
|
self.no_reply_consecutive = 0
|
||||||
|
|
||||||
# 自动加载历史消息
|
|
||||||
self._load_history_messages()
|
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
"""自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象"""
|
"""自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象"""
|
||||||
import copy
|
import copy
|
||||||
@@ -380,119 +377,6 @@ class ChatStream:
|
|||||||
# 默认基础分
|
# 默认基础分
|
||||||
return 0.3
|
return 0.3
|
||||||
|
|
||||||
def _load_history_messages(self):
|
|
||||||
"""从数据库加载历史消息到StreamContext"""
|
|
||||||
try:
|
|
||||||
from src.common.database.sqlalchemy_models import Messages
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
from sqlalchemy import select, desc
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
async def _load_history_messages_async():
|
|
||||||
"""异步加载并转换历史消息到 stream_context(在事件循环中运行)。"""
|
|
||||||
try:
|
|
||||||
async with get_db_session() as session:
|
|
||||||
stmt = (
|
|
||||||
select(Messages)
|
|
||||||
.where(Messages.chat_info_stream_id == self.stream_id)
|
|
||||||
.order_by(desc(Messages.time))
|
|
||||||
.limit(global_config.chat.max_context_size)
|
|
||||||
)
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
db_messages = result.scalars().all()
|
|
||||||
|
|
||||||
# 转换为DatabaseMessages对象并添加到StreamContext
|
|
||||||
for db_msg in db_messages:
|
|
||||||
try:
|
|
||||||
import orjson
|
|
||||||
|
|
||||||
actions = None
|
|
||||||
if db_msg.actions:
|
|
||||||
try:
|
|
||||||
actions = orjson.loads(db_msg.actions)
|
|
||||||
except (orjson.JSONDecodeError, TypeError):
|
|
||||||
actions = None
|
|
||||||
|
|
||||||
db_message = DatabaseMessages(
|
|
||||||
message_id=db_msg.message_id,
|
|
||||||
time=db_msg.time,
|
|
||||||
chat_id=db_msg.chat_id,
|
|
||||||
reply_to=db_msg.reply_to,
|
|
||||||
interest_value=db_msg.interest_value,
|
|
||||||
key_words=db_msg.key_words,
|
|
||||||
key_words_lite=db_msg.key_words_lite,
|
|
||||||
is_mentioned=db_msg.is_mentioned,
|
|
||||||
processed_plain_text=db_msg.processed_plain_text,
|
|
||||||
display_message=db_msg.display_message,
|
|
||||||
priority_mode=db_msg.priority_mode,
|
|
||||||
priority_info=db_msg.priority_info,
|
|
||||||
additional_config=db_msg.additional_config,
|
|
||||||
is_emoji=db_msg.is_emoji,
|
|
||||||
is_picid=db_msg.is_picid,
|
|
||||||
is_command=db_msg.is_command,
|
|
||||||
is_notify=db_msg.is_notify,
|
|
||||||
user_id=db_msg.user_id,
|
|
||||||
user_nickname=db_msg.user_nickname,
|
|
||||||
user_cardname=db_msg.user_cardname,
|
|
||||||
user_platform=db_msg.user_platform,
|
|
||||||
chat_info_group_id=db_msg.chat_info_group_id,
|
|
||||||
chat_info_group_name=db_msg.chat_info_group_name,
|
|
||||||
chat_info_group_platform=db_msg.chat_info_group_platform,
|
|
||||||
chat_info_user_id=db_msg.chat_info_user_id,
|
|
||||||
chat_info_user_nickname=db_msg.chat_info_user_nickname,
|
|
||||||
chat_info_user_cardname=db_msg.chat_info_user_cardname,
|
|
||||||
chat_info_user_platform=db_msg.chat_info_user_platform,
|
|
||||||
chat_info_stream_id=db_msg.chat_info_stream_id,
|
|
||||||
chat_info_platform=db_msg.chat_info_platform,
|
|
||||||
chat_info_create_time=db_msg.chat_info_create_time,
|
|
||||||
chat_info_last_active_time=db_msg.chat_info_last_active_time,
|
|
||||||
actions=actions,
|
|
||||||
should_reply=getattr(db_msg, "should_reply", False) or False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
db_message.is_read = True
|
|
||||||
self.stream_context.history_messages.append(db_message)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"转换消息 {getattr(db_msg, 'message_id', '<unknown>')} 失败: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self.stream_context.history_messages:
|
|
||||||
logger.info(
|
|
||||||
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"异步加载历史消息失败: {e}")
|
|
||||||
|
|
||||||
# 在已有事件循环中,避免调用 asyncio.run()
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
# 没有运行的事件循环,安全地运行并等待完成
|
|
||||||
asyncio.run(_load_history_messages_async())
|
|
||||||
else:
|
|
||||||
# 如果事件循环正在运行,在后台创建任务
|
|
||||||
if loop.is_running():
|
|
||||||
try:
|
|
||||||
asyncio.create_task(_load_history_messages_async())
|
|
||||||
except Exception as e:
|
|
||||||
# 如果无法创建任务,退回到阻塞运行
|
|
||||||
logger.warning(f"无法在事件循环中创建后台任务,尝试阻塞运行: {e}")
|
|
||||||
asyncio.run(_load_history_messages_async())
|
|
||||||
else:
|
|
||||||
# loop 存在但未运行,使用 asyncio.run
|
|
||||||
asyncio.run(_load_history_messages_async())
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载历史消息失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class ChatManager:
|
class ChatManager:
|
||||||
"""聊天管理器,管理所有聊天流"""
|
"""聊天管理器,管理所有聊天流"""
|
||||||
|
|
||||||
|
|||||||
@@ -317,10 +317,10 @@ class ChatterActionManager:
|
|||||||
chat_stream = chat_manager.get_stream(stream_id)
|
chat_stream = chat_manager.get_stream(stream_id)
|
||||||
if chat_stream:
|
if chat_stream:
|
||||||
context = chat_stream.context_manager
|
context = chat_stream.context_manager
|
||||||
if context.interruption_count > 0:
|
if context.context.interruption_count > 0:
|
||||||
old_count = context.interruption_count
|
old_count = context.context.interruption_count
|
||||||
old_afc_adjustment = context.get_afc_threshold_adjustment()
|
old_afc_adjustment = context.context.get_afc_threshold_adjustment()
|
||||||
context.reset_interruption_count()
|
context.context.reset_interruption_count()
|
||||||
logger.debug(f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0")
|
logger.debug(f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"重置打断计数时出错: {e}")
|
logger.warning(f"重置打断计数时出错: {e}")
|
||||||
|
|||||||
@@ -712,7 +712,7 @@ class DefaultReplyer:
|
|||||||
if chat_stream:
|
if chat_stream:
|
||||||
stream_context = chat_stream.context_manager
|
stream_context = chat_stream.context_manager
|
||||||
# 使用真正的已读和未读消息
|
# 使用真正的已读和未读消息
|
||||||
read_messages = stream_context.history_messages # 已读消息
|
read_messages = stream_context.context.history_messages # 已读消息
|
||||||
unread_messages = stream_context.get_unread_messages() # 未读消息
|
unread_messages = stream_context.get_unread_messages() # 未读消息
|
||||||
|
|
||||||
# 构建已读历史消息 prompt
|
# 构建已读历史消息 prompt
|
||||||
@@ -728,7 +728,7 @@ class DefaultReplyer:
|
|||||||
else:
|
else:
|
||||||
# 如果没有已读消息,则从数据库加载最近的上下文
|
# 如果没有已读消息,则从数据库加载最近的上下文
|
||||||
logger.info("暂无已读历史消息,正在从数据库加载上下文...")
|
logger.info("暂无已读历史消息,正在从数据库加载上下文...")
|
||||||
fallback_messages = get_raw_msg_before_timestamp_with_chat(
|
fallback_messages = await get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=global_config.chat.max_context_size,
|
limit=global_config.chat.max_context_size,
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from src.person_info.person_info import PersonInfoManager, get_person_info_manag
|
|||||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
from sqlalchemy import select, and_
|
from sqlalchemy import select, and_
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
logger = get_logger("chat_message_builder")
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -830,10 +832,11 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
|||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
result = await session.execute(select(Images).where(Images.image_id == pic_id))
|
result = await session.execute(select(Images).where(Images.image_id == pic_id))
|
||||||
image = result.scalar_one_or_none()
|
image = result.scalar_one_or_none()
|
||||||
if image and image.description: # type: ignore
|
if image and hasattr(image, 'description') and image.description:
|
||||||
description = image.description
|
description = image.description
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# 如果查询失败,保持默认描述
|
# 如果查询失败,保持默认描述
|
||||||
|
logger.debug(f"[chat_message_builder] 查询图片描述失败: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
mapping_lines.append(f"[{display_name}] 的内容:{description}")
|
mapping_lines.append(f"[{display_name}] 的内容:{description}")
|
||||||
|
|||||||
@@ -308,8 +308,8 @@ class ImageManager:
|
|||||||
|
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 优先检查Images表中是否已有完整的描述
|
# 优先检查Images表中是否已有完整的描述
|
||||||
existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
||||||
result.scalar()
|
existing_image = result.scalar()
|
||||||
if existing_image:
|
if existing_image:
|
||||||
# 更新计数
|
# 更新计数
|
||||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||||
@@ -528,8 +528,8 @@ class ImageManager:
|
|||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
||||||
result.scalar()
|
existing_image = result.scalar()
|
||||||
if existing_image:
|
if existing_image:
|
||||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ class PromptBuilder:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U):
|
async def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U):
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = await get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=300,
|
limit=300,
|
||||||
@@ -251,7 +251,7 @@ class PromptBuilder:
|
|||||||
for msg in all_msg_seg_list:
|
for msg in all_msg_seg_list:
|
||||||
core_msg_str += msg
|
core_msg_str += msg
|
||||||
|
|
||||||
all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat(
|
all_dialogue_prompt = await get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=20,
|
limit=20,
|
||||||
|
|||||||
@@ -263,9 +263,9 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool
|
|||||||
return get_raw_msg_before_timestamp(timestamp, limit)
|
return get_raw_msg_before_timestamp(timestamp, limit)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_before_time_in_chat(
|
async def get_messages_before_time_in_chat(
|
||||||
chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False
|
chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False
|
||||||
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取指定聊天中指定时间戳之前的消息
|
获取指定聊天中指定时间戳之前的消息
|
||||||
|
|
||||||
@@ -290,8 +290,8 @@ def get_messages_before_time_in_chat(
|
|||||||
if not isinstance(chat_id, str):
|
if not isinstance(chat_id, str):
|
||||||
raise ValueError("chat_id 必须是字符串类型")
|
raise ValueError("chat_id 必须是字符串类型")
|
||||||
if filter_mai:
|
if filter_mai:
|
||||||
return filter_mai_messages(get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit))
|
return await filter_mai_messages(await get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit))
|
||||||
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
|
return await get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> Coroutine[
|
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> Coroutine[
|
||||||
|
|||||||
@@ -329,11 +329,11 @@ class ChatterPlanFilter:
|
|||||||
stream_context = chat_stream.context_manager
|
stream_context = chat_stream.context_manager
|
||||||
|
|
||||||
# 获取真正的已读和未读消息
|
# 获取真正的已读和未读消息
|
||||||
read_messages = stream_context.history_messages # 已读消息存储在history_messages中
|
read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中
|
||||||
if not read_messages:
|
if not read_messages:
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
# 如果内存中没有已读消息(比如刚启动),则从数据库加载最近的上下文
|
# 如果内存中没有已读消息(比如刚启动),则从数据库加载最近的上下文
|
||||||
fallback_messages_dicts = get_raw_msg_before_timestamp_with_chat(
|
fallback_messages_dicts = await get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=plan.chat_id,
|
chat_id=plan.chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=global_config.chat.max_context_size,
|
limit=global_config.chat.max_context_size,
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ class ChatterPlanGenerator:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取最近的消息记录
|
# 获取最近的消息记录
|
||||||
raw_messages = get_raw_msg_before_timestamp_with_chat(
|
raw_messages = await get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length
|
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -338,7 +338,6 @@ class NoticeHandler:
|
|||||||
|
|
||||||
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
|
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
|
||||||
await event_manager.trigger_event(
|
await event_manager.trigger_event(
|
||||||
<<<<<<< HEAD
|
|
||||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||||
permission_group=PLUGIN_NAME,
|
permission_group=PLUGIN_NAME,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
@@ -350,16 +349,6 @@ class NoticeHandler:
|
|||||||
type="text",
|
type="text",
|
||||||
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
|
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
|
||||||
)
|
)
|
||||||
=======
|
|
||||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
|
||||||
permission_group=PLUGIN_NAME,
|
|
||||||
group_id=group_id,
|
|
||||||
user_id=user_id,
|
|
||||||
message_id=raw_message.get("message_id",""),
|
|
||||||
emoji_id=like_emoji_id
|
|
||||||
)
|
|
||||||
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]")
|
|
||||||
>>>>>>> 9912d7f643d347cbadcf1e3d618aa78bcbf89cc4
|
|
||||||
return seg_data, user_info
|
return seg_data, user_info
|
||||||
|
|
||||||
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||||
|
|||||||
Reference in New Issue
Block a user