refactor(chat): 迁移数据库操作为异步模式并修复相关调用

将同步数据库操作全面迁移为异步模式,主要涉及:
- 将 `with get_db_session()` 改为 `async with get_db_session()`
- 修复相关异步调用链,确保 await 正确传递
- 优化消息管理器、上下文管理器等核心组件的异步处理
- 移除同步的 person_id 获取方法,避免协程对象传递问题

修复 deepcopy 在 StreamContext 中的序列化问题,跳过不可序列化的 asyncio.Task 对象

删除无用的测试文件和废弃的插件清单文件
This commit is contained in:
Windpicker-owo
2025-09-28 20:40:46 +08:00
parent 08ef960947
commit fd76e36320
30 changed files with 481 additions and 625 deletions

View File

@@ -261,7 +261,7 @@ class AntiPromptInjector:
logger.warning("无法删除消息缺少message_id")
return
with get_db_session() as session:
async with get_db_session() as session:
# 删除对应的消息记录
stmt = delete(Messages).where(Messages.message_id == message_id)
result = session.execute(stmt)
@@ -287,7 +287,7 @@ class AntiPromptInjector:
logger.warning("无法更新消息缺少message_id")
return
with get_db_session() as session:
async with get_db_session() as session:
# 更新消息内容
stmt = (
update(Messages)

View File

@@ -42,7 +42,7 @@ class SingleStreamContextManager:
self._update_access_stats()
return self.context
def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
async def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""添加消息到上下文
Args:
@@ -53,30 +53,21 @@ class SingleStreamContextManager:
bool: 是否成功添加
"""
try:
# 添加消息到上下文
self.context.add_message(message)
# 计算消息兴趣度
interest_value = self._calculate_message_interest(message)
interest_value = await self._calculate_message_interest(message)
message.interest_value = interest_value
# 更新统计
self.total_messages += 1
self.last_access_time = time.time()
# 更新能量和分发
if not skip_energy_update:
self._update_stream_energy()
await self._update_stream_energy()
distribution_manager.add_stream_message(self.stream_id, 1)
logger.debug(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
async def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
"""更新上下文中的消息
Args:
@@ -87,16 +78,11 @@ class SingleStreamContextManager:
bool: 是否成功更新
"""
try:
# 更新消息信息
self.context.update_message_info(message_id, **updates)
# 如果更新了兴趣度,重新计算能量
if "interest_value" in updates:
self._update_stream_energy()
await self._update_stream_energy()
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
return True
except Exception as e:
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
return False
@@ -164,16 +150,13 @@ class SingleStreamContextManager:
logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True)
return False
def clear_context(self) -> bool:
async def clear_context(self) -> bool:
"""清空上下文"""
try:
# 清空消息
if hasattr(self.context, "unread_messages"):
self.context.unread_messages.clear()
if hasattr(self.context, "history_messages"):
self.context.history_messages.clear()
# 重置状态
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
for attr in reset_attrs:
if hasattr(self.context, attr):
@@ -181,13 +164,9 @@ class SingleStreamContextManager:
setattr(self.context, attr, 0)
else:
setattr(self.context, attr, time.time())
# 重新计算能量
self._update_stream_energy()
await self._update_stream_energy()
logger.info(f"清空单流上下文: {self.stream_id}")
return True
except Exception as e:
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
@@ -249,39 +228,115 @@ class SingleStreamContextManager:
self.last_access_time = time.time()
self.access_count += 1
def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""计算消息兴趣度"""
async def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""异步实现:使用插件的异步评分器正确 await 计算兴趣度并返回分数。"""
try:
# 使用插件内部的兴趣度评分系统
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
# 使用插件内部的兴趣度评分系统计算(同步方式)
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system,
)
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
interest_score = loop.run_until_complete(
chatter_interest_scoring_system._calculate_single_message_score(
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
message=message, bot_nickname=global_config.bot.nickname
)
interest_value = interest_score.total_score
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
return interest_value
except Exception as e:
logger.warning(f"插件内部兴趣度计算失败: {e}")
return 0.5
except Exception as e:
logger.warning(f"插件内部兴趣度计算加载失败,使用默认值: {e}")
return 0.5
except Exception as e:
logger.error(f"计算消息兴趣度失败: {e}")
return 0.5
async def _calculate_message_interest_async(self, message: DatabaseMessages) -> float:
"""异步实现:使用插件的异步评分器正确 await 计算兴趣度并返回分数。"""
try:
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system,
)
# 直接 await 插件的异步方法
try:
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
message=message, bot_nickname=global_config.bot.nickname
)
interest_value = interest_score.total_score
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
return interest_value
except Exception as e:
logger.warning(f"插件内部兴趣度计算失败: {e}")
return 0.5
except Exception as e:
logger.warning(f"插件内部兴趣度计算失败,使用默认值: {e}")
interest_value = 0.5 # 默认中等兴趣度
return interest_value
logger.warning(f"插件内部兴趣度计算加载失败,使用默认值: {e}")
return 0.5
except Exception as e:
logger.error(f"计算消息兴趣度失败: {e}")
return 0.5
async def add_message_async(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""异步实现的 add_message将消息添加到 context并 await 能量更新与分发。"""
try:
self.context.add_message(message)
interest_value = await self._calculate_message_interest_async(message)
message.interest_value = interest_value
self.total_messages += 1
self.last_access_time = time.time()
if not skip_energy_update:
await self._update_stream_energy()
distribution_manager.add_stream_message(self.stream_id, 1)
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度: {interest_value:.3f})")
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
return False
async def update_message_async(self, message_id: str, updates: Dict[str, Any]) -> bool:
"""异步实现的 update_message更新消息并在需要时 await 能量更新。"""
try:
self.context.update_message_info(message_id, **updates)
if "interest_value" in updates:
await self._update_stream_energy()
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
return True
except Exception as e:
logger.error(f"更新单流上下文消息失败 (async) {self.stream_id}/{message_id}: {e}", exc_info=True)
return False
async def clear_context_async(self) -> bool:
"""异步实现的 clear_context清空消息并 await 能量重算。"""
try:
if hasattr(self.context, "unread_messages"):
self.context.unread_messages.clear()
if hasattr(self.context, "history_messages"):
self.context.history_messages.clear()
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
for attr in reset_attrs:
if hasattr(self.context, attr):
if attr in ["interruption_count", "afc_threshold_adjustment"]:
setattr(self.context, attr, 0)
else:
setattr(self.context, attr, time.time())
await self._update_stream_energy()
logger.info(f"清空单流上下文(异步): {self.stream_id}")
return True
except Exception as e:
logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
return False
async def _update_stream_energy(self):
"""更新流能量"""
try:

View File

@@ -75,29 +75,23 @@ class MessageManager:
logger.info("消息管理器已停止")
def add_message(self, stream_id: str, message: DatabaseMessages):
async def add_message(self, stream_id: str, message: DatabaseMessages):
"""添加消息到指定聊天流"""
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
# 使用 ChatStream 的 context_manager 添加消息
success = chat_stream.context_manager.add_message(message)
success = await chat_stream.context_manager.add_message(message)
if success:
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
else:
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
except Exception as e:
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
def update_message(
async def update_message(
self,
stream_id: str,
message_id: str,
@@ -107,15 +101,11 @@ class MessageManager:
):
"""更新消息信息"""
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在")
return
# 构建更新字典
updates = {}
if interest_value is not None:
updates["interest_value"] = interest_value
@@ -123,41 +113,30 @@ class MessageManager:
updates["actions"] = actions
if should_reply is not None:
updates["should_reply"] = should_reply
# 使用 ChatStream 的 context_manager 更新消息
if updates:
success = chat_stream.context_manager.update_message(message_id, updates)
success = await chat_stream.context_manager.update_message(message_id, updates)
if success:
logger.debug(f"更新消息 {message_id} 成功")
else:
logger.warning(f"更新消息 {message_id} 失败")
except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
def add_action(self, stream_id: str, message_id: str, action: str):
async def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息"""
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
return
# 使用 ChatStream 的 context_manager 添加动作
# 注意:这里需要根据实际的 API 调整
# 假设我们可以通过 update_message 来添加动作
success = chat_stream.context_manager.update_message(
success = await chat_stream.context_manager.update_message(
message_id, {"actions": [action]}
)
if success:
logger.debug(f"为消息 {message_id} 添加动作 {action} 成功")
else:
logger.warning(f"为消息 {message_id} 添加动作 {action} 失败")
except Exception as e:
logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}")
@@ -382,36 +361,27 @@ class MessageManager:
"start_time": self.stats.start_time,
}
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
async def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
"""清理不活跃的聊天流"""
try:
# 通过 ChatManager 清理不活跃的流
chat_manager = get_chat_manager()
current_time = time.time()
max_inactive_seconds = max_inactive_hours * 3600
inactive_streams = []
for stream_id, chat_stream in chat_manager.streams.items():
# 检查最后活跃时间
if current_time - chat_stream.last_active_time > max_inactive_seconds:
inactive_streams.append(stream_id)
# 清理不活跃的流
for stream_id in inactive_streams:
try:
# 清理流的内容
chat_stream.context_manager.clear_context()
# 从 ChatManager 中移除
await chat_stream.context_manager.clear_context()
del chat_manager.streams[stream_id]
logger.info(f"清理不活跃聊天流: {stream_id}")
except Exception as e:
logger.error(f"清理聊天流 {stream_id} 失败: {e}")
if inactive_streams:
logger.info(f"已清理 {len(inactive_streams)} 个不活跃聊天流")
else:
logger.debug("没有需要清理的不活跃聊天流")
except Exception as e:
logger.error(f"清理不活跃聊天流时发生错误: {e}")

View File

@@ -514,7 +514,7 @@ class ChatBot:
db_message.chat_info_group_platform = message.chat_stream.group_info.platform
# 添加消息到消息管理器
message_manager.add_message(message.chat_stream.stream_id, db_message)
await message_manager.add_message(message.chat_stream.stream_id, db_message)
logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}")
if template_group_name:

View File

@@ -389,30 +389,24 @@ class ChatStream:
from sqlalchemy import select, desc
import asyncio
async def _load_messages():
def _db_query():
with get_db_session() as session:
# 查询该stream_id的最近20条消息
async def _load_history_messages_async():
"""异步加载并转换历史消息到 stream_context在事件循环中运行"""
try:
async with get_db_session() as session:
stmt = (
select(Messages)
.where(Messages.chat_info_stream_id == self.stream_id)
.order_by(desc(Messages.time))
.limit(global_config.chat.max_context_size)
)
result = session.execute(stmt)
results = result.scalars().all()
return results
# 在线程中执行数据库查询
db_messages = await asyncio.to_thread(_db_query)
result = await session.execute(stmt)
db_messages = result.scalars().all()
# 转换为DatabaseMessages对象并添加到StreamContext
for db_msg in db_messages:
try:
# 从SQLAlchemy模型转换为DatabaseMessages数据模型
import orjson
# 解析actions字段JSON格式
actions = None
if db_msg.actions:
try:
@@ -457,17 +451,15 @@ class ChatStream:
should_reply=getattr(db_msg, "should_reply", False) or False,
)
# 添加调试日志检查从数据库加载的interest_value
logger.debug(
f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}"
)
# 标记为已读并添加到历史消息
db_message.is_read = True
self.stream_context.history_messages.append(db_message)
except Exception as e:
logger.warning(f"转换消息 {db_msg.message_id} 失败: {e}")
logger.warning(f"转换消息 {getattr(db_msg, 'message_id', '<unknown>')} 失败: {e}")
continue
if self.stream_context.history_messages:
@@ -475,8 +467,27 @@ class ChatStream:
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}"
)
# 创建任务来加载历史消息
asyncio.create_task(_load_messages())
except Exception as e:
logger.warning(f"异步加载历史消息失败: {e}")
# 在已有事件循环中,避免调用 asyncio.run()
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# 没有运行的事件循环,安全地运行并等待完成
asyncio.run(_load_history_messages_async())
else:
# 如果事件循环正在运行,在后台创建任务
if loop.is_running():
try:
asyncio.create_task(_load_history_messages_async())
except Exception as e:
# 如果无法创建任务,退回到阻塞运行
logger.warning(f"无法在事件循环中创建后台任务,尝试阻塞运行: {e}")
asyncio.run(_load_history_messages_async())
else:
# loop 存在但未运行,使用 asyncio.run
asyncio.run(_load_history_messages_async())
except Exception as e:
logger.error(f"加载历史消息失败: {e}")
@@ -498,7 +509,7 @@ class ChatManager:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
# try:
# with get_db_session() as session:
# async with get_db_session() as session:
# db.connect(reuse_if_open=True)
# # 确保 ChatStreams 表存在
# session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))

View File

@@ -219,7 +219,7 @@ class MessageStorage:
return match.group(0)
@staticmethod
def update_message_interest_value(message_id: str, interest_value: float) -> None:
async def update_message_interest_value(message_id: str, interest_value: float) -> None:
"""
更新数据库中消息的interest_value字段
@@ -228,11 +228,11 @@ class MessageStorage:
interest_value: 兴趣度值
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
# 更新消息的interest_value字段
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value)
result = session.execute(stmt)
session.commit()
result = await session.execute(stmt)
await session.commit()
if result.rowcount > 0:
logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}")
@@ -244,7 +244,7 @@ class MessageStorage:
raise
@staticmethod
def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
async def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
"""
修复指定聊天中interest_value为0或null的历史消息记录
@@ -256,7 +256,7 @@ class MessageStorage:
修复的记录数量
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
from sqlalchemy import select, update
from src.common.database.sqlalchemy_models import Messages
@@ -271,7 +271,7 @@ class MessageStorage:
)
).limit(50) # 限制每次修复的数量,避免性能问题
result = session.execute(query)
result = await session.execute(query)
messages_to_fix = result.scalars().all()
fixed_count = 0
@@ -297,12 +297,12 @@ class MessageStorage:
Messages.message_id == msg.message_id
).values(interest_value=default_interest)
result = session.execute(update_stmt)
result = await session.execute(update_stmt)
if result.rowcount > 0:
fixed_count += 1
logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}")
session.commit()
await session.commit()
logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值")
return fixed_count

View File

@@ -297,15 +297,12 @@ class ChatterActionManager:
return
# 通过message_manager更新消息的动作记录并刷新focus_energy
if chat_stream.stream_id in message_manager.stream_contexts:
message_manager.add_action(
await message_manager.add_action(
stream_id=chat_stream.stream_id,
message_id=target_message_id,
action=action_name
)
logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy")
else:
logger.debug(f"未找到stream_context: {chat_stream.stream_id}")
except Exception as e:
logger.error(f"记录动作到消息失败: {e}")
@@ -315,8 +312,11 @@ class ChatterActionManager:
"""在动作执行成功后重置打断计数"""
from src.chat.message_manager.message_manager import message_manager
try:
if stream_id in message_manager.stream_contexts:
context = message_manager.stream_contexts[stream_id]
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream:
context = chat_stream.context_manager
if context.interruption_count > 0:
old_count = context.interruption_count
old_afc_adjustment = context.get_afc_threshold_adjustment()

View File

@@ -73,7 +73,7 @@ class ActionModifier:
from src.chat.utils.utils import get_chat_type_and_target_info
# 获取聊天类型
is_group_chat, _ = await get_chat_type_and_target_info(self.chat_id)
is_group_chat, _ = get_chat_type_and_target_info(self.chat_id)
all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION)
chat_type_removals = []

View File

@@ -684,8 +684,11 @@ class DefaultReplyer:
from src.chat.message_manager.message_manager import message_manager
# 获取聊天流的上下文
stream_context = message_manager.stream_contexts.get(chat_id)
if stream_context:
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(chat_id)
if chat_stream:
stream_context = chat_stream.context_manager
# 使用真正的已读和未读消息
read_messages = stream_context.history_messages # 已读消息
unread_messages = stream_context.get_unread_messages() # 未读消息
@@ -693,7 +696,7 @@ class DefaultReplyer:
# 构建已读历史消息 prompt
read_history_prompt = ""
if read_messages:
read_content = build_readable_messages(
read_content = await build_readable_messages(
[msg.flatten() for msg in read_messages[-50:]], # 限制数量
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
@@ -716,7 +719,7 @@ class DefaultReplyer:
]
if filtered_fallback_messages:
read_content = build_readable_messages(
read_content = await build_readable_messages(
filtered_fallback_messages,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
@@ -754,7 +757,7 @@ class DefaultReplyer:
if platform and user_id:
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager()
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
sender_name = person_info_manager.get_value(person_id, "person_name") or "未知用户"
else:
sender_name = "未知用户"
@@ -819,7 +822,7 @@ class DefaultReplyer:
# 构建已读历史消息 prompt
read_history_prompt = ""
if read_messages:
read_content = build_readable_messages(
read_content = await build_readable_messages(
read_messages[-50:],
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
@@ -853,7 +856,7 @@ class DefaultReplyer:
if platform and user_id:
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager()
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
sender_name = person_info_manager.get_value(person_id, "person_name") or "未知用户"
else:
sender_name = "未知用户"
@@ -1027,7 +1030,7 @@ class DefaultReplyer:
# 检查是否是bot自己的名字如果是则替换为"(你)"
bot_user_id = str(global_config.bot.qq_account)
current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
current_user_id = person_info_manager.get_value(person_id, "user_id")
current_platform = reply_message.get("chat_info_platform")
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
@@ -1046,7 +1049,7 @@ class DefaultReplyer:
target = "(无消息内容)"
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
person_id = await person_info_manager.get_person_id_by_person_name(sender)
platform = chat_stream.platform
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
@@ -1071,7 +1074,7 @@ class DefaultReplyer:
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
)
chat_talking_prompt_short = build_readable_messages(
chat_talking_prompt_short = await build_readable_messages(
message_list_before_short,
replace_bot_name=True,
merge_messages=False,
@@ -1324,7 +1327,7 @@ class DefaultReplyer:
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
)
chat_talking_prompt_half = build_readable_messages(
chat_talking_prompt_half = await build_readable_messages(
message_list_before_now_half,
replace_bot_name=True,
merge_messages=False,
@@ -1523,7 +1526,7 @@ class DefaultReplyer:
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
person_id = await person_info_manager.get_person_id_by_person_name(sender)
if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"

View File

@@ -46,7 +46,7 @@ def replace_user_references_sync(
if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
person_id = PersonInfoManager.get_person_id(platform, user_id)
return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore
return person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
name_resolver = default_resolver
@@ -254,7 +254,7 @@ def get_raw_msg_by_timestamp_with_chat_users(
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_actions_by_timestamp_with_chat(
async def get_actions_by_timestamp_with_chat(
chat_id: str,
timestamp_start: float = 0,
timestamp_end: float = time.time(),
@@ -273,22 +273,21 @@ def get_actions_by_timestamp_with_chat(
f"limit={limit}, limit_mode={limit_mode}"
)
with get_db_session() as session:
async with get_db_session() as session:
if limit > 0:
if limit_mode == "latest":
query = session.execute(
result = await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end,
)
)
.order_by(ActionRecords.time.desc())
.limit(limit)
)
actions = list(query.scalars())
actions = list(result.scalars())
actions_result = []
for action in reversed(actions):
action_dict = {
@@ -305,8 +304,9 @@ def get_actions_by_timestamp_with_chat(
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
actions_result.append(action_dict)
else: # earliest
query = session.execute(
result = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -318,7 +318,7 @@ def get_actions_by_timestamp_with_chat(
.order_by(ActionRecords.time.asc())
.limit(limit)
)
actions = list(query.scalars())
actions = list(result.scalars())
actions_result = []
for action in actions:
action_dict = {
@@ -336,7 +336,7 @@ def get_actions_by_timestamp_with_chat(
}
actions_result.append(action_dict)
else:
query = session.execute(
result = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -347,7 +347,7 @@ def get_actions_by_timestamp_with_chat(
)
.order_by(ActionRecords.time.asc())
)
actions = list(query.scalars())
actions = list(result.scalars())
actions_result = []
for action in actions:
action_dict = {
@@ -367,14 +367,14 @@ def get_actions_by_timestamp_with_chat(
return actions_result
def get_actions_by_timestamp_with_chat_inclusive(
async def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
with get_db_session() as session:
async with get_db_session() as session:
if limit > 0:
if limit_mode == "latest":
query = session.execute(
result = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -386,10 +386,10 @@ def get_actions_by_timestamp_with_chat_inclusive(
.order_by(ActionRecords.time.desc())
.limit(limit)
)
actions = list(query.scalars())
actions = list(result.scalars())
return [action.__dict__ for action in reversed(actions)]
else: # earliest
query = session.execute(
result = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -402,7 +402,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
.limit(limit)
)
else:
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -507,7 +507,7 @@ def num_new_messages_since_with_users(
return count_messages(message_filter=filter_query)
def _build_readable_messages_internal(
async def _build_readable_messages_internal(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
@@ -627,7 +627,7 @@ def _build_readable_messages_internal(
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
person_name = await person_info_manager.get_value(person_id, "person_name") # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name:
@@ -800,7 +800,7 @@ def _build_readable_messages_internal(
)
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# sourcery skip: use-contextlib-suppress
"""
构建图片映射信息字符串,显示图片的具体描述内容
@@ -823,8 +823,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述
description = "[图片内容未知]" # 默认描述
try:
with get_db_session() as session:
result = session.execute(select(Images).where(Images.image_id == pic_id))
async with get_db_session() as session:
result = await session.execute(select(Images).where(Images.image_id == pic_id))
image = result.scalar_one_or_none()
if image and image.description: # type: ignore
description = image.description
@@ -922,17 +922,17 @@ async def build_readable_messages_with_list(
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
"""
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
)
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping):
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
return formatted_string, details_list
def build_readable_messages_with_id(
async def build_readable_messages_with_id(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
@@ -948,7 +948,7 @@ def build_readable_messages_with_id(
"""
message_id_list = assign_message_ids(messages)
formatted_string = build_readable_messages(
formatted_string = await build_readable_messages(
messages=messages,
replace_bot_name=replace_bot_name,
merge_messages=merge_messages,
@@ -960,10 +960,16 @@ def build_readable_messages_with_id(
message_id_list=message_id_list,
)
# 如果存在图片映射信息,附加之
if pic_mapping_info := await build_pic_mapping_info({}):
# 如果当前没有图片映射则不附加
if pic_mapping_info:
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
return formatted_string, message_id_list
def build_readable_messages(
async def build_readable_messages(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
@@ -1004,9 +1010,9 @@ def build_readable_messages(
from src.common.database.sqlalchemy_database_api import get_db_session
with get_db_session() as session:
async with get_db_session() as session:
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = session.execute(
actions_in_range = (await session.execute(
select(ActionRecords)
.where(
and_(
@@ -1014,15 +1020,15 @@ def build_readable_messages(
)
)
.order_by(ActionRecords.time)
).scalars()
)).scalars()
# 获取最新消息之后的第一个动作记录
action_after_latest = session.execute(
action_after_latest = (await session.execute(
select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
).scalars()
)).scalars()
# 合并两部分动作记录,并转为 dict避免 DetachedInstanceError
actions = [
@@ -1053,7 +1059,7 @@ def build_readable_messages(
if read_mark <= 0:
# 没有有效的 read_mark直接格式化所有消息
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal(
copy_messages,
replace_bot_name,
merge_messages,
@@ -1064,7 +1070,7 @@ def build_readable_messages(
)
# 生成图片映射信息并添加到最前面
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
if pic_mapping_info:
return f"{pic_mapping_info}\n\n{formatted_string}"
else:
@@ -1079,7 +1085,7 @@ def build_readable_messages(
pic_counter = 1
# 分别格式化,但使用共享的图片映射
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal(
messages_before_mark,
replace_bot_name,
merge_messages,
@@ -1090,7 +1096,7 @@ def build_readable_messages(
show_pic=show_pic,
message_id_list=message_id_list,
)
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal(
messages_after_mark,
replace_bot_name,
merge_messages,
@@ -1106,7 +1112,7 @@ def build_readable_messages(
# 生成图片映射信息
if pic_id_mapping:
pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
else:
pic_mapping_info = "聊天记录信息:\n"
@@ -1229,7 +1235,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# 在最前面添加图片映射信息
final_output_lines = []
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
if pic_mapping_info:
final_output_lines.append(pic_mapping_info)
final_output_lines.append("\n\n")

View File

@@ -494,7 +494,7 @@ class Prompt:
chat_history = ""
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-10:]
chat_history = build_readable_messages(
chat_history = await build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
@@ -535,7 +535,7 @@ class Prompt:
chat_history = ""
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-20:]
chat_history = build_readable_messages(
chat_history = await build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
@@ -589,7 +589,7 @@ class Prompt:
chat_history = ""
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-15:]
chat_history = build_readable_messages(
chat_history = await build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
@@ -863,7 +863,7 @@ class Prompt:
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
person_id = await person_info_manager.get_person_id_by_person_name(sender)
if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
@@ -904,7 +904,7 @@ class Prompt:
return ""
@staticmethod
def parse_reply_target_id(reply_to: str) -> str:
async def parse_reply_target_id(reply_to: str) -> str:
"""
解析回复目标中的用户ID
@@ -924,9 +924,9 @@ class Prompt:
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
person_id = await person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
user_id = person_info_manager.get_value(person_id, "user_id")
return str(user_id) if user_id else ""
return ""

View File

@@ -1,3 +1,4 @@
import asyncio
import random
import re
import string
@@ -662,9 +663,32 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_name = None
if person_id:
# get_value is async, so await it directly
person_info_manager = get_person_info_manager()
person_name = person_info_manager.get_value_sync(person_id, "person_name")
try:
# 如果没有运行的事件循环,直接 asyncio.run
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环在运行,从其他线程提交并等待结果
try:
from concurrent.futures import TimeoutError
fut = asyncio.run_coroutine_threadsafe(
person_info_manager.get_value(person_id, "person_name"), loop
)
person_name = fut.result(timeout=2)
except Exception as e:
# 无法在运行循环上安全等待,退回为 None
logger.debug(f"无法通过运行的事件循环获取 person_name: {e}")
person_name = None
else:
person_name = asyncio.run(person_info_manager.get_value(person_id, "person_name"))
except RuntimeError:
# get_event_loop 在某些上下文可能抛出 RuntimeError退回到 asyncio.run
try:
person_name = asyncio.run(person_info_manager.get_value(person_id, "person_name"))
except Exception as e:
logger.debug(f"获取 person_name 失败: {e}")
person_name = None
target_info["person_id"] = person_id
target_info["person_name"] = person_name

View File

@@ -344,6 +344,39 @@ class StreamContext(BaseDataModel):
"""获取优先级信息"""
return self.priority_info
def __deepcopy__(self, memo):
"""自定义深拷贝,跳过不可序列化的 asyncio.Task (processing_task)。
deepcopy 在内部可能会尝试 pickle 某些对象(如 asyncio.Task
这会在多线程或运行时事件循环中导致 TypeError。这里我们手动复制
__dict__ 中的字段,确保 processing_task 被设置为 None其他字段使用
copy.deepcopy 递归复制。
"""
import copy
# 如果已经复制过,直接返回缓存结果
obj_id = id(self)
if obj_id in memo:
return memo[obj_id]
# 创建一个未初始化的新实例,然后逐个字段深拷贝
cls = self.__class__
new = cls.__new__(cls)
memo[obj_id] = new
for k, v in self.__dict__.items():
if k == "processing_task":
# 不复制 asyncio.Task避免无法 pickling
setattr(new, k, None)
else:
try:
setattr(new, k, copy.deepcopy(v, memo))
except Exception:
# 如果某个字段无法深拷贝,退回到原始引用(安全性谨慎)
setattr(new, k, v)
return new
@dataclass
class MessageManagerStats(BaseDataModel):

View File

@@ -30,7 +30,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns}
def find_messages(
async def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
limit: int = 0,
@@ -51,7 +51,7 @@ def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
query = select(Messages)
# 应用过滤器
@@ -101,8 +101,8 @@ def find_messages(
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
try:
results = result = session.execute(query)
result.scalars().all()
result = await session.execute(query)
results = result.scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
@@ -110,8 +110,8 @@ def find_messages(
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
try:
latest_results = result = session.execute(query)
result.scalars().all()
result = await session.execute(query)
latest_results = result.scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
@@ -135,8 +135,8 @@ def find_messages(
if sort_terms:
query = query.order_by(*sort_terms)
try:
results = result = session.execute(query)
result.scalars().all()
result = await session.execute(query)
results = result.scalars().all()
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
results = []
@@ -152,7 +152,7 @@ def find_messages(
return []
def count_messages(message_filter: dict[str, Any]) -> int:
async def count_messages(message_filter: dict[str, Any]) -> int:
"""
根据提供的过滤器计算消息数量。
@@ -163,7 +163,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量,如果出错则返回 0。
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
query = select(func.count(Messages.id))
# 应用过滤器
@@ -201,7 +201,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
if conditions:
query = query.where(*conditions)
count = session.execute(query).scalar()
count = (await session.execute(query)).scalar()
return count or 0
except Exception as e:
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"

View File

@@ -148,8 +148,8 @@ class MainSystem:
# 停止消息重组器
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system import EventType
asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
from src.utils.message_chunker import reassembler
loop = asyncio.get_event_loop()

View File

@@ -110,7 +110,7 @@ class ChatMood:
limit=int(global_config.chat.max_context_size / 3),
limit_mode="last",
)
chat_talking_prompt = build_readable_messages(
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
@@ -159,7 +159,7 @@ class ChatMood:
limit=15,
limit_mode="last",
)
chat_talking_prompt = build_readable_messages(
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=False,

View File

@@ -1,3 +1,4 @@
import asyncio
import copy
import datetime
import hashlib
@@ -57,7 +58,7 @@ class PersonInfoManager:
self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
# try:
# with get_db_session() as session:
# async with get_db_session() as session:
# db.connect(reuse_if_open=True)
# # 设置连接池参数仅对SQLite有效
# if hasattr(db, "execute_sql"):
@@ -75,7 +76,7 @@ class PersonInfoManager:
try:
pass
# 在这里获取会话
# with get_db_session() as session:
# async with get_db_session() as session:
# for record in session.execute(
# select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
# ).fetchall():
@@ -87,59 +88,26 @@ class PersonInfoManager:
@staticmethod
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
"""获取唯一id(同步)
说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。
为了避免将 coroutine 传递到其它同步调用(例如数据库查询条件)中,这里将方法改为同步并仅返回基于 platform 和 user_id 的 MD5 哈希值。
注意: 这会跳过原有的 napcat->qq 迁移检查逻辑。如需保留迁移,请使用显式的、在合适时机执行的迁移任务。
"""
# 检查platform是否为None或空
if platform is None:
platform = "unknown"
if "-" in platform:
platform = platform.split("-")[1]
# 在此处打一个补丁如果platform为qq尝试生成id后检查是否存在如果不存在则将平台换为napcat后再次检查如果存在则更新原id为platform为qq的id
components = [platform, str(user_id)]
key = "_".join(components)
# 如果不是 qq 平台,直接返回计算的 id
if platform != "qq":
# 直接返回计算的 id(同步)
return hashlib.md5(key.encode()).hexdigest()
qq_id = hashlib.md5(key.encode()).hexdigest()
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
try:
with get_db_session() as session:
# 检查 qq_id 是否存在
existing_qq = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if existing_qq:
return p_id
# 如果 qq_id 不存在,尝试使用 napcat 作为平台生成对应 id 并检查
nap_components = ["napcat", str(raw_user_id)]
nap_key = "_".join(nap_components)
nap_id = hashlib.md5(nap_key.encode()).hexdigest()
existing_nap = session.execute(select(PersonInfo).where(PersonInfo.person_id == nap_id)).scalar()
if not existing_nap:
# napcat 也不存在,返回 qq_id未命中
return p_id
# napcat 存在,迁移该记录:更新 person_id 与 platform -> qq
try:
# 更新现有 napcat 记录
existing_nap.person_id = p_id
existing_nap.platform = "qq"
existing_nap.user_id = str(raw_user_id)
session.commit()
return p_id
except Exception:
session.rollback()
return p_id
except Exception as e:
logger.error(f"检查/迁移 napcat->qq 时出错: {e}")
return p_id
return _db_check_and_migrate_sync(qq_id, user_id)
async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人"""
person_id = self.get_person_id(platform, user_id)
@@ -157,17 +125,25 @@ class PersonInfoManager:
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
return False
@staticmethod
async def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID"""
async def get_person_id_by_person_name(self, person_name: str) -> str:
"""
根据用户名获取用户ID(同步)
说明: 为了避免在多个调用点将 coroutine 误传递到数据库查询中,
此处提供一个同步实现。优先在内存缓存 `self.person_name_list` 中查找,
若未命中则返回空字符串。若后续需要更强的一致性,可在异步上下文
额外实现带 await 的查询方法。
"""
try:
# 在需要时获取会话
async with get_db_session() as session:
record = result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))
result.scalar()
return record.person_id if record else ""
# 优先使用内存缓存加速查找self.person_name_list maps person_id -> person_name
for pid, pname in self.person_name_list.items():
if pname == person_name:
return pid
# 未找到缓存命中,避免在同步路径中进行阻塞的数据库查询,直接返回空字符串
return ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
return ""
@staticmethod
@@ -578,26 +554,15 @@ class PersonInfoManager:
@staticmethod
def get_value(person_id: str, field_name: str) -> Any:
async def get_value(person_id: str, field_name: str) -> Any:
"""获取单个字段值(同步版本)"""
if not person_id:
logger.debug("get_value获取失败person_id不能为空")
return None
import asyncio
async def _get_record_sync():
async with get_db_session() as session:
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
record = result.scalar()
return record
try:
record = asyncio.run(_get_record_sync())
except RuntimeError:
# 如果当前线程已经有事件循环在运行,则使用现有的循环
loop = asyncio.get_running_loop()
record = loop.run_until_complete(_get_record_sync())
model_fields = [column.name for column in PersonInfo.__table__.columns]

View File

@@ -176,7 +176,7 @@ class RelationshipFetcher:
# 查询用户关系数据
relationships = await db_query(
UserRelationships,
filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))],
filters=[UserRelationships.user_id == str(person_info_manager.get_value(person_id, "user_id"))],
limit=1,
)
@@ -259,7 +259,7 @@ class RelationshipFetcher:
# 记录信息获取请求
self.info_fetching_cache.append(
{
"person_id": get_person_info_manager().get_person_id_by_person_name(person_name),
"person_id": await get_person_info_manager().get_person_id_by_person_name(person_name),
"person_name": person_name,
"info_type": info_type,
"start_time": time.time(),

View File

@@ -412,7 +412,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
# =============================================================================
def build_readable_messages_to_str(
async def build_readable_messages_to_str(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
@@ -436,7 +436,7 @@ def build_readable_messages_to_str(
Returns:
格式化后的可读字符串
"""
return build_readable_messages(
return await build_readable_messages(
messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions
)

View File

@@ -134,7 +134,7 @@ async def is_person_known(platform: str, user_id: int) -> bool:
return False
def get_person_id_by_name(person_name: str) -> str:
async def get_person_id_by_name(person_name: str) -> str:
"""根据用户名获取person_id
Args:
@@ -148,7 +148,7 @@ def get_person_id_by_name(person_name: str) -> str:
"""
try:
person_info_manager = get_person_info_manager()
return person_info_manager.get_person_id_by_person_name(person_name)
return await person_info_manager.get_person_id_by_person_name(person_name)
except Exception as e:
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
return ""

View File

@@ -542,7 +542,22 @@ class PluginManager:
plugin_instance.on_unload()
# 从组件注册表中移除插件的所有组件
try:
loop = asyncio.get_event_loop()
if loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
component_registry.unregister_plugin(plugin_name), loop
)
fut.result(timeout=5)
else:
asyncio.run(component_registry.unregister_plugin(plugin_name))
except Exception:
# 最后兜底:直接同步调用(如果 unregister_plugin 为非协程)或忽略错误
try:
# 如果 unregister_plugin 是普通函数
component_registry.unregister_plugin(plugin_name)
except Exception as e:
logger.debug(f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}")
# 从已加载插件中移除
del self.loaded_plugins[plugin_name]

View File

@@ -199,7 +199,7 @@ class ChatterInterestScoringSystem:
# 如果内存中没有,尝试从关系追踪器获取
if hasattr(self, "relationship_tracker") and self.relationship_tracker:
try:
relationship_score = self.relationship_tracker.get_user_relationship_score(user_id)
relationship_score = await self.relationship_tracker.get_user_relationship_score(user_id)
# 同时更新内存缓存
self.user_relationships[user_id] = relationship_score
return relationship_score

View File

@@ -182,7 +182,7 @@ class ChatterPlanFilter:
if plan.mode == ChatMode.PROACTIVE:
long_term_memory_block = await self._get_long_term_memory_context()
chat_content_block, message_id_list = build_readable_messages_with_id(
chat_content_block, message_id_list = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in plan.chat_history],
timestamp_mode="normal",
truncate=False,
@@ -190,7 +190,7 @@ class ChatterPlanFilter:
)
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
actions_before_now = get_actions_by_timestamp_with_chat(
actions_before_now = await get_actions_by_timestamp_with_chat(
chat_id=plan.chat_id,
timestamp_start=time.time() - 3600,
timestamp_end=time.time(),
@@ -216,7 +216,7 @@ class ChatterPlanFilter:
)
# 为了兼容性保留原有的chat_content_block
chat_content_block, _ = build_readable_messages_with_id(
chat_content_block, _ = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in plan.chat_history],
timestamp_mode="normal",
read_mark=self.last_obs_time_mark,
@@ -224,7 +224,7 @@ class ChatterPlanFilter:
show_actions=True,
)
actions_before_now = get_actions_by_timestamp_with_chat(
actions_before_now = await get_actions_by_timestamp_with_chat(
chat_id=plan.chat_id,
timestamp_start=time.time() - 3600,
timestamp_end=time.time(),
@@ -319,7 +319,14 @@ class ChatterPlanFilter:
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
# 获取聊天流的上下文
stream_context = message_manager.stream_contexts.get(plan.chat_id)
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(plan.chat_id)
if not chat_stream:
logger.warning(f"[plan_filter] 聊天流 {plan.chat_id} 不存在")
return "最近没有聊天内容。", "没有未读消息。", []
stream_context = chat_stream.context_manager
# 获取真正的已读和未读消息
read_messages = stream_context.history_messages # 已读消息存储在history_messages中
@@ -338,7 +345,7 @@ class ChatterPlanFilter:
# 构建已读历史消息块
if read_messages:
read_content, read_ids = build_readable_messages_with_id(
read_content, read_ids = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in read_messages[-50:]], # 限制数量
timestamp_mode="normal_no_YMD",
truncate=False,

View File

@@ -138,7 +138,7 @@ class ChatterActionPlanner:
# 更新StreamContext中的消息信息并刷新focus_energy
if context:
from src.chat.message_manager.message_manager import message_manager
message_manager.update_message(
await message_manager.update_message(
stream_id=self.chat_id,
message_id=message.message_id,
interest_value=message_interest,
@@ -148,7 +148,7 @@ class ChatterActionPlanner:
# 更新数据库中的消息记录
try:
from src.chat.message_receive.storage import MessageStorage
MessageStorage.update_message_interest_value(message.message_id, message_interest)
await MessageStorage.update_message_interest_value(message.message_id, message_interest)
logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
except Exception as e:
logger.warning(f"更新数据库消息兴趣度失败: {e}")

View File

@@ -124,10 +124,10 @@ class EmojiAction(BaseAction):
emoji_base64, emoji_description = random.choice(all_emojis_data)
else:
# 获取最近的5条消息内容用于判断
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
messages_text = ""
if recent_messages:
messages_text = message_api.build_readable_messages(
messages_text = await message_api.build_readable_messages(
messages=recent_messages,
timestamp_mode="normal_no_YMD",
truncate=False,
@@ -185,10 +185,10 @@ class EmojiAction(BaseAction):
elif global_config.emoji.emoji_selection_mode == "description":
# --- 详细描述选择模式 ---
# 获取最近的5条消息内容用于判断
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
messages_text = ""
if recent_messages:
messages_text = message_api.build_readable_messages(
messages_text = await message_api.build_readable_messages(
messages=recent_messages,
timestamp_mode="normal_no_YMD",
truncate=False,

View File

@@ -118,7 +118,7 @@ class QZoneService:
async def read_and_process_feeds(self, target_name: str, stream_id: Optional[str]) -> Dict[str, Any]:
"""读取并处理指定好友的说说"""
target_person_id = person_api.get_person_id_by_name(target_name)
target_person_id = await person_api.get_person_id_by_name(target_name)
if not target_person_id:
return {"success": False, "message": f"找不到名为'{target_name}'的好友"}
target_qq = await person_api.get_person_value(target_person_id, "user_id")

View File

@@ -331,6 +331,7 @@ class NoticeHandler:
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
await event_manager.trigger_event(
<<<<<<< HEAD
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
permission_group=PLUGIN_NAME,
group_id=group_id,
@@ -342,6 +343,16 @@ class NoticeHandler:
type="text",
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
)
=======
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
permission_group=PLUGIN_NAME,
group_id=group_id,
user_id=user_id,
message_id=raw_message.get("message_id",""),
emoji_id=like_emoji_id
)
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]")
>>>>>>> 9912d7f643d347cbadcf1e3d618aa78bcbf89cc4
return seg_data, user_info
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:

View File

@@ -1,25 +0,0 @@
{
"manifest_version": 1,
"name": "web_search_tool",
"version": "1.0.0",
"description": "一个用于在互联网上搜索信息的工具",
"author": {
"name": "MoFox-Studio",
"url": "https://github.com/MoFox-Studio"
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.10.0"
},
"keywords": ["web_search", "url_parser"],
"categories": ["web_search", "url_parser"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": false,
"plugin_type": "web_search"
}
}

View File

@@ -1,110 +0,0 @@
#!/usr/bin/env python3
"""
测试 ChatStream 的 deepcopy 功能
验证 asyncio.Task 序列化问题是否已解决
"""
import asyncio
import sys
import os
import copy
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import UserInfo, GroupInfo
async def test_chat_stream_deepcopy():
"""测试 ChatStream 的 deepcopy 功能"""
print("[TEST] 开始测试 ChatStream deepcopy 功能...")
try:
# 创建测试用的用户和群组信息
user_info = UserInfo(
platform="test_platform",
user_id="test_user_123",
user_nickname="测试用户",
user_cardname="测试卡片名"
)
group_info = GroupInfo(
platform="test_platform",
group_id="test_group_456",
group_name="测试群组"
)
# 创建 ChatStream 实例
print("📝 创建 ChatStream 实例...")
stream_id = "test_stream_789"
platform = "test_platform"
chat_stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
print(f"[SUCCESS] ChatStream 创建成功: {chat_stream.stream_id}")
# 等待一下,让异步任务有机会创建
await asyncio.sleep(0.1)
# 尝试进行 deepcopy
print("[INFO] 尝试进行 deepcopy...")
copied_stream = copy.deepcopy(chat_stream)
print("[SUCCESS] deepcopy 成功!")
# 验证复制后的对象属性
print("\n[CHECK] 验证复制后的对象属性:")
print(f" - stream_id: {copied_stream.stream_id}")
print(f" - platform: {copied_stream.platform}")
print(f" - user_info: {copied_stream.user_info.user_nickname}")
print(f" - group_info: {copied_stream.group_info.group_name}")
# 检查 processing_task 是否被正确处理
if hasattr(copied_stream.stream_context, 'processing_task'):
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
if copied_stream.stream_context.processing_task is None:
print(" [SUCCESS] processing_task 已被正确设置为 None")
else:
print(" [WARNING] processing_task 不为 None")
else:
print(" [SUCCESS] stream_context 没有 processing_task 属性")
# 验证原始对象和复制对象是不同的实例
if id(chat_stream) != id(copied_stream):
print("[SUCCESS] 原始对象和复制对象是不同的实例")
else:
print("[ERROR] 原始对象和复制对象是同一个实例")
# 验证基本属性是否正确复制
if (chat_stream.stream_id == copied_stream.stream_id and
chat_stream.platform == copied_stream.platform):
print("[SUCCESS] 基本属性正确复制")
else:
print("[ERROR] 基本属性复制失败")
print("\n[COMPLETE] 测试完成deepcopy 功能修复成功!")
return True
except Exception as e:
print(f"[ERROR] 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
# 运行测试
result = asyncio.run(test_chat_stream_deepcopy())
if result:
print("\n[SUCCESS] 所有测试通过!")
sys.exit(0)
else:
print("\n[ERROR] 测试失败!")
sys.exit(1)

View File

@@ -1,109 +0,0 @@
#!/usr/bin/env python3
"""
简单的 ChatStream deepcopy 测试
"""
import asyncio
import sys
import os
import copy
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import UserInfo, GroupInfo
async def test_deepcopy():
"""测试 deepcopy 功能"""
print("开始测试 ChatStream deepcopy 功能...")
try:
# 创建测试用的用户和群组信息
user_info = UserInfo(
platform="test_platform",
user_id="test_user_123",
user_nickname="测试用户",
user_cardname="测试卡片名"
)
group_info = GroupInfo(
platform="test_platform",
group_id="test_group_456",
group_name="测试群组"
)
# 创建 ChatStream 实例
print("创建 ChatStream 实例...")
stream_id = "test_stream_789"
platform = "test_platform"
chat_stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
print(f"ChatStream 创建成功: {chat_stream.stream_id}")
# 等待一下,让异步任务有机会创建
await asyncio.sleep(0.1)
# 尝试进行 deepcopy
print("尝试进行 deepcopy...")
copied_stream = copy.deepcopy(chat_stream)
print("deepcopy 成功!")
# 验证复制后的对象属性
print("\n验证复制后的对象属性:")
print(f" - stream_id: {copied_stream.stream_id}")
print(f" - platform: {copied_stream.platform}")
print(f" - user_info: {copied_stream.user_info.user_nickname}")
print(f" - group_info: {copied_stream.group_info.group_name}")
# 检查 processing_task 是否被正确处理
if hasattr(copied_stream.stream_context, 'processing_task'):
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
if copied_stream.stream_context.processing_task is None:
print(" SUCCESS: processing_task 已被正确设置为 None")
else:
print(" WARNING: processing_task 不为 None")
else:
print(" SUCCESS: stream_context 没有 processing_task 属性")
# 验证原始对象和复制对象是不同的实例
if id(chat_stream) != id(copied_stream):
print("SUCCESS: 原始对象和复制对象是不同的实例")
else:
print("ERROR: 原始对象和复制对象是同一个实例")
# 验证基本属性是否正确复制
if (chat_stream.stream_id == copied_stream.stream_id and
chat_stream.platform == copied_stream.platform):
print("SUCCESS: 基本属性正确复制")
else:
print("ERROR: 基本属性复制失败")
print("\n测试完成deepcopy 功能修复成功!")
return True
except Exception as e:
print(f"ERROR: 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
# 运行测试
result = asyncio.run(test_deepcopy())
if result:
print("\n所有测试通过!")
sys.exit(0)
else:
print("\n测试失败!")
sys.exit(1)