refactor(chat): 迁移数据库操作为异步模式并修复相关调用

将同步数据库操作全面迁移为异步模式,主要涉及:
- 将 `with get_db_session()` 改为 `async with get_db_session()`
- 修复相关异步调用链,确保 await 正确传递
- 优化消息管理器、上下文管理器等核心组件的异步处理
- 移除同步的 person_id 获取方法,避免协程对象传递问题

修复 deepcopy 在 StreamContext 中的序列化问题,跳过不可序列化的 asyncio.Task 对象

删除无用的测试文件和废弃的插件清单文件
This commit is contained in:
Windpicker-owo
2025-09-28 20:40:46 +08:00
parent 08ef960947
commit fd76e36320
30 changed files with 481 additions and 625 deletions

View File

@@ -261,7 +261,7 @@ class AntiPromptInjector:
logger.warning("无法删除消息缺少message_id") logger.warning("无法删除消息缺少message_id")
return return
with get_db_session() as session: async with get_db_session() as session:
# 删除对应的消息记录 # 删除对应的消息记录
stmt = delete(Messages).where(Messages.message_id == message_id) stmt = delete(Messages).where(Messages.message_id == message_id)
result = session.execute(stmt) result = session.execute(stmt)
@@ -287,7 +287,7 @@ class AntiPromptInjector:
logger.warning("无法更新消息缺少message_id") logger.warning("无法更新消息缺少message_id")
return return
with get_db_session() as session: async with get_db_session() as session:
# 更新消息内容 # 更新消息内容
stmt = ( stmt = (
update(Messages) update(Messages)

View File

@@ -42,7 +42,7 @@ class SingleStreamContextManager:
self._update_access_stats() self._update_access_stats()
return self.context return self.context
def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool: async def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""添加消息到上下文 """添加消息到上下文
Args: Args:
@@ -53,30 +53,21 @@ class SingleStreamContextManager:
bool: 是否成功添加 bool: 是否成功添加
""" """
try: try:
# 添加消息到上下文
self.context.add_message(message) self.context.add_message(message)
interest_value = await self._calculate_message_interest(message)
# 计算消息兴趣度
interest_value = self._calculate_message_interest(message)
message.interest_value = interest_value message.interest_value = interest_value
# 更新统计
self.total_messages += 1 self.total_messages += 1
self.last_access_time = time.time() self.last_access_time = time.time()
# 更新能量和分发
if not skip_energy_update: if not skip_energy_update:
self._update_stream_energy() await self._update_stream_energy()
distribution_manager.add_stream_message(self.stream_id, 1) distribution_manager.add_stream_message(self.stream_id, 1)
logger.debug(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})") logger.debug(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
return True return True
except Exception as e: except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False return False
def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool: async def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
"""更新上下文中的消息 """更新上下文中的消息
Args: Args:
@@ -87,16 +78,11 @@ class SingleStreamContextManager:
bool: 是否成功更新 bool: 是否成功更新
""" """
try: try:
# 更新消息信息
self.context.update_message_info(message_id, **updates) self.context.update_message_info(message_id, **updates)
# 如果更新了兴趣度,重新计算能量
if "interest_value" in updates: if "interest_value" in updates:
self._update_stream_energy() await self._update_stream_energy()
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}") logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True) logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
return False return False
@@ -164,16 +150,13 @@ class SingleStreamContextManager:
logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True) logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True)
return False return False
def clear_context(self) -> bool: async def clear_context(self) -> bool:
"""清空上下文""" """清空上下文"""
try: try:
# 清空消息
if hasattr(self.context, "unread_messages"): if hasattr(self.context, "unread_messages"):
self.context.unread_messages.clear() self.context.unread_messages.clear()
if hasattr(self.context, "history_messages"): if hasattr(self.context, "history_messages"):
self.context.history_messages.clear() self.context.history_messages.clear()
# 重置状态
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"] reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
for attr in reset_attrs: for attr in reset_attrs:
if hasattr(self.context, attr): if hasattr(self.context, attr):
@@ -181,13 +164,9 @@ class SingleStreamContextManager:
setattr(self.context, attr, 0) setattr(self.context, attr, 0)
else: else:
setattr(self.context, attr, time.time()) setattr(self.context, attr, time.time())
await self._update_stream_energy()
# 重新计算能量
self._update_stream_energy()
logger.info(f"清空单流上下文: {self.stream_id}") logger.info(f"清空单流上下文: {self.stream_id}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True) logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False return False
@@ -249,39 +228,115 @@ class SingleStreamContextManager:
self.last_access_time = time.time() self.last_access_time = time.time()
self.access_count += 1 self.access_count += 1
def _calculate_message_interest(self, message: DatabaseMessages) -> float: async def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""计算消息兴趣度""" """异步实现:使用插件的异步评分器正确 await 计算兴趣度并返回分数。"""
try: try:
# 使用插件内部的兴趣度评分系统
try: try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system,
# 使用插件内部的兴趣度评分系统计算(同步方式) )
try: try:
loop = asyncio.get_event_loop() interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
interest_score = loop.run_until_complete(
chatter_interest_scoring_system._calculate_single_message_score(
message=message, bot_nickname=global_config.bot.nickname message=message, bot_nickname=global_config.bot.nickname
) )
) interest_value = interest_score.total_score
interest_value = interest_score.total_score logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
return interest_value
except Exception as e:
logger.warning(f"插件内部兴趣度计算失败: {e}")
return 0.5
except Exception as e:
logger.warning(f"插件内部兴趣度计算加载失败,使用默认值: {e}")
return 0.5
except Exception as e:
logger.error(f"计算消息兴趣度失败: {e}")
return 0.5
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}") async def _calculate_message_interest_async(self, message: DatabaseMessages) -> float:
"""异步实现:使用插件的异步评分器正确 await 计算兴趣度并返回分数。"""
try:
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system,
)
# 直接 await 插件的异步方法
try:
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
message=message, bot_nickname=global_config.bot.nickname
)
interest_value = interest_score.total_score
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
return interest_value
except Exception as e:
logger.warning(f"插件内部兴趣度计算失败: {e}")
return 0.5
except Exception as e: except Exception as e:
logger.warning(f"插件内部兴趣度计算失败,使用默认值: {e}") logger.warning(f"插件内部兴趣度计算加载失败,使用默认值: {e}")
interest_value = 0.5 # 默认中等兴趣度 return 0.5
return interest_value
except Exception as e: except Exception as e:
logger.error(f"计算消息兴趣度失败: {e}") logger.error(f"计算消息兴趣度失败: {e}")
return 0.5 return 0.5
async def add_message_async(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""异步实现的 add_message将消息添加到 context并 await 能量更新与分发。"""
try:
self.context.add_message(message)
interest_value = await self._calculate_message_interest_async(message)
message.interest_value = interest_value
self.total_messages += 1
self.last_access_time = time.time()
if not skip_energy_update:
await self._update_stream_energy()
distribution_manager.add_stream_message(self.stream_id, 1)
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度: {interest_value:.3f})")
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
return False
async def update_message_async(self, message_id: str, updates: Dict[str, Any]) -> bool:
"""异步实现的 update_message更新消息并在需要时 await 能量更新。"""
try:
self.context.update_message_info(message_id, **updates)
if "interest_value" in updates:
await self._update_stream_energy()
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
return True
except Exception as e:
logger.error(f"更新单流上下文消息失败 (async) {self.stream_id}/{message_id}: {e}", exc_info=True)
return False
async def clear_context_async(self) -> bool:
"""异步实现的 clear_context清空消息并 await 能量重算。"""
try:
if hasattr(self.context, "unread_messages"):
self.context.unread_messages.clear()
if hasattr(self.context, "history_messages"):
self.context.history_messages.clear()
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
for attr in reset_attrs:
if hasattr(self.context, attr):
if attr in ["interruption_count", "afc_threshold_adjustment"]:
setattr(self.context, attr, 0)
else:
setattr(self.context, attr, time.time())
await self._update_stream_energy()
logger.info(f"清空单流上下文(异步): {self.stream_id}")
return True
except Exception as e:
logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
return False
async def _update_stream_energy(self): async def _update_stream_energy(self):
"""更新流能量""" """更新流能量"""
try: try:
@@ -305,4 +360,4 @@ class SingleStreamContextManager:
distribution_manager.update_stream_energy(self.stream_id, energy) distribution_manager.update_stream_energy(self.stream_id, energy)
except Exception as e: except Exception as e:
logger.error(f"更新单流能量失败 {self.stream_id}: {e}") logger.error(f"更新单流能量失败 {self.stream_id}: {e}")

View File

@@ -75,29 +75,23 @@ class MessageManager:
logger.info("消息管理器已停止") logger.info("消息管理器已停止")
def add_message(self, stream_id: str, message: DatabaseMessages): async def add_message(self, stream_id: str, message: DatabaseMessages):
"""添加消息到指定聊天流""" """添加消息到指定聊天流"""
try: try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id) chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream: if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在") logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return return
success = await chat_stream.context_manager.add_message(message)
# 使用 ChatStream 的 context_manager 添加消息
success = chat_stream.context_manager.add_message(message)
if success: if success:
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}") logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
else: else:
logger.warning(f"添加消息到聊天流 {stream_id} 失败") logger.warning(f"添加消息到聊天流 {stream_id} 失败")
except Exception as e: except Exception as e:
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}") logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
def update_message( async def update_message(
self, self,
stream_id: str, stream_id: str,
message_id: str, message_id: str,
@@ -107,15 +101,11 @@ class MessageManager:
): ):
"""更新消息信息""" """更新消息信息"""
try: try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id) chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream: if not chat_stream:
logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在") logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在")
return return
# 构建更新字典
updates = {} updates = {}
if interest_value is not None: if interest_value is not None:
updates["interest_value"] = interest_value updates["interest_value"] = interest_value
@@ -123,41 +113,30 @@ class MessageManager:
updates["actions"] = actions updates["actions"] = actions
if should_reply is not None: if should_reply is not None:
updates["should_reply"] = should_reply updates["should_reply"] = should_reply
# 使用 ChatStream 的 context_manager 更新消息
if updates: if updates:
success = chat_stream.context_manager.update_message(message_id, updates) success = await chat_stream.context_manager.update_message(message_id, updates)
if success: if success:
logger.debug(f"更新消息 {message_id} 成功") logger.debug(f"更新消息 {message_id} 成功")
else: else:
logger.warning(f"更新消息 {message_id} 失败") logger.warning(f"更新消息 {message_id} 失败")
except Exception as e: except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}") logger.error(f"更新消息 {message_id} 时发生错误: {e}")
def add_action(self, stream_id: str, message_id: str, action: str): async def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息""" """添加动作到消息"""
try: try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id) chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream: if not chat_stream:
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在") logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
return return
success = await chat_stream.context_manager.update_message(
# 使用 ChatStream 的 context_manager 添加动作
# 注意:这里需要根据实际的 API 调整
# 假设我们可以通过 update_message 来添加动作
success = chat_stream.context_manager.update_message(
message_id, {"actions": [action]} message_id, {"actions": [action]}
) )
if success: if success:
logger.debug(f"为消息 {message_id} 添加动作 {action} 成功") logger.debug(f"为消息 {message_id} 添加动作 {action} 成功")
else: else:
logger.warning(f"为消息 {message_id} 添加动作 {action} 失败") logger.warning(f"为消息 {message_id} 添加动作 {action} 失败")
except Exception as e: except Exception as e:
logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}") logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}")
@@ -382,36 +361,27 @@ class MessageManager:
"start_time": self.stats.start_time, "start_time": self.stats.start_time,
} }
def cleanup_inactive_streams(self, max_inactive_hours: int = 24): async def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
"""清理不活跃的聊天流""" """清理不活跃的聊天流"""
try: try:
# 通过 ChatManager 清理不活跃的流
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
current_time = time.time() current_time = time.time()
max_inactive_seconds = max_inactive_hours * 3600 max_inactive_seconds = max_inactive_hours * 3600
inactive_streams = [] inactive_streams = []
for stream_id, chat_stream in chat_manager.streams.items(): for stream_id, chat_stream in chat_manager.streams.items():
# 检查最后活跃时间
if current_time - chat_stream.last_active_time > max_inactive_seconds: if current_time - chat_stream.last_active_time > max_inactive_seconds:
inactive_streams.append(stream_id) inactive_streams.append(stream_id)
# 清理不活跃的流
for stream_id in inactive_streams: for stream_id in inactive_streams:
try: try:
# 清理流的内容 await chat_stream.context_manager.clear_context()
chat_stream.context_manager.clear_context()
# 从 ChatManager 中移除
del chat_manager.streams[stream_id] del chat_manager.streams[stream_id]
logger.info(f"清理不活跃聊天流: {stream_id}") logger.info(f"清理不活跃聊天流: {stream_id}")
except Exception as e: except Exception as e:
logger.error(f"清理聊天流 {stream_id} 失败: {e}") logger.error(f"清理聊天流 {stream_id} 失败: {e}")
if inactive_streams: if inactive_streams:
logger.info(f"已清理 {len(inactive_streams)} 个不活跃聊天流") logger.info(f"已清理 {len(inactive_streams)} 个不活跃聊天流")
else: else:
logger.debug("没有需要清理的不活跃聊天流") logger.debug("没有需要清理的不活跃聊天流")
except Exception as e: except Exception as e:
logger.error(f"清理不活跃聊天流时发生错误: {e}") logger.error(f"清理不活跃聊天流时发生错误: {e}")

View File

@@ -514,7 +514,7 @@ class ChatBot:
db_message.chat_info_group_platform = message.chat_stream.group_info.platform db_message.chat_info_group_platform = message.chat_stream.group_info.platform
# 添加消息到消息管理器 # 添加消息到消息管理器
message_manager.add_message(message.chat_stream.stream_id, db_message) await message_manager.add_message(message.chat_stream.stream_id, db_message)
logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}") logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}")
if template_group_name: if template_group_name:

View File

@@ -389,94 +389,105 @@ class ChatStream:
from sqlalchemy import select, desc from sqlalchemy import select, desc
import asyncio import asyncio
async def _load_messages(): async def _load_history_messages_async():
def _db_query(): """异步加载并转换历史消息到 stream_context在事件循环中运行"""
with get_db_session() as session: try:
# 查询该stream_id的最近20条消息 async with get_db_session() as session:
stmt = ( stmt = (
select(Messages) select(Messages)
.where(Messages.chat_info_stream_id == self.stream_id) .where(Messages.chat_info_stream_id == self.stream_id)
.order_by(desc(Messages.time)) .order_by(desc(Messages.time))
.limit(global_config.chat.max_context_size) .limit(global_config.chat.max_context_size)
) )
result = session.execute(stmt) result = await session.execute(stmt)
results = result.scalars().all() db_messages = result.scalars().all()
return results
# 在线程中执行数据库查询 # 转换为DatabaseMessages对象并添加到StreamContext
db_messages = await asyncio.to_thread(_db_query) for db_msg in db_messages:
try:
import orjson
# 转换为DatabaseMessages对象并添加到StreamContext actions = None
for db_msg in db_messages: if db_msg.actions:
try:
actions = orjson.loads(db_msg.actions)
except (orjson.JSONDecodeError, TypeError):
actions = None
db_message = DatabaseMessages(
message_id=db_msg.message_id,
time=db_msg.time,
chat_id=db_msg.chat_id,
reply_to=db_msg.reply_to,
interest_value=db_msg.interest_value,
key_words=db_msg.key_words,
key_words_lite=db_msg.key_words_lite,
is_mentioned=db_msg.is_mentioned,
processed_plain_text=db_msg.processed_plain_text,
display_message=db_msg.display_message,
priority_mode=db_msg.priority_mode,
priority_info=db_msg.priority_info,
additional_config=db_msg.additional_config,
is_emoji=db_msg.is_emoji,
is_picid=db_msg.is_picid,
is_command=db_msg.is_command,
is_notify=db_msg.is_notify,
user_id=db_msg.user_id,
user_nickname=db_msg.user_nickname,
user_cardname=db_msg.user_cardname,
user_platform=db_msg.user_platform,
chat_info_group_id=db_msg.chat_info_group_id,
chat_info_group_name=db_msg.chat_info_group_name,
chat_info_group_platform=db_msg.chat_info_group_platform,
chat_info_user_id=db_msg.chat_info_user_id,
chat_info_user_nickname=db_msg.chat_info_user_nickname,
chat_info_user_cardname=db_msg.chat_info_user_cardname,
chat_info_user_platform=db_msg.chat_info_user_platform,
chat_info_stream_id=db_msg.chat_info_stream_id,
chat_info_platform=db_msg.chat_info_platform,
chat_info_create_time=db_msg.chat_info_create_time,
chat_info_last_active_time=db_msg.chat_info_last_active_time,
actions=actions,
should_reply=getattr(db_msg, "should_reply", False) or False,
)
logger.debug(
f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}"
)
db_message.is_read = True
self.stream_context.history_messages.append(db_message)
except Exception as e:
logger.warning(f"转换消息 {getattr(db_msg, 'message_id', '<unknown>')} 失败: {e}")
continue
if self.stream_context.history_messages:
logger.info(
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}"
)
except Exception as e:
logger.warning(f"异步加载历史消息失败: {e}")
# 在已有事件循环中,避免调用 asyncio.run()
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# 没有运行的事件循环,安全地运行并等待完成
asyncio.run(_load_history_messages_async())
else:
# 如果事件循环正在运行,在后台创建任务
if loop.is_running():
try: try:
# 从SQLAlchemy模型转换为DatabaseMessages数据模型 asyncio.create_task(_load_history_messages_async())
import orjson
# 解析actions字段JSON格式
actions = None
if db_msg.actions:
try:
actions = orjson.loads(db_msg.actions)
except (orjson.JSONDecodeError, TypeError):
actions = None
db_message = DatabaseMessages(
message_id=db_msg.message_id,
time=db_msg.time,
chat_id=db_msg.chat_id,
reply_to=db_msg.reply_to,
interest_value=db_msg.interest_value,
key_words=db_msg.key_words,
key_words_lite=db_msg.key_words_lite,
is_mentioned=db_msg.is_mentioned,
processed_plain_text=db_msg.processed_plain_text,
display_message=db_msg.display_message,
priority_mode=db_msg.priority_mode,
priority_info=db_msg.priority_info,
additional_config=db_msg.additional_config,
is_emoji=db_msg.is_emoji,
is_picid=db_msg.is_picid,
is_command=db_msg.is_command,
is_notify=db_msg.is_notify,
user_id=db_msg.user_id,
user_nickname=db_msg.user_nickname,
user_cardname=db_msg.user_cardname,
user_platform=db_msg.user_platform,
chat_info_group_id=db_msg.chat_info_group_id,
chat_info_group_name=db_msg.chat_info_group_name,
chat_info_group_platform=db_msg.chat_info_group_platform,
chat_info_user_id=db_msg.chat_info_user_id,
chat_info_user_nickname=db_msg.chat_info_user_nickname,
chat_info_user_cardname=db_msg.chat_info_user_cardname,
chat_info_user_platform=db_msg.chat_info_user_platform,
chat_info_stream_id=db_msg.chat_info_stream_id,
chat_info_platform=db_msg.chat_info_platform,
chat_info_create_time=db_msg.chat_info_create_time,
chat_info_last_active_time=db_msg.chat_info_last_active_time,
actions=actions,
should_reply=getattr(db_msg, "should_reply", False) or False,
)
# 添加调试日志检查从数据库加载的interest_value
logger.debug(
f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}"
)
# 标记为已读并添加到历史消息
db_message.is_read = True
self.stream_context.history_messages.append(db_message)
except Exception as e: except Exception as e:
logger.warning(f"转换消息 {db_msg.message_id} 失败: {e}") # 如果无法创建任务,退回到阻塞运行
continue logger.warning(f"无法在事件循环中创建后台任务,尝试阻塞运行: {e}")
asyncio.run(_load_history_messages_async())
if self.stream_context.history_messages: else:
logger.info( # loop 存在但未运行,使用 asyncio.run
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}" asyncio.run(_load_history_messages_async())
)
# 创建任务来加载历史消息
asyncio.create_task(_load_messages())
except Exception as e: except Exception as e:
logger.error(f"加载历史消息失败: {e}") logger.error(f"加载历史消息失败: {e}")
@@ -498,7 +509,7 @@ class ChatManager:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
# try: # try:
# with get_db_session() as session: # async with get_db_session() as session:
# db.connect(reuse_if_open=True) # db.connect(reuse_if_open=True)
# # 确保 ChatStreams 表存在 # # 确保 ChatStreams 表存在
# session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)")) # session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))

View File

@@ -219,7 +219,7 @@ class MessageStorage:
return match.group(0) return match.group(0)
@staticmethod @staticmethod
def update_message_interest_value(message_id: str, interest_value: float) -> None: async def update_message_interest_value(message_id: str, interest_value: float) -> None:
""" """
更新数据库中消息的interest_value字段 更新数据库中消息的interest_value字段
@@ -228,11 +228,11 @@ class MessageStorage:
interest_value: 兴趣度值 interest_value: 兴趣度值
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 更新消息的interest_value字段 # 更新消息的interest_value字段
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value) stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value)
result = session.execute(stmt) result = await session.execute(stmt)
session.commit() await session.commit()
if result.rowcount > 0: if result.rowcount > 0:
logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}") logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}")
@@ -244,7 +244,7 @@ class MessageStorage:
raise raise
@staticmethod @staticmethod
def fix_zero_interest_values(chat_id: str, since_time: float) -> int: async def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
""" """
修复指定聊天中interest_value为0或null的历史消息记录 修复指定聊天中interest_value为0或null的历史消息记录
@@ -256,7 +256,7 @@ class MessageStorage:
修复的记录数量 修复的记录数量
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
from sqlalchemy import select, update from sqlalchemy import select, update
from src.common.database.sqlalchemy_models import Messages from src.common.database.sqlalchemy_models import Messages
@@ -271,7 +271,7 @@ class MessageStorage:
) )
).limit(50) # 限制每次修复的数量,避免性能问题 ).limit(50) # 限制每次修复的数量,避免性能问题
result = session.execute(query) result = await session.execute(query)
messages_to_fix = result.scalars().all() messages_to_fix = result.scalars().all()
fixed_count = 0 fixed_count = 0
@@ -297,12 +297,12 @@ class MessageStorage:
Messages.message_id == msg.message_id Messages.message_id == msg.message_id
).values(interest_value=default_interest) ).values(interest_value=default_interest)
result = session.execute(update_stmt) result = await session.execute(update_stmt)
if result.rowcount > 0: if result.rowcount > 0:
fixed_count += 1 fixed_count += 1
logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}") logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}")
session.commit() await session.commit()
logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值") logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值")
return fixed_count return fixed_count

View File

@@ -297,15 +297,12 @@ class ChatterActionManager:
return return
# 通过message_manager更新消息的动作记录并刷新focus_energy # 通过message_manager更新消息的动作记录并刷新focus_energy
if chat_stream.stream_id in message_manager.stream_contexts: await message_manager.add_action(
message_manager.add_action( stream_id=chat_stream.stream_id,
stream_id=chat_stream.stream_id, message_id=target_message_id,
message_id=target_message_id, action=action_name
action=action_name )
) logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy")
logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy")
else:
logger.debug(f"未找到stream_context: {chat_stream.stream_id}")
except Exception as e: except Exception as e:
logger.error(f"记录动作到消息失败: {e}") logger.error(f"记录动作到消息失败: {e}")
@@ -315,8 +312,11 @@ class ChatterActionManager:
"""在动作执行成功后重置打断计数""" """在动作执行成功后重置打断计数"""
from src.chat.message_manager.message_manager import message_manager from src.chat.message_manager.message_manager import message_manager
try: try:
if stream_id in message_manager.stream_contexts: from src.plugin_system.apis.chat_api import get_chat_manager
context = message_manager.stream_contexts[stream_id] chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream:
context = chat_stream.context_manager
if context.interruption_count > 0: if context.interruption_count > 0:
old_count = context.interruption_count old_count = context.interruption_count
old_afc_adjustment = context.get_afc_threshold_adjustment() old_afc_adjustment = context.get_afc_threshold_adjustment()

View File

@@ -73,7 +73,7 @@ class ActionModifier:
from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.utils.utils import get_chat_type_and_target_info
# 获取聊天类型 # 获取聊天类型
is_group_chat, _ = await get_chat_type_and_target_info(self.chat_id) is_group_chat, _ = get_chat_type_and_target_info(self.chat_id)
all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION) all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION)
chat_type_removals = [] chat_type_removals = []

View File

@@ -684,8 +684,11 @@ class DefaultReplyer:
from src.chat.message_manager.message_manager import message_manager from src.chat.message_manager.message_manager import message_manager
# 获取聊天流的上下文 # 获取聊天流的上下文
stream_context = message_manager.stream_contexts.get(chat_id) from src.plugin_system.apis.chat_api import get_chat_manager
if stream_context: chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(chat_id)
if chat_stream:
stream_context = chat_stream.context_manager
# 使用真正的已读和未读消息 # 使用真正的已读和未读消息
read_messages = stream_context.history_messages # 已读消息 read_messages = stream_context.history_messages # 已读消息
unread_messages = stream_context.get_unread_messages() # 未读消息 unread_messages = stream_context.get_unread_messages() # 未读消息
@@ -693,7 +696,7 @@ class DefaultReplyer:
# 构建已读历史消息 prompt # 构建已读历史消息 prompt
read_history_prompt = "" read_history_prompt = ""
if read_messages: if read_messages:
read_content = build_readable_messages( read_content = await build_readable_messages(
[msg.flatten() for msg in read_messages[-50:]], # 限制数量 [msg.flatten() for msg in read_messages[-50:]], # 限制数量
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
@@ -716,7 +719,7 @@ class DefaultReplyer:
] ]
if filtered_fallback_messages: if filtered_fallback_messages:
read_content = build_readable_messages( read_content = await build_readable_messages(
filtered_fallback_messages, filtered_fallback_messages,
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
@@ -754,7 +757,7 @@ class DefaultReplyer:
if platform and user_id: if platform and user_id:
person_id = PersonInfoManager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户" sender_name = person_info_manager.get_value(person_id, "person_name") or "未知用户"
else: else:
sender_name = "未知用户" sender_name = "未知用户"
@@ -819,7 +822,7 @@ class DefaultReplyer:
# 构建已读历史消息 prompt # 构建已读历史消息 prompt
read_history_prompt = "" read_history_prompt = ""
if read_messages: if read_messages:
read_content = build_readable_messages( read_content = await build_readable_messages(
read_messages[-50:], read_messages[-50:],
replace_bot_name=True, replace_bot_name=True,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
@@ -853,7 +856,7 @@ class DefaultReplyer:
if platform and user_id: if platform and user_id:
person_id = PersonInfoManager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户" sender_name = person_info_manager.get_value(person_id, "person_name") or "未知用户"
else: else:
sender_name = "未知用户" sender_name = "未知用户"
@@ -1027,7 +1030,7 @@ class DefaultReplyer:
# 检查是否是bot自己的名字如果是则替换为"(你)" # 检查是否是bot自己的名字如果是则替换为"(你)"
bot_user_id = str(global_config.bot.qq_account) bot_user_id = str(global_config.bot.qq_account)
current_user_id = person_info_manager.get_value_sync(person_id, "user_id") current_user_id = person_info_manager.get_value(person_id, "user_id")
current_platform = reply_message.get("chat_info_platform") current_platform = reply_message.get("chat_info_platform")
if current_user_id == bot_user_id and current_platform == global_config.bot.platform: if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
@@ -1046,7 +1049,7 @@ class DefaultReplyer:
target = "(无消息内容)" target = "(无消息内容)"
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender) person_id = await person_info_manager.get_person_id_by_person_name(sender)
platform = chat_stream.platform platform = chat_stream.platform
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
@@ -1071,7 +1074,7 @@ class DefaultReplyer:
timestamp=time.time(), timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33), limit=int(global_config.chat.max_context_size * 0.33),
) )
chat_talking_prompt_short = build_readable_messages( chat_talking_prompt_short = await build_readable_messages(
message_list_before_short, message_list_before_short,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
@@ -1324,7 +1327,7 @@ class DefaultReplyer:
timestamp=time.time(), timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15), limit=min(int(global_config.chat.max_context_size * 0.33), 15),
) )
chat_talking_prompt_half = build_readable_messages( chat_talking_prompt_half = await build_readable_messages(
message_list_before_now_half, message_list_before_now_half,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
@@ -1523,7 +1526,7 @@ class DefaultReplyer:
# 获取用户ID # 获取用户ID
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender) person_id = await person_info_manager.get_person_id_by_person_name(sender)
if not person_id: if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取") logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。" return f"你完全不认识{sender}不理解ta的相关信息。"

View File

@@ -46,7 +46,7 @@ def replace_user_references_sync(
if replace_bot_name and user_id == global_config.bot.qq_account: if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)" return f"{global_config.bot.nickname}(你)"
person_id = PersonInfoManager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore return person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
name_resolver = default_resolver name_resolver = default_resolver
@@ -254,7 +254,7 @@ def get_raw_msg_by_timestamp_with_chat_users(
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_actions_by_timestamp_with_chat( async def get_actions_by_timestamp_with_chat(
chat_id: str, chat_id: str,
timestamp_start: float = 0, timestamp_start: float = 0,
timestamp_end: float = time.time(), timestamp_end: float = time.time(),
@@ -273,22 +273,21 @@ def get_actions_by_timestamp_with_chat(
f"limit={limit}, limit_mode={limit_mode}" f"limit={limit}, limit_mode={limit_mode}"
) )
with get_db_session() as session: async with get_db_session() as session:
if limit > 0: if limit > 0:
if limit_mode == "latest": result = await session.execute(
query = session.execute(
select(ActionRecords) select(ActionRecords)
.where( .where(
and_( and_(
ActionRecords.chat_id == chat_id, ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start, ActionRecords.time >= timestamp_start,
ActionRecords.time < timestamp_end, ActionRecords.time <= timestamp_end,
) )
) )
.order_by(ActionRecords.time.desc()) .order_by(ActionRecords.time.desc())
.limit(limit) .limit(limit)
) )
actions = list(query.scalars()) actions = list(result.scalars())
actions_result = [] actions_result = []
for action in reversed(actions): for action in reversed(actions):
action_dict = { action_dict = {
@@ -305,38 +304,39 @@ def get_actions_by_timestamp_with_chat(
"chat_info_platform": action.chat_info_platform, "chat_info_platform": action.chat_info_platform,
} }
actions_result.append(action_dict) actions_result.append(action_dict)
else: # earliest
query = session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end,
)
)
.order_by(ActionRecords.time.asc())
.limit(limit)
)
actions = list(query.scalars())
actions_result = []
for action in actions:
action_dict = {
"id": action.id,
"action_id": action.action_id,
"time": action.time,
"action_name": action.action_name,
"action_data": action.action_data,
"action_done": action.action_done,
"action_build_into_prompt": action.action_build_into_prompt,
"action_prompt_display": action.action_prompt_display,
"chat_id": action.chat_id,
"chat_info_stream_id": action.chat_info_stream_id,
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict) actions_result.append(action_dict)
else: # earliest
result = await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end,
)
)
.order_by(ActionRecords.time.asc())
.limit(limit)
)
actions = list(result.scalars())
actions_result = []
for action in actions:
action_dict = {
"id": action.id,
"action_id": action.action_id,
"time": action.time,
"action_name": action.action_name,
"action_data": action.action_data,
"action_done": action.action_done,
"action_build_into_prompt": action.action_build_into_prompt,
"action_prompt_display": action.action_prompt_display,
"chat_id": action.chat_id,
"chat_info_stream_id": action.chat_info_stream_id,
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
else: else:
query = session.execute( result = await session.execute(
select(ActionRecords) select(ActionRecords)
.where( .where(
and_( and_(
@@ -347,7 +347,7 @@ def get_actions_by_timestamp_with_chat(
) )
.order_by(ActionRecords.time.asc()) .order_by(ActionRecords.time.asc())
) )
actions = list(query.scalars()) actions = list(result.scalars())
actions_result = [] actions_result = []
for action in actions: for action in actions:
action_dict = { action_dict = {
@@ -367,14 +367,14 @@ def get_actions_by_timestamp_with_chat(
return actions_result return actions_result
def get_actions_by_timestamp_with_chat_inclusive( async def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
with get_db_session() as session: async with get_db_session() as session:
if limit > 0: if limit > 0:
if limit_mode == "latest": if limit_mode == "latest":
query = session.execute( result = await session.execute(
select(ActionRecords) select(ActionRecords)
.where( .where(
and_( and_(
@@ -386,10 +386,10 @@ def get_actions_by_timestamp_with_chat_inclusive(
.order_by(ActionRecords.time.desc()) .order_by(ActionRecords.time.desc())
.limit(limit) .limit(limit)
) )
actions = list(query.scalars()) actions = list(result.scalars())
return [action.__dict__ for action in reversed(actions)] return [action.__dict__ for action in reversed(actions)]
else: # earliest else: # earliest
query = session.execute( result = await session.execute(
select(ActionRecords) select(ActionRecords)
.where( .where(
and_( and_(
@@ -402,7 +402,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
.limit(limit) .limit(limit)
) )
else: else:
query = session.execute( query = await session.execute(
select(ActionRecords) select(ActionRecords)
.where( .where(
and_( and_(
@@ -507,7 +507,7 @@ def num_new_messages_since_with_users(
return count_messages(message_filter=filter_query) return count_messages(message_filter=filter_query)
def _build_readable_messages_internal( async def _build_readable_messages_internal(
messages: List[Dict[str, Any]], messages: List[Dict[str, Any]],
replace_bot_name: bool = True, replace_bot_name: bool = True,
merge_messages: bool = False, merge_messages: bool = False,
@@ -627,7 +627,7 @@ def _build_readable_messages_internal(
if replace_bot_name and user_id == global_config.bot.qq_account: if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)" person_name = f"{global_config.bot.nickname}(你)"
else: else:
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore person_name = await person_info_manager.get_value(person_id, "person_name") # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称 # 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name: if not person_name:
@@ -800,7 +800,7 @@ def _build_readable_messages_internal(
) )
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# sourcery skip: use-contextlib-suppress # sourcery skip: use-contextlib-suppress
""" """
构建图片映射信息字符串,显示图片的具体描述内容 构建图片映射信息字符串,显示图片的具体描述内容
@@ -823,8 +823,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述 # 从数据库中获取图片描述
description = "[图片内容未知]" # 默认描述 description = "[图片内容未知]" # 默认描述
try: try:
with get_db_session() as session: async with get_db_session() as session:
result = session.execute(select(Images).where(Images.image_id == pic_id)) result = await session.execute(select(Images).where(Images.image_id == pic_id))
image = result.scalar_one_or_none() image = result.scalar_one_or_none()
if image and image.description: # type: ignore if image and image.description: # type: ignore
description = image.description description = image.description
@@ -922,17 +922,17 @@ async def build_readable_messages_with_list(
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。 允许通过参数控制格式化行为。
""" """
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate messages, replace_bot_name, merge_messages, timestamp_mode, truncate
) )
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping):
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
return formatted_string, details_list return formatted_string, details_list
def build_readable_messages_with_id( async def build_readable_messages_with_id(
messages: List[Dict[str, Any]], messages: List[Dict[str, Any]],
replace_bot_name: bool = True, replace_bot_name: bool = True,
merge_messages: bool = False, merge_messages: bool = False,
@@ -948,7 +948,7 @@ def build_readable_messages_with_id(
""" """
message_id_list = assign_message_ids(messages) message_id_list = assign_message_ids(messages)
formatted_string = build_readable_messages( formatted_string = await build_readable_messages(
messages=messages, messages=messages,
replace_bot_name=replace_bot_name, replace_bot_name=replace_bot_name,
merge_messages=merge_messages, merge_messages=merge_messages,
@@ -960,10 +960,16 @@ def build_readable_messages_with_id(
message_id_list=message_id_list, message_id_list=message_id_list,
) )
# 如果存在图片映射信息,附加之
if pic_mapping_info := await build_pic_mapping_info({}):
# 如果当前没有图片映射则不附加
if pic_mapping_info:
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
return formatted_string, message_id_list return formatted_string, message_id_list
def build_readable_messages( async def build_readable_messages(
messages: List[Dict[str, Any]], messages: List[Dict[str, Any]],
replace_bot_name: bool = True, replace_bot_name: bool = True,
merge_messages: bool = False, merge_messages: bool = False,
@@ -1004,9 +1010,9 @@ def build_readable_messages(
from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_database_api import get_db_session
with get_db_session() as session: async with get_db_session() as session:
# 获取这个时间范围内的动作记录并匹配chat_id # 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = session.execute( actions_in_range = (await session.execute(
select(ActionRecords) select(ActionRecords)
.where( .where(
and_( and_(
@@ -1014,15 +1020,15 @@ def build_readable_messages(
) )
) )
.order_by(ActionRecords.time) .order_by(ActionRecords.time)
).scalars() )).scalars()
# 获取最新消息之后的第一个动作记录 # 获取最新消息之后的第一个动作记录
action_after_latest = session.execute( action_after_latest = (await session.execute(
select(ActionRecords) select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time) .order_by(ActionRecords.time)
.limit(1) .limit(1)
).scalars() )).scalars()
# 合并两部分动作记录,并转为 dict避免 DetachedInstanceError # 合并两部分动作记录,并转为 dict避免 DetachedInstanceError
actions = [ actions = [
@@ -1053,7 +1059,7 @@ def build_readable_messages(
if read_mark <= 0: if read_mark <= 0:
# 没有有效的 read_mark直接格式化所有消息 # 没有有效的 read_mark直接格式化所有消息
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal( formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal(
copy_messages, copy_messages,
replace_bot_name, replace_bot_name,
merge_messages, merge_messages,
@@ -1064,7 +1070,7 @@ def build_readable_messages(
) )
# 生成图片映射信息并添加到最前面 # 生成图片映射信息并添加到最前面
pic_mapping_info = build_pic_mapping_info(pic_id_mapping) pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
if pic_mapping_info: if pic_mapping_info:
return f"{pic_mapping_info}\n\n{formatted_string}" return f"{pic_mapping_info}\n\n{formatted_string}"
else: else:
@@ -1079,7 +1085,7 @@ def build_readable_messages(
pic_counter = 1 pic_counter = 1
# 分别格式化,但使用共享的图片映射 # 分别格式化,但使用共享的图片映射
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal( formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal(
messages_before_mark, messages_before_mark,
replace_bot_name, replace_bot_name,
merge_messages, merge_messages,
@@ -1090,7 +1096,7 @@ def build_readable_messages(
show_pic=show_pic, show_pic=show_pic,
message_id_list=message_id_list, message_id_list=message_id_list,
) )
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal( formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal(
messages_after_mark, messages_after_mark,
replace_bot_name, replace_bot_name,
merge_messages, merge_messages,
@@ -1106,7 +1112,7 @@ def build_readable_messages(
# 生成图片映射信息 # 生成图片映射信息
if pic_id_mapping: if pic_id_mapping:
pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n" pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
else: else:
pic_mapping_info = "聊天记录信息:\n" pic_mapping_info = "聊天记录信息:\n"
@@ -1229,7 +1235,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# 在最前面添加图片映射信息 # 在最前面添加图片映射信息
final_output_lines = [] final_output_lines = []
pic_mapping_info = build_pic_mapping_info(pic_id_mapping) pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
if pic_mapping_info: if pic_mapping_info:
final_output_lines.append(pic_mapping_info) final_output_lines.append(pic_mapping_info)
final_output_lines.append("\n\n") final_output_lines.append("\n\n")

View File

@@ -494,7 +494,7 @@ class Prompt:
chat_history = "" chat_history = ""
if self.parameters.message_list_before_now_long: if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-10:] recent_messages = self.parameters.message_list_before_now_long[-10:]
chat_history = build_readable_messages( chat_history = await build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
) )
@@ -535,7 +535,7 @@ class Prompt:
chat_history = "" chat_history = ""
if self.parameters.message_list_before_now_long: if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-20:] recent_messages = self.parameters.message_list_before_now_long[-20:]
chat_history = build_readable_messages( chat_history = await build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
) )
@@ -589,7 +589,7 @@ class Prompt:
chat_history = "" chat_history = ""
if self.parameters.message_list_before_now_long: if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-15:] recent_messages = self.parameters.message_list_before_now_long[-15:]
chat_history = build_readable_messages( chat_history = await build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
) )
@@ -863,7 +863,7 @@ class Prompt:
# 获取用户ID # 获取用户ID
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender) person_id = await person_info_manager.get_person_id_by_person_name(sender)
if not person_id: if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取") logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。" return f"你完全不认识{sender}不理解ta的相关信息。"
@@ -904,7 +904,7 @@ class Prompt:
return "" return ""
@staticmethod @staticmethod
def parse_reply_target_id(reply_to: str) -> str: async def parse_reply_target_id(reply_to: str) -> str:
""" """
解析回复目标中的用户ID 解析回复目标中的用户ID
@@ -924,9 +924,9 @@ class Prompt:
# 获取用户ID # 获取用户ID
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender) person_id = await person_info_manager.get_person_id_by_person_name(sender)
if person_id: if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id") user_id = person_info_manager.get_value(person_id, "user_id")
return str(user_id) if user_id else "" return str(user_id) if user_id else ""
return "" return ""

View File

@@ -1,3 +1,4 @@
import asyncio
import random import random
import re import re
import string import string
@@ -662,9 +663,32 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
person_id = PersonInfoManager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
person_name = None person_name = None
if person_id: if person_id:
# get_value is async, so await it directly
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
person_name = person_info_manager.get_value_sync(person_id, "person_name") try:
# 如果没有运行的事件循环,直接 asyncio.run
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环在运行,从其他线程提交并等待结果
try:
from concurrent.futures import TimeoutError
fut = asyncio.run_coroutine_threadsafe(
person_info_manager.get_value(person_id, "person_name"), loop
)
person_name = fut.result(timeout=2)
except Exception as e:
# 无法在运行循环上安全等待,退回为 None
logger.debug(f"无法通过运行的事件循环获取 person_name: {e}")
person_name = None
else:
person_name = asyncio.run(person_info_manager.get_value(person_id, "person_name"))
except RuntimeError:
# get_event_loop 在某些上下文可能抛出 RuntimeError退回到 asyncio.run
try:
person_name = asyncio.run(person_info_manager.get_value(person_id, "person_name"))
except Exception as e:
logger.debug(f"获取 person_name 失败: {e}")
person_name = None
target_info["person_id"] = person_id target_info["person_id"] = person_id
target_info["person_name"] = person_name target_info["person_name"] = person_name

View File

@@ -344,6 +344,39 @@ class StreamContext(BaseDataModel):
"""获取优先级信息""" """获取优先级信息"""
return self.priority_info return self.priority_info
def __deepcopy__(self, memo):
"""自定义深拷贝,跳过不可序列化的 asyncio.Task (processing_task)。
deepcopy 在内部可能会尝试 pickle 某些对象(如 asyncio.Task
这会在多线程或运行时事件循环中导致 TypeError。这里我们手动复制
__dict__ 中的字段,确保 processing_task 被设置为 None其他字段使用
copy.deepcopy 递归复制。
"""
import copy
# 如果已经复制过,直接返回缓存结果
obj_id = id(self)
if obj_id in memo:
return memo[obj_id]
# 创建一个未初始化的新实例,然后逐个字段深拷贝
cls = self.__class__
new = cls.__new__(cls)
memo[obj_id] = new
for k, v in self.__dict__.items():
if k == "processing_task":
# 不复制 asyncio.Task避免无法 pickling
setattr(new, k, None)
else:
try:
setattr(new, k, copy.deepcopy(v, memo))
except Exception:
# 如果某个字段无法深拷贝,退回到原始引用(安全性谨慎)
setattr(new, k, v)
return new
@dataclass @dataclass
class MessageManagerStats(BaseDataModel): class MessageManagerStats(BaseDataModel):

View File

@@ -30,7 +30,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns} return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns}
def find_messages( async def find_messages(
message_filter: dict[str, Any], message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None, sort: Optional[List[tuple[str, int]]] = None,
limit: int = 0, limit: int = 0,
@@ -51,7 +51,7 @@ def find_messages(
消息字典列表,如果出错则返回空列表。 消息字典列表,如果出错则返回空列表。
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
query = select(Messages) query = select(Messages)
# 应用过滤器 # 应用过滤器
@@ -101,8 +101,8 @@ def find_messages(
# 获取时间最早的 limit 条记录,已经是正序 # 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit) query = query.order_by(Messages.time.asc()).limit(limit)
try: try:
results = result = session.execute(query) result = await session.execute(query)
result.scalars().all() results = result.scalars().all()
except Exception as e: except Exception as e:
logger.error(f"执行earliest查询失败: {e}") logger.error(f"执行earliest查询失败: {e}")
results = [] results = []
@@ -110,8 +110,8 @@ def find_messages(
# 获取时间最晚的 limit 条记录 # 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit) query = query.order_by(Messages.time.desc()).limit(limit)
try: try:
latest_results = result = session.execute(query) result = await session.execute(query)
result.scalars().all() latest_results = result.scalars().all()
# 将结果按时间正序排列 # 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time) results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e: except Exception as e:
@@ -135,8 +135,8 @@ def find_messages(
if sort_terms: if sort_terms:
query = query.order_by(*sort_terms) query = query.order_by(*sort_terms)
try: try:
results = result = session.execute(query) result = await session.execute(query)
result.scalars().all() results = result.scalars().all()
except Exception as e: except Exception as e:
logger.error(f"执行无限制查询失败: {e}") logger.error(f"执行无限制查询失败: {e}")
results = [] results = []
@@ -152,7 +152,7 @@ def find_messages(
return [] return []
def count_messages(message_filter: dict[str, Any]) -> int: async def count_messages(message_filter: dict[str, Any]) -> int:
""" """
根据提供的过滤器计算消息数量。 根据提供的过滤器计算消息数量。
@@ -163,7 +163,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量,如果出错则返回 0。 符合条件的消息数量,如果出错则返回 0。
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
query = select(func.count(Messages.id)) query = select(func.count(Messages.id))
# 应用过滤器 # 应用过滤器
@@ -201,7 +201,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
if conditions: if conditions:
query = query.where(*conditions) query = query.where(*conditions)
count = session.execute(query).scalar() count = (await session.execute(query)).scalar()
return count or 0 return count or 0
except Exception as e: except Exception as e:
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"

View File

@@ -148,8 +148,8 @@ class MainSystem:
# 停止消息重组器 # 停止消息重组器
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
from src.plugin_system import EventType from src.plugin_system import EventType
asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM")) asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
from src.utils.message_chunker import reassembler from src.utils.message_chunker import reassembler
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()

View File

@@ -110,7 +110,7 @@ class ChatMood:
limit=int(global_config.chat.max_context_size / 3), limit=int(global_config.chat.max_context_size / 3),
limit_mode="last", limit_mode="last",
) )
chat_talking_prompt = build_readable_messages( chat_talking_prompt = await build_readable_messages(
message_list_before_now, message_list_before_now,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
@@ -159,7 +159,7 @@ class ChatMood:
limit=15, limit=15,
limit_mode="last", limit_mode="last",
) )
chat_talking_prompt = build_readable_messages( chat_talking_prompt = await build_readable_messages(
message_list_before_now, message_list_before_now,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,

View File

@@ -1,3 +1,4 @@
import asyncio
import copy import copy
import datetime import datetime
import hashlib import hashlib
@@ -57,7 +58,7 @@ class PersonInfoManager:
self.person_name_list = {} self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
# try: # try:
# with get_db_session() as session: # async with get_db_session() as session:
# db.connect(reuse_if_open=True) # db.connect(reuse_if_open=True)
# # 设置连接池参数仅对SQLite有效 # # 设置连接池参数仅对SQLite有效
# if hasattr(db, "execute_sql"): # if hasattr(db, "execute_sql"):
@@ -75,7 +76,7 @@ class PersonInfoManager:
try: try:
pass pass
# 在这里获取会话 # 在这里获取会话
# with get_db_session() as session: # async with get_db_session() as session:
# for record in session.execute( # for record in session.execute(
# select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None)) # select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
# ).fetchall(): # ).fetchall():
@@ -87,58 +88,25 @@ class PersonInfoManager:
@staticmethod @staticmethod
def get_person_id(platform: str, user_id: Union[int, str]) -> str: def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id""" """获取唯一id(同步)
说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。
为了避免将 coroutine 传递到其它同步调用(例如数据库查询条件)中,这里将方法改为同步并仅返回基于 platform 和 user_id 的 MD5 哈希值。
注意: 这会跳过原有的 napcat->qq 迁移检查逻辑。如需保留迁移,请使用显式的、在合适时机执行的迁移任务。
"""
# 检查platform是否为None或空 # 检查platform是否为None或空
if platform is None: if platform is None:
platform = "unknown" platform = "unknown"
if "-" in platform: if "-" in platform:
platform = platform.split("-")[1] platform = platform.split("-")[1]
# 在此处打一个补丁如果platform为qq尝试生成id后检查是否存在如果不存在则将平台换为napcat后再次检查如果存在则更新原id为platform为qq的id
components = [platform, str(user_id)] components = [platform, str(user_id)]
key = "_".join(components) key = "_".join(components)
# 如果不是 qq 平台,直接返回计算的 id # 直接返回计算的 id(同步)
if platform != "qq": return hashlib.md5(key.encode()).hexdigest()
return hashlib.md5(key.encode()).hexdigest()
qq_id = hashlib.md5(key.encode()).hexdigest()
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
try:
with get_db_session() as session:
# 检查 qq_id 是否存在
existing_qq = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if existing_qq:
return p_id
# 如果 qq_id 不存在,尝试使用 napcat 作为平台生成对应 id 并检查
nap_components = ["napcat", str(raw_user_id)]
nap_key = "_".join(nap_components)
nap_id = hashlib.md5(nap_key.encode()).hexdigest()
existing_nap = session.execute(select(PersonInfo).where(PersonInfo.person_id == nap_id)).scalar()
if not existing_nap:
# napcat 也不存在,返回 qq_id未命中
return p_id
# napcat 存在,迁移该记录:更新 person_id 与 platform -> qq
try:
# 更新现有 napcat 记录
existing_nap.person_id = p_id
existing_nap.platform = "qq"
existing_nap.user_id = str(raw_user_id)
session.commit()
return p_id
except Exception:
session.rollback()
return p_id
except Exception as e:
logger.error(f"检查/迁移 napcat->qq 时出错: {e}")
return p_id
return _db_check_and_migrate_sync(qq_id, user_id)
async def is_person_known(self, platform: str, user_id: int): async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人""" """判断是否认识某人"""
@@ -157,17 +125,25 @@ class PersonInfoManager:
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}") logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
return False return False
@staticmethod async def get_person_id_by_person_name(self, person_name: str) -> str:
async def get_person_id_by_person_name(person_name: str) -> str: """
"""根据用户名获取用户ID""" 根据用户名获取用户ID(同步)
说明: 为了避免在多个调用点将 coroutine 误传递到数据库查询中,
此处提供一个同步实现。优先在内存缓存 `self.person_name_list` 中查找,
若未命中则返回空字符串。若后续需要更强的一致性,可在异步上下文
额外实现带 await 的查询方法。
"""
try: try:
# 在需要时获取会话 # 优先使用内存缓存加速查找self.person_name_list maps person_id -> person_name
async with get_db_session() as session: for pid, pname in self.person_name_list.items():
record = result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)) if pname == person_name:
result.scalar() return pid
return record.person_id if record else ""
# 未找到缓存命中,避免在同步路径中进行阻塞的数据库查询,直接返回空字符串
return ""
except Exception as e: except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}") logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
return "" return ""
@staticmethod @staticmethod
@@ -578,26 +554,15 @@ class PersonInfoManager:
@staticmethod @staticmethod
def get_value(person_id: str, field_name: str) -> Any: async def get_value(person_id: str, field_name: str) -> Any:
"""获取单个字段值(同步版本)""" """获取单个字段值(同步版本)"""
if not person_id: if not person_id:
logger.debug("get_value获取失败person_id不能为空") logger.debug("get_value获取失败person_id不能为空")
return None return None
import asyncio async with get_db_session() as session:
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
async def _get_record_sync(): record = result.scalar()
async with get_db_session() as session:
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
record = result.scalar()
return record
try:
record = asyncio.run(_get_record_sync())
except RuntimeError:
# 如果当前线程已经有事件循环在运行,则使用现有的循环
loop = asyncio.get_running_loop()
record = loop.run_until_complete(_get_record_sync())
model_fields = [column.name for column in PersonInfo.__table__.columns] model_fields = [column.name for column in PersonInfo.__table__.columns]

View File

@@ -176,7 +176,7 @@ class RelationshipFetcher:
# 查询用户关系数据 # 查询用户关系数据
relationships = await db_query( relationships = await db_query(
UserRelationships, UserRelationships,
filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))], filters=[UserRelationships.user_id == str(person_info_manager.get_value(person_id, "user_id"))],
limit=1, limit=1,
) )
@@ -259,7 +259,7 @@ class RelationshipFetcher:
# 记录信息获取请求 # 记录信息获取请求
self.info_fetching_cache.append( self.info_fetching_cache.append(
{ {
"person_id": get_person_info_manager().get_person_id_by_person_name(person_name), "person_id": await get_person_info_manager().get_person_id_by_person_name(person_name),
"person_name": person_name, "person_name": person_name,
"info_type": info_type, "info_type": info_type,
"start_time": time.time(), "start_time": time.time(),

View File

@@ -412,7 +412,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
# ============================================================================= # =============================================================================
def build_readable_messages_to_str( async def build_readable_messages_to_str(
messages: List[Dict[str, Any]], messages: List[Dict[str, Any]],
replace_bot_name: bool = True, replace_bot_name: bool = True,
merge_messages: bool = False, merge_messages: bool = False,
@@ -436,7 +436,7 @@ def build_readable_messages_to_str(
Returns: Returns:
格式化后的可读字符串 格式化后的可读字符串
""" """
return build_readable_messages( return await build_readable_messages(
messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions
) )

View File

@@ -134,7 +134,7 @@ async def is_person_known(platform: str, user_id: int) -> bool:
return False return False
def get_person_id_by_name(person_name: str) -> str: async def get_person_id_by_name(person_name: str) -> str:
"""根据用户名获取person_id """根据用户名获取person_id
Args: Args:
@@ -148,7 +148,7 @@ def get_person_id_by_name(person_name: str) -> str:
""" """
try: try:
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
return person_info_manager.get_person_id_by_person_name(person_name) return await person_info_manager.get_person_id_by_person_name(person_name)
except Exception as e: except Exception as e:
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}") logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
return "" return ""

View File

@@ -542,7 +542,22 @@ class PluginManager:
plugin_instance.on_unload() plugin_instance.on_unload()
# 从组件注册表中移除插件的所有组件 # 从组件注册表中移除插件的所有组件
asyncio.run(component_registry.unregister_plugin(plugin_name)) try:
loop = asyncio.get_event_loop()
if loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
component_registry.unregister_plugin(plugin_name), loop
)
fut.result(timeout=5)
else:
asyncio.run(component_registry.unregister_plugin(plugin_name))
except Exception:
# 最后兜底:直接同步调用(如果 unregister_plugin 为非协程)或忽略错误
try:
# 如果 unregister_plugin 是普通函数
component_registry.unregister_plugin(plugin_name)
except Exception as e:
logger.debug(f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}")
# 从已加载插件中移除 # 从已加载插件中移除
del self.loaded_plugins[plugin_name] del self.loaded_plugins[plugin_name]

View File

@@ -199,7 +199,7 @@ class ChatterInterestScoringSystem:
# 如果内存中没有,尝试从关系追踪器获取 # 如果内存中没有,尝试从关系追踪器获取
if hasattr(self, "relationship_tracker") and self.relationship_tracker: if hasattr(self, "relationship_tracker") and self.relationship_tracker:
try: try:
relationship_score = self.relationship_tracker.get_user_relationship_score(user_id) relationship_score = await self.relationship_tracker.get_user_relationship_score(user_id)
# 同时更新内存缓存 # 同时更新内存缓存
self.user_relationships[user_id] = relationship_score self.user_relationships[user_id] = relationship_score
return relationship_score return relationship_score

View File

@@ -182,7 +182,7 @@ class ChatterPlanFilter:
if plan.mode == ChatMode.PROACTIVE: if plan.mode == ChatMode.PROACTIVE:
long_term_memory_block = await self._get_long_term_memory_context() long_term_memory_block = await self._get_long_term_memory_context()
chat_content_block, message_id_list = build_readable_messages_with_id( chat_content_block, message_id_list = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in plan.chat_history], messages=[msg.flatten() for msg in plan.chat_history],
timestamp_mode="normal", timestamp_mode="normal",
truncate=False, truncate=False,
@@ -190,7 +190,7 @@ class ChatterPlanFilter:
) )
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt") prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
actions_before_now = get_actions_by_timestamp_with_chat( actions_before_now = await get_actions_by_timestamp_with_chat(
chat_id=plan.chat_id, chat_id=plan.chat_id,
timestamp_start=time.time() - 3600, timestamp_start=time.time() - 3600,
timestamp_end=time.time(), timestamp_end=time.time(),
@@ -216,7 +216,7 @@ class ChatterPlanFilter:
) )
# 为了兼容性保留原有的chat_content_block # 为了兼容性保留原有的chat_content_block
chat_content_block, _ = build_readable_messages_with_id( chat_content_block, _ = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in plan.chat_history], messages=[msg.flatten() for msg in plan.chat_history],
timestamp_mode="normal", timestamp_mode="normal",
read_mark=self.last_obs_time_mark, read_mark=self.last_obs_time_mark,
@@ -224,7 +224,7 @@ class ChatterPlanFilter:
show_actions=True, show_actions=True,
) )
actions_before_now = get_actions_by_timestamp_with_chat( actions_before_now = await get_actions_by_timestamp_with_chat(
chat_id=plan.chat_id, chat_id=plan.chat_id,
timestamp_start=time.time() - 3600, timestamp_start=time.time() - 3600,
timestamp_end=time.time(), timestamp_end=time.time(),
@@ -319,7 +319,14 @@ class ChatterPlanFilter:
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
# 获取聊天流的上下文 # 获取聊天流的上下文
stream_context = message_manager.stream_contexts.get(plan.chat_id) from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(plan.chat_id)
if not chat_stream:
logger.warning(f"[plan_filter] 聊天流 {plan.chat_id} 不存在")
return "最近没有聊天内容。", "没有未读消息。", []
stream_context = chat_stream.context_manager
# 获取真正的已读和未读消息 # 获取真正的已读和未读消息
read_messages = stream_context.history_messages # 已读消息存储在history_messages中 read_messages = stream_context.history_messages # 已读消息存储在history_messages中
@@ -338,7 +345,7 @@ class ChatterPlanFilter:
# 构建已读历史消息块 # 构建已读历史消息块
if read_messages: if read_messages:
read_content, read_ids = build_readable_messages_with_id( read_content, read_ids = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in read_messages[-50:]], # 限制数量 messages=[msg.flatten() for msg in read_messages[-50:]], # 限制数量
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
truncate=False, truncate=False,

View File

@@ -138,7 +138,7 @@ class ChatterActionPlanner:
# 更新StreamContext中的消息信息并刷新focus_energy # 更新StreamContext中的消息信息并刷新focus_energy
if context: if context:
from src.chat.message_manager.message_manager import message_manager from src.chat.message_manager.message_manager import message_manager
message_manager.update_message( await message_manager.update_message(
stream_id=self.chat_id, stream_id=self.chat_id,
message_id=message.message_id, message_id=message.message_id,
interest_value=message_interest, interest_value=message_interest,
@@ -148,7 +148,7 @@ class ChatterActionPlanner:
# 更新数据库中的消息记录 # 更新数据库中的消息记录
try: try:
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
MessageStorage.update_message_interest_value(message.message_id, message_interest) await MessageStorage.update_message_interest_value(message.message_id, message_interest)
logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}") logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
except Exception as e: except Exception as e:
logger.warning(f"更新数据库消息兴趣度失败: {e}") logger.warning(f"更新数据库消息兴趣度失败: {e}")

View File

@@ -124,10 +124,10 @@ class EmojiAction(BaseAction):
emoji_base64, emoji_description = random.choice(all_emojis_data) emoji_base64, emoji_description = random.choice(all_emojis_data)
else: else:
# 获取最近的5条消息内容用于判断 # 获取最近的5条消息内容用于判断
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
messages_text = "" messages_text = ""
if recent_messages: if recent_messages:
messages_text = message_api.build_readable_messages( messages_text = await message_api.build_readable_messages(
messages=recent_messages, messages=recent_messages,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
truncate=False, truncate=False,
@@ -185,10 +185,10 @@ class EmojiAction(BaseAction):
elif global_config.emoji.emoji_selection_mode == "description": elif global_config.emoji.emoji_selection_mode == "description":
# --- 详细描述选择模式 --- # --- 详细描述选择模式 ---
# 获取最近的5条消息内容用于判断 # 获取最近的5条消息内容用于判断
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5) recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
messages_text = "" messages_text = ""
if recent_messages: if recent_messages:
messages_text = message_api.build_readable_messages( messages_text = await message_api.build_readable_messages(
messages=recent_messages, messages=recent_messages,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
truncate=False, truncate=False,

View File

@@ -118,7 +118,7 @@ class QZoneService:
async def read_and_process_feeds(self, target_name: str, stream_id: Optional[str]) -> Dict[str, Any]: async def read_and_process_feeds(self, target_name: str, stream_id: Optional[str]) -> Dict[str, Any]:
"""读取并处理指定好友的说说""" """读取并处理指定好友的说说"""
target_person_id = person_api.get_person_id_by_name(target_name) target_person_id = await person_api.get_person_id_by_name(target_name)
if not target_person_id: if not target_person_id:
return {"success": False, "message": f"找不到名为'{target_name}'的好友"} return {"success": False, "message": f"找不到名为'{target_name}'的好友"}
target_qq = await person_api.get_person_value(target_person_id, "user_id") target_qq = await person_api.get_person_value(target_person_id, "user_id")

View File

@@ -331,6 +331,7 @@ class NoticeHandler:
like_emoji_id = raw_message.get("likes")[0].get("emoji_id") like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
await event_manager.trigger_event( await event_manager.trigger_event(
<<<<<<< HEAD
NapcatEvent.ON_RECEIVED.EMOJI_LIEK, NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
permission_group=PLUGIN_NAME, permission_group=PLUGIN_NAME,
group_id=group_id, group_id=group_id,
@@ -342,6 +343,16 @@ class NoticeHandler:
type="text", type="text",
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]", data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
) )
=======
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
permission_group=PLUGIN_NAME,
group_id=group_id,
user_id=user_id,
message_id=raw_message.get("message_id",""),
emoji_id=like_emoji_id
)
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]")
>>>>>>> 9912d7f643d347cbadcf1e3d618aa78bcbf89cc4
return seg_data, user_info return seg_data, user_info
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:

View File

@@ -1,25 +0,0 @@
{
"manifest_version": 1,
"name": "web_search_tool",
"version": "1.0.0",
"description": "一个用于在互联网上搜索信息的工具",
"author": {
"name": "MoFox-Studio",
"url": "https://github.com/MoFox-Studio"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.0"
},
"keywords": ["web_search", "url_parser"],
"categories": ["web_search", "url_parser"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "web_search"
}
}

View File

@@ -1,110 +0,0 @@
#!/usr/bin/env python3
"""
测试 ChatStream 的 deepcopy 功能
验证 asyncio.Task 序列化问题是否已解决
"""
import asyncio
import sys
import os
import copy
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import UserInfo, GroupInfo
async def test_chat_stream_deepcopy():
"""测试 ChatStream 的 deepcopy 功能"""
print("[TEST] 开始测试 ChatStream deepcopy 功能...")
try:
# 创建测试用的用户和群组信息
user_info = UserInfo(
platform="test_platform",
user_id="test_user_123",
user_nickname="测试用户",
user_cardname="测试卡片名"
)
group_info = GroupInfo(
platform="test_platform",
group_id="test_group_456",
group_name="测试群组"
)
# 创建 ChatStream 实例
print("📝 创建 ChatStream 实例...")
stream_id = "test_stream_789"
platform = "test_platform"
chat_stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
print(f"[SUCCESS] ChatStream 创建成功: {chat_stream.stream_id}")
# 等待一下,让异步任务有机会创建
await asyncio.sleep(0.1)
# 尝试进行 deepcopy
print("[INFO] 尝试进行 deepcopy...")
copied_stream = copy.deepcopy(chat_stream)
print("[SUCCESS] deepcopy 成功!")
# 验证复制后的对象属性
print("\n[CHECK] 验证复制后的对象属性:")
print(f" - stream_id: {copied_stream.stream_id}")
print(f" - platform: {copied_stream.platform}")
print(f" - user_info: {copied_stream.user_info.user_nickname}")
print(f" - group_info: {copied_stream.group_info.group_name}")
# 检查 processing_task 是否被正确处理
if hasattr(copied_stream.stream_context, 'processing_task'):
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
if copied_stream.stream_context.processing_task is None:
print(" [SUCCESS] processing_task 已被正确设置为 None")
else:
print(" [WARNING] processing_task 不为 None")
else:
print(" [SUCCESS] stream_context 没有 processing_task 属性")
# 验证原始对象和复制对象是不同的实例
if id(chat_stream) != id(copied_stream):
print("[SUCCESS] 原始对象和复制对象是不同的实例")
else:
print("[ERROR] 原始对象和复制对象是同一个实例")
# 验证基本属性是否正确复制
if (chat_stream.stream_id == copied_stream.stream_id and
chat_stream.platform == copied_stream.platform):
print("[SUCCESS] 基本属性正确复制")
else:
print("[ERROR] 基本属性复制失败")
print("\n[COMPLETE] 测试完成deepcopy 功能修复成功!")
return True
except Exception as e:
print(f"[ERROR] 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
# 运行测试
result = asyncio.run(test_chat_stream_deepcopy())
if result:
print("\n[SUCCESS] 所有测试通过!")
sys.exit(0)
else:
print("\n[ERROR] 测试失败!")
sys.exit(1)

View File

@@ -1,109 +0,0 @@
#!/usr/bin/env python3
"""
简单的 ChatStream deepcopy 测试
"""
import asyncio
import sys
import os
import copy
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import UserInfo, GroupInfo
async def test_deepcopy():
"""测试 deepcopy 功能"""
print("开始测试 ChatStream deepcopy 功能...")
try:
# 创建测试用的用户和群组信息
user_info = UserInfo(
platform="test_platform",
user_id="test_user_123",
user_nickname="测试用户",
user_cardname="测试卡片名"
)
group_info = GroupInfo(
platform="test_platform",
group_id="test_group_456",
group_name="测试群组"
)
# 创建 ChatStream 实例
print("创建 ChatStream 实例...")
stream_id = "test_stream_789"
platform = "test_platform"
chat_stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
print(f"ChatStream 创建成功: {chat_stream.stream_id}")
# 等待一下,让异步任务有机会创建
await asyncio.sleep(0.1)
# 尝试进行 deepcopy
print("尝试进行 deepcopy...")
copied_stream = copy.deepcopy(chat_stream)
print("deepcopy 成功!")
# 验证复制后的对象属性
print("\n验证复制后的对象属性:")
print(f" - stream_id: {copied_stream.stream_id}")
print(f" - platform: {copied_stream.platform}")
print(f" - user_info: {copied_stream.user_info.user_nickname}")
print(f" - group_info: {copied_stream.group_info.group_name}")
# 检查 processing_task 是否被正确处理
if hasattr(copied_stream.stream_context, 'processing_task'):
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
if copied_stream.stream_context.processing_task is None:
print(" SUCCESS: processing_task 已被正确设置为 None")
else:
print(" WARNING: processing_task 不为 None")
else:
print(" SUCCESS: stream_context 没有 processing_task 属性")
# 验证原始对象和复制对象是不同的实例
if id(chat_stream) != id(copied_stream):
print("SUCCESS: 原始对象和复制对象是不同的实例")
else:
print("ERROR: 原始对象和复制对象是同一个实例")
# 验证基本属性是否正确复制
if (chat_stream.stream_id == copied_stream.stream_id and
chat_stream.platform == copied_stream.platform):
print("SUCCESS: 基本属性正确复制")
else:
print("ERROR: 基本属性复制失败")
print("\n测试完成deepcopy 功能修复成功!")
return True
except Exception as e:
print(f"ERROR: 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
# 运行测试
result = asyncio.run(test_deepcopy())
if result:
print("\n所有测试通过!")
sys.exit(0)
else:
print("\n测试失败!")
sys.exit(1)