refactor(chat): 迁移数据库操作为异步模式并修复相关调用
将同步数据库操作全面迁移为异步模式,主要涉及: - 将 `with get_db_session()` 改为 `async with get_db_session()` - 修复相关异步调用链,确保 await 正确传递 - 优化消息管理器、上下文管理器等核心组件的异步处理 - 移除同步的 person_id 获取方法,避免协程对象传递问题 修复 deepcopy 在 StreamContext 中的序列化问题,跳过不可序列化的 asyncio.Task 对象 删除无用的测试文件和废弃的插件清单文件
This commit is contained in:
@@ -261,7 +261,7 @@ class AntiPromptInjector:
|
||||
logger.warning("无法删除消息:缺少message_id")
|
||||
return
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 删除对应的消息记录
|
||||
stmt = delete(Messages).where(Messages.message_id == message_id)
|
||||
result = session.execute(stmt)
|
||||
@@ -287,7 +287,7 @@ class AntiPromptInjector:
|
||||
logger.warning("无法更新消息:缺少message_id")
|
||||
return
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 更新消息内容
|
||||
stmt = (
|
||||
update(Messages)
|
||||
|
||||
@@ -42,7 +42,7 @@ class SingleStreamContextManager:
|
||||
self._update_access_stats()
|
||||
return self.context
|
||||
|
||||
def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
||||
async def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
||||
"""添加消息到上下文
|
||||
|
||||
Args:
|
||||
@@ -53,30 +53,21 @@ class SingleStreamContextManager:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
try:
|
||||
# 添加消息到上下文
|
||||
self.context.add_message(message)
|
||||
|
||||
# 计算消息兴趣度
|
||||
interest_value = self._calculate_message_interest(message)
|
||||
interest_value = await self._calculate_message_interest(message)
|
||||
message.interest_value = interest_value
|
||||
|
||||
# 更新统计
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
|
||||
# 更新能量和分发
|
||||
if not skip_energy_update:
|
||||
self._update_stream_energy()
|
||||
await self._update_stream_energy()
|
||||
distribution_manager.add_stream_message(self.stream_id, 1)
|
||||
|
||||
logger.debug(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
async def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
"""更新上下文中的消息
|
||||
|
||||
Args:
|
||||
@@ -87,16 +78,11 @@ class SingleStreamContextManager:
|
||||
bool: 是否成功更新
|
||||
"""
|
||||
try:
|
||||
# 更新消息信息
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
|
||||
# 如果更新了兴趣度,重新计算能量
|
||||
if "interest_value" in updates:
|
||||
self._update_stream_energy()
|
||||
|
||||
await self._update_stream_energy()
|
||||
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -164,16 +150,13 @@ class SingleStreamContextManager:
|
||||
logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def clear_context(self) -> bool:
|
||||
async def clear_context(self) -> bool:
|
||||
"""清空上下文"""
|
||||
try:
|
||||
# 清空消息
|
||||
if hasattr(self.context, "unread_messages"):
|
||||
self.context.unread_messages.clear()
|
||||
if hasattr(self.context, "history_messages"):
|
||||
self.context.history_messages.clear()
|
||||
|
||||
# 重置状态
|
||||
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
|
||||
for attr in reset_attrs:
|
||||
if hasattr(self.context, attr):
|
||||
@@ -181,13 +164,9 @@ class SingleStreamContextManager:
|
||||
setattr(self.context, attr, 0)
|
||||
else:
|
||||
setattr(self.context, attr, time.time())
|
||||
|
||||
# 重新计算能量
|
||||
self._update_stream_energy()
|
||||
|
||||
await self._update_stream_energy()
|
||||
logger.info(f"清空单流上下文: {self.stream_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -249,39 +228,115 @@ class SingleStreamContextManager:
|
||||
self.last_access_time = time.time()
|
||||
self.access_count += 1
|
||||
|
||||
def _calculate_message_interest(self, message: DatabaseMessages) -> float:
|
||||
"""计算消息兴趣度"""
|
||||
async def _calculate_message_interest(self, message: DatabaseMessages) -> float:
|
||||
"""异步实现:使用插件的异步评分器正确 await 计算兴趣度并返回分数。"""
|
||||
try:
|
||||
# 使用插件内部的兴趣度评分系统
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
# 使用插件内部的兴趣度评分系统计算(同步方式)
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
|
||||
chatter_interest_scoring_system,
|
||||
)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
interest_score = loop.run_until_complete(
|
||||
chatter_interest_scoring_system._calculate_single_message_score(
|
||||
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=message, bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
)
|
||||
interest_value = interest_score.total_score
|
||||
interest_value = interest_score.total_score
|
||||
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
|
||||
return interest_value
|
||||
except Exception as e:
|
||||
logger.warning(f"插件内部兴趣度计算失败: {e}")
|
||||
return 0.5
|
||||
except Exception as e:
|
||||
logger.warning(f"插件内部兴趣度计算加载失败,使用默认值: {e}")
|
||||
return 0.5
|
||||
except Exception as e:
|
||||
logger.error(f"计算消息兴趣度失败: {e}")
|
||||
return 0.5
|
||||
|
||||
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
|
||||
async def _calculate_message_interest_async(self, message: DatabaseMessages) -> float:
|
||||
"""异步实现:使用插件的异步评分器正确 await 计算兴趣度并返回分数。"""
|
||||
try:
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
|
||||
chatter_interest_scoring_system,
|
||||
)
|
||||
|
||||
# 直接 await 插件的异步方法
|
||||
try:
|
||||
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=message, bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
interest_value = interest_score.total_score
|
||||
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
|
||||
return interest_value
|
||||
except Exception as e:
|
||||
logger.warning(f"插件内部兴趣度计算失败: {e}")
|
||||
return 0.5
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"插件内部兴趣度计算失败,使用默认值: {e}")
|
||||
interest_value = 0.5 # 默认中等兴趣度
|
||||
|
||||
return interest_value
|
||||
logger.warning(f"插件内部兴趣度计算加载失败,使用默认值: {e}")
|
||||
return 0.5
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算消息兴趣度失败: {e}")
|
||||
return 0.5
|
||||
|
||||
async def add_message_async(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
||||
"""异步实现的 add_message:将消息添加到 context,并 await 能量更新与分发。"""
|
||||
try:
|
||||
self.context.add_message(message)
|
||||
|
||||
interest_value = await self._calculate_message_interest_async(message)
|
||||
message.interest_value = interest_value
|
||||
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
|
||||
if not skip_energy_update:
|
||||
await self._update_stream_energy()
|
||||
distribution_manager.add_stream_message(self.stream_id, 1)
|
||||
|
||||
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度: {interest_value:.3f})")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_message_async(self, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
"""异步实现的 update_message:更新消息并在需要时 await 能量更新。"""
|
||||
try:
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
if "interest_value" in updates:
|
||||
await self._update_stream_energy()
|
||||
|
||||
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"更新单流上下文消息失败 (async) {self.stream_id}/{message_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def clear_context_async(self) -> bool:
|
||||
"""异步实现的 clear_context:清空消息并 await 能量重算。"""
|
||||
try:
|
||||
if hasattr(self.context, "unread_messages"):
|
||||
self.context.unread_messages.clear()
|
||||
if hasattr(self.context, "history_messages"):
|
||||
self.context.history_messages.clear()
|
||||
|
||||
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
|
||||
for attr in reset_attrs:
|
||||
if hasattr(self.context, attr):
|
||||
if attr in ["interruption_count", "afc_threshold_adjustment"]:
|
||||
setattr(self.context, attr, 0)
|
||||
else:
|
||||
setattr(self.context, attr, time.time())
|
||||
|
||||
await self._update_stream_energy()
|
||||
logger.info(f"清空单流上下文(异步): {self.stream_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _update_stream_energy(self):
|
||||
"""更新流能量"""
|
||||
try:
|
||||
@@ -305,4 +360,4 @@ class SingleStreamContextManager:
|
||||
distribution_manager.update_stream_energy(self.stream_id, energy)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新单流能量失败 {self.stream_id}: {e}")
|
||||
logger.error(f"更新单流能量失败 {self.stream_id}: {e}")
|
||||
|
||||
@@ -75,29 +75,23 @@ class MessageManager:
|
||||
|
||||
logger.info("消息管理器已停止")
|
||||
|
||||
def add_message(self, stream_id: str, message: DatabaseMessages):
|
||||
async def add_message(self, stream_id: str, message: DatabaseMessages):
|
||||
"""添加消息到指定聊天流"""
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
# 使用 ChatStream 的 context_manager 添加消息
|
||||
success = chat_stream.context_manager.add_message(message)
|
||||
|
||||
success = await chat_stream.context_manager.add_message(message)
|
||||
if success:
|
||||
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
|
||||
else:
|
||||
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
|
||||
|
||||
def update_message(
|
||||
async def update_message(
|
||||
self,
|
||||
stream_id: str,
|
||||
message_id: str,
|
||||
@@ -107,15 +101,11 @@ class MessageManager:
|
||||
):
|
||||
"""更新消息信息"""
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
# 构建更新字典
|
||||
updates = {}
|
||||
if interest_value is not None:
|
||||
updates["interest_value"] = interest_value
|
||||
@@ -123,41 +113,30 @@ class MessageManager:
|
||||
updates["actions"] = actions
|
||||
if should_reply is not None:
|
||||
updates["should_reply"] = should_reply
|
||||
|
||||
# 使用 ChatStream 的 context_manager 更新消息
|
||||
if updates:
|
||||
success = chat_stream.context_manager.update_message(message_id, updates)
|
||||
success = await chat_stream.context_manager.update_message(message_id, updates)
|
||||
if success:
|
||||
logger.debug(f"更新消息 {message_id} 成功")
|
||||
else:
|
||||
logger.warning(f"更新消息 {message_id} 失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
||||
|
||||
def add_action(self, stream_id: str, message_id: str, action: str):
|
||||
async def add_action(self, stream_id: str, message_id: str, action: str):
|
||||
"""添加动作到消息"""
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
|
||||
# 使用 ChatStream 的 context_manager 添加动作
|
||||
# 注意:这里需要根据实际的 API 调整
|
||||
# 假设我们可以通过 update_message 来添加动作
|
||||
success = chat_stream.context_manager.update_message(
|
||||
success = await chat_stream.context_manager.update_message(
|
||||
message_id, {"actions": [action]}
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.debug(f"为消息 {message_id} 添加动作 {action} 成功")
|
||||
else:
|
||||
logger.warning(f"为消息 {message_id} 添加动作 {action} 失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}")
|
||||
|
||||
@@ -382,36 +361,27 @@ class MessageManager:
|
||||
"start_time": self.stats.start_time,
|
||||
}
|
||||
|
||||
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
|
||||
async def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
|
||||
"""清理不活跃的聊天流"""
|
||||
try:
|
||||
# 通过 ChatManager 清理不活跃的流
|
||||
chat_manager = get_chat_manager()
|
||||
current_time = time.time()
|
||||
max_inactive_seconds = max_inactive_hours * 3600
|
||||
|
||||
inactive_streams = []
|
||||
for stream_id, chat_stream in chat_manager.streams.items():
|
||||
# 检查最后活跃时间
|
||||
if current_time - chat_stream.last_active_time > max_inactive_seconds:
|
||||
inactive_streams.append(stream_id)
|
||||
|
||||
# 清理不活跃的流
|
||||
for stream_id in inactive_streams:
|
||||
try:
|
||||
# 清理流的内容
|
||||
chat_stream.context_manager.clear_context()
|
||||
# 从 ChatManager 中移除
|
||||
await chat_stream.context_manager.clear_context()
|
||||
del chat_manager.streams[stream_id]
|
||||
logger.info(f"清理不活跃聊天流: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理聊天流 {stream_id} 失败: {e}")
|
||||
|
||||
if inactive_streams:
|
||||
logger.info(f"已清理 {len(inactive_streams)} 个不活跃聊天流")
|
||||
else:
|
||||
logger.debug("没有需要清理的不活跃聊天流")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理不活跃聊天流时发生错误: {e}")
|
||||
|
||||
|
||||
@@ -514,7 +514,7 @@ class ChatBot:
|
||||
db_message.chat_info_group_platform = message.chat_stream.group_info.platform
|
||||
|
||||
# 添加消息到消息管理器
|
||||
message_manager.add_message(message.chat_stream.stream_id, db_message)
|
||||
await message_manager.add_message(message.chat_stream.stream_id, db_message)
|
||||
logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}")
|
||||
|
||||
if template_group_name:
|
||||
|
||||
@@ -389,94 +389,105 @@ class ChatStream:
|
||||
from sqlalchemy import select, desc
|
||||
import asyncio
|
||||
|
||||
async def _load_messages():
|
||||
def _db_query():
|
||||
with get_db_session() as session:
|
||||
# 查询该stream_id的最近20条消息
|
||||
async def _load_history_messages_async():
|
||||
"""异步加载并转换历史消息到 stream_context(在事件循环中运行)。"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
stmt = (
|
||||
select(Messages)
|
||||
.where(Messages.chat_info_stream_id == self.stream_id)
|
||||
.order_by(desc(Messages.time))
|
||||
.limit(global_config.chat.max_context_size)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
results = result.scalars().all()
|
||||
return results
|
||||
result = await session.execute(stmt)
|
||||
db_messages = result.scalars().all()
|
||||
|
||||
# 在线程中执行数据库查询
|
||||
db_messages = await asyncio.to_thread(_db_query)
|
||||
# 转换为DatabaseMessages对象并添加到StreamContext
|
||||
for db_msg in db_messages:
|
||||
try:
|
||||
import orjson
|
||||
|
||||
# 转换为DatabaseMessages对象并添加到StreamContext
|
||||
for db_msg in db_messages:
|
||||
actions = None
|
||||
if db_msg.actions:
|
||||
try:
|
||||
actions = orjson.loads(db_msg.actions)
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
actions = None
|
||||
|
||||
db_message = DatabaseMessages(
|
||||
message_id=db_msg.message_id,
|
||||
time=db_msg.time,
|
||||
chat_id=db_msg.chat_id,
|
||||
reply_to=db_msg.reply_to,
|
||||
interest_value=db_msg.interest_value,
|
||||
key_words=db_msg.key_words,
|
||||
key_words_lite=db_msg.key_words_lite,
|
||||
is_mentioned=db_msg.is_mentioned,
|
||||
processed_plain_text=db_msg.processed_plain_text,
|
||||
display_message=db_msg.display_message,
|
||||
priority_mode=db_msg.priority_mode,
|
||||
priority_info=db_msg.priority_info,
|
||||
additional_config=db_msg.additional_config,
|
||||
is_emoji=db_msg.is_emoji,
|
||||
is_picid=db_msg.is_picid,
|
||||
is_command=db_msg.is_command,
|
||||
is_notify=db_msg.is_notify,
|
||||
user_id=db_msg.user_id,
|
||||
user_nickname=db_msg.user_nickname,
|
||||
user_cardname=db_msg.user_cardname,
|
||||
user_platform=db_msg.user_platform,
|
||||
chat_info_group_id=db_msg.chat_info_group_id,
|
||||
chat_info_group_name=db_msg.chat_info_group_name,
|
||||
chat_info_group_platform=db_msg.chat_info_group_platform,
|
||||
chat_info_user_id=db_msg.chat_info_user_id,
|
||||
chat_info_user_nickname=db_msg.chat_info_user_nickname,
|
||||
chat_info_user_cardname=db_msg.chat_info_user_cardname,
|
||||
chat_info_user_platform=db_msg.chat_info_user_platform,
|
||||
chat_info_stream_id=db_msg.chat_info_stream_id,
|
||||
chat_info_platform=db_msg.chat_info_platform,
|
||||
chat_info_create_time=db_msg.chat_info_create_time,
|
||||
chat_info_last_active_time=db_msg.chat_info_last_active_time,
|
||||
actions=actions,
|
||||
should_reply=getattr(db_msg, "should_reply", False) or False,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
db_message.is_read = True
|
||||
self.stream_context.history_messages.append(db_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"转换消息 {getattr(db_msg, 'message_id', '<unknown>')} 失败: {e}")
|
||||
continue
|
||||
|
||||
if self.stream_context.history_messages:
|
||||
logger.info(
|
||||
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"异步加载历史消息失败: {e}")
|
||||
|
||||
# 在已有事件循环中,避免调用 asyncio.run()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,安全地运行并等待完成
|
||||
asyncio.run(_load_history_messages_async())
|
||||
else:
|
||||
# 如果事件循环正在运行,在后台创建任务
|
||||
if loop.is_running():
|
||||
try:
|
||||
# 从SQLAlchemy模型转换为DatabaseMessages数据模型
|
||||
import orjson
|
||||
|
||||
# 解析actions字段(JSON格式)
|
||||
actions = None
|
||||
if db_msg.actions:
|
||||
try:
|
||||
actions = orjson.loads(db_msg.actions)
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
actions = None
|
||||
|
||||
db_message = DatabaseMessages(
|
||||
message_id=db_msg.message_id,
|
||||
time=db_msg.time,
|
||||
chat_id=db_msg.chat_id,
|
||||
reply_to=db_msg.reply_to,
|
||||
interest_value=db_msg.interest_value,
|
||||
key_words=db_msg.key_words,
|
||||
key_words_lite=db_msg.key_words_lite,
|
||||
is_mentioned=db_msg.is_mentioned,
|
||||
processed_plain_text=db_msg.processed_plain_text,
|
||||
display_message=db_msg.display_message,
|
||||
priority_mode=db_msg.priority_mode,
|
||||
priority_info=db_msg.priority_info,
|
||||
additional_config=db_msg.additional_config,
|
||||
is_emoji=db_msg.is_emoji,
|
||||
is_picid=db_msg.is_picid,
|
||||
is_command=db_msg.is_command,
|
||||
is_notify=db_msg.is_notify,
|
||||
user_id=db_msg.user_id,
|
||||
user_nickname=db_msg.user_nickname,
|
||||
user_cardname=db_msg.user_cardname,
|
||||
user_platform=db_msg.user_platform,
|
||||
chat_info_group_id=db_msg.chat_info_group_id,
|
||||
chat_info_group_name=db_msg.chat_info_group_name,
|
||||
chat_info_group_platform=db_msg.chat_info_group_platform,
|
||||
chat_info_user_id=db_msg.chat_info_user_id,
|
||||
chat_info_user_nickname=db_msg.chat_info_user_nickname,
|
||||
chat_info_user_cardname=db_msg.chat_info_user_cardname,
|
||||
chat_info_user_platform=db_msg.chat_info_user_platform,
|
||||
chat_info_stream_id=db_msg.chat_info_stream_id,
|
||||
chat_info_platform=db_msg.chat_info_platform,
|
||||
chat_info_create_time=db_msg.chat_info_create_time,
|
||||
chat_info_last_active_time=db_msg.chat_info_last_active_time,
|
||||
actions=actions,
|
||||
should_reply=getattr(db_msg, "should_reply", False) or False,
|
||||
)
|
||||
|
||||
# 添加调试日志:检查从数据库加载的interest_value
|
||||
logger.debug(
|
||||
f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
# 标记为已读并添加到历史消息
|
||||
db_message.is_read = True
|
||||
self.stream_context.history_messages.append(db_message)
|
||||
|
||||
asyncio.create_task(_load_history_messages_async())
|
||||
except Exception as e:
|
||||
logger.warning(f"转换消息 {db_msg.message_id} 失败: {e}")
|
||||
continue
|
||||
|
||||
if self.stream_context.history_messages:
|
||||
logger.info(
|
||||
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}"
|
||||
)
|
||||
|
||||
# 创建任务来加载历史消息
|
||||
asyncio.create_task(_load_messages())
|
||||
# 如果无法创建任务,退回到阻塞运行
|
||||
logger.warning(f"无法在事件循环中创建后台任务,尝试阻塞运行: {e}")
|
||||
asyncio.run(_load_history_messages_async())
|
||||
else:
|
||||
# loop 存在但未运行,使用 asyncio.run
|
||||
asyncio.run(_load_history_messages_async())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载历史消息失败: {e}")
|
||||
@@ -498,7 +509,7 @@ class ChatManager:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
# try:
|
||||
# with get_db_session() as session:
|
||||
# async with get_db_session() as session:
|
||||
# db.connect(reuse_if_open=True)
|
||||
# # 确保 ChatStreams 表存在
|
||||
# session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
|
||||
|
||||
@@ -219,7 +219,7 @@ class MessageStorage:
|
||||
return match.group(0)
|
||||
|
||||
@staticmethod
|
||||
def update_message_interest_value(message_id: str, interest_value: float) -> None:
|
||||
async def update_message_interest_value(message_id: str, interest_value: float) -> None:
|
||||
"""
|
||||
更新数据库中消息的interest_value字段
|
||||
|
||||
@@ -228,11 +228,11 @@ class MessageStorage:
|
||||
interest_value: 兴趣度值
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 更新消息的interest_value字段
|
||||
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}")
|
||||
@@ -244,7 +244,7 @@ class MessageStorage:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
|
||||
async def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
|
||||
"""
|
||||
修复指定聊天中interest_value为0或null的历史消息记录
|
||||
|
||||
@@ -256,7 +256,7 @@ class MessageStorage:
|
||||
修复的记录数量
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
from sqlalchemy import select, update
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
|
||||
@@ -271,7 +271,7 @@ class MessageStorage:
|
||||
)
|
||||
).limit(50) # 限制每次修复的数量,避免性能问题
|
||||
|
||||
result = session.execute(query)
|
||||
result = await session.execute(query)
|
||||
messages_to_fix = result.scalars().all()
|
||||
fixed_count = 0
|
||||
|
||||
@@ -297,12 +297,12 @@ class MessageStorage:
|
||||
Messages.message_id == msg.message_id
|
||||
).values(interest_value=default_interest)
|
||||
|
||||
result = session.execute(update_stmt)
|
||||
result = await session.execute(update_stmt)
|
||||
if result.rowcount > 0:
|
||||
fixed_count += 1
|
||||
logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}")
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值")
|
||||
return fixed_count
|
||||
|
||||
|
||||
@@ -297,15 +297,12 @@ class ChatterActionManager:
|
||||
return
|
||||
|
||||
# 通过message_manager更新消息的动作记录并刷新focus_energy
|
||||
if chat_stream.stream_id in message_manager.stream_contexts:
|
||||
message_manager.add_action(
|
||||
stream_id=chat_stream.stream_id,
|
||||
message_id=target_message_id,
|
||||
action=action_name
|
||||
)
|
||||
logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy")
|
||||
else:
|
||||
logger.debug(f"未找到stream_context: {chat_stream.stream_id}")
|
||||
await message_manager.add_action(
|
||||
stream_id=chat_stream.stream_id,
|
||||
message_id=target_message_id,
|
||||
action=action_name
|
||||
)
|
||||
logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录动作到消息失败: {e}")
|
||||
@@ -315,8 +312,11 @@ class ChatterActionManager:
|
||||
"""在动作执行成功后重置打断计数"""
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
try:
|
||||
if stream_id in message_manager.stream_contexts:
|
||||
context = message_manager.stream_contexts[stream_id]
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if chat_stream:
|
||||
context = chat_stream.context_manager
|
||||
if context.interruption_count > 0:
|
||||
old_count = context.interruption_count
|
||||
old_afc_adjustment = context.get_afc_threshold_adjustment()
|
||||
|
||||
@@ -73,7 +73,7 @@ class ActionModifier:
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
|
||||
# 获取聊天类型
|
||||
is_group_chat, _ = await get_chat_type_and_target_info(self.chat_id)
|
||||
is_group_chat, _ = get_chat_type_and_target_info(self.chat_id)
|
||||
all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION)
|
||||
|
||||
chat_type_removals = []
|
||||
|
||||
@@ -684,8 +684,11 @@ class DefaultReplyer:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
# 获取聊天流的上下文
|
||||
stream_context = message_manager.stream_contexts.get(chat_id)
|
||||
if stream_context:
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(chat_id)
|
||||
if chat_stream:
|
||||
stream_context = chat_stream.context_manager
|
||||
# 使用真正的已读和未读消息
|
||||
read_messages = stream_context.history_messages # 已读消息
|
||||
unread_messages = stream_context.get_unread_messages() # 未读消息
|
||||
@@ -693,7 +696,7 @@ class DefaultReplyer:
|
||||
# 构建已读历史消息 prompt
|
||||
read_history_prompt = ""
|
||||
if read_messages:
|
||||
read_content = build_readable_messages(
|
||||
read_content = await build_readable_messages(
|
||||
[msg.flatten() for msg in read_messages[-50:]], # 限制数量
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
@@ -716,7 +719,7 @@ class DefaultReplyer:
|
||||
]
|
||||
|
||||
if filtered_fallback_messages:
|
||||
read_content = build_readable_messages(
|
||||
read_content = await build_readable_messages(
|
||||
filtered_fallback_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
@@ -754,7 +757,7 @@ class DefaultReplyer:
|
||||
if platform and user_id:
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_info_manager = get_person_info_manager()
|
||||
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
|
||||
sender_name = person_info_manager.get_value(person_id, "person_name") or "未知用户"
|
||||
else:
|
||||
sender_name = "未知用户"
|
||||
|
||||
@@ -819,7 +822,7 @@ class DefaultReplyer:
|
||||
# 构建已读历史消息 prompt
|
||||
read_history_prompt = ""
|
||||
if read_messages:
|
||||
read_content = build_readable_messages(
|
||||
read_content = await build_readable_messages(
|
||||
read_messages[-50:],
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
@@ -853,7 +856,7 @@ class DefaultReplyer:
|
||||
if platform and user_id:
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_info_manager = get_person_info_manager()
|
||||
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
|
||||
sender_name = person_info_manager.get_value(person_id, "person_name") or "未知用户"
|
||||
else:
|
||||
sender_name = "未知用户"
|
||||
|
||||
@@ -1027,7 +1030,7 @@ class DefaultReplyer:
|
||||
|
||||
# 检查是否是bot自己的名字,如果是则替换为"(你)"
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
current_user_id = person_info_manager.get_value(person_id, "user_id")
|
||||
current_platform = reply_message.get("chat_info_platform")
|
||||
|
||||
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
|
||||
@@ -1046,7 +1049,7 @@ class DefaultReplyer:
|
||||
target = "(无消息内容)"
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
||||
platform = chat_stream.platform
|
||||
|
||||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
||||
@@ -1071,7 +1074,7 @@ class DefaultReplyer:
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
chat_talking_prompt_short = build_readable_messages(
|
||||
chat_talking_prompt_short = await build_readable_messages(
|
||||
message_list_before_short,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
@@ -1324,7 +1327,7 @@ class DefaultReplyer:
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
)
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
chat_talking_prompt_half = await build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
@@ -1523,7 +1526,7 @@ class DefaultReplyer:
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
||||
if not person_id:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
@@ -46,7 +46,7 @@ def replace_user_references_sync(
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore
|
||||
return person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
|
||||
|
||||
name_resolver = default_resolver
|
||||
|
||||
@@ -254,7 +254,7 @@ def get_raw_msg_by_timestamp_with_chat_users(
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat(
|
||||
async def get_actions_by_timestamp_with_chat(
|
||||
chat_id: str,
|
||||
timestamp_start: float = 0,
|
||||
timestamp_end: float = time.time(),
|
||||
@@ -273,22 +273,21 @@ def get_actions_by_timestamp_with_chat(
|
||||
f"limit={limit}, limit_mode={limit_mode}"
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = session.execute(
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions = list(result.scalars())
|
||||
actions_result = []
|
||||
for action in reversed(actions):
|
||||
action_dict = {
|
||||
@@ -305,38 +304,39 @@ def get_actions_by_timestamp_with_chat(
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else: # earliest
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions_result = []
|
||||
for action in actions:
|
||||
action_dict = {
|
||||
"id": action.id,
|
||||
"action_id": action.action_id,
|
||||
"time": action.time,
|
||||
"action_name": action.action_name,
|
||||
"action_data": action.action_data,
|
||||
"action_done": action.action_done,
|
||||
"action_build_into_prompt": action.action_build_into_prompt,
|
||||
"action_prompt_display": action.action_prompt_display,
|
||||
"chat_id": action.chat_id,
|
||||
"chat_info_stream_id": action.chat_info_stream_id,
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else: # earliest
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(result.scalars())
|
||||
actions_result = []
|
||||
for action in actions:
|
||||
action_dict = {
|
||||
"id": action.id,
|
||||
"action_id": action.action_id,
|
||||
"time": action.time,
|
||||
"action_name": action.action_name,
|
||||
"action_data": action.action_data,
|
||||
"action_done": action.action_done,
|
||||
"action_build_into_prompt": action.action_build_into_prompt,
|
||||
"action_prompt_display": action.action_prompt_display,
|
||||
"chat_id": action.chat_id,
|
||||
"chat_info_stream_id": action.chat_info_stream_id,
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else:
|
||||
query = session.execute(
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -347,7 +347,7 @@ def get_actions_by_timestamp_with_chat(
|
||||
)
|
||||
.order_by(ActionRecords.time.asc())
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions = list(result.scalars())
|
||||
actions_result = []
|
||||
for action in actions:
|
||||
action_dict = {
|
||||
@@ -367,14 +367,14 @@ def get_actions_by_timestamp_with_chat(
|
||||
return actions_result
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat_inclusive(
|
||||
async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = session.execute(
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -386,10 +386,10 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||
.order_by(ActionRecords.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions = list(result.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = session.execute(
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -402,7 +402,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
query = session.execute(
|
||||
query = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -507,7 +507,7 @@ def num_new_messages_since_with_users(
|
||||
return count_messages(message_filter=filter_query)
|
||||
|
||||
|
||||
def _build_readable_messages_internal(
|
||||
async def _build_readable_messages_internal(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -627,7 +627,7 @@ def _build_readable_messages_internal(
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name") # type: ignore
|
||||
|
||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||
if not person_name:
|
||||
@@ -800,7 +800,7 @@ def _build_readable_messages_internal(
|
||||
)
|
||||
|
||||
|
||||
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
构建图片映射信息字符串,显示图片的具体描述内容
|
||||
@@ -823,8 +823,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# 从数据库中获取图片描述
|
||||
description = "[图片内容未知]" # 默认描述
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
result = session.execute(select(Images).where(Images.image_id == pic_id))
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(Images).where(Images.image_id == pic_id))
|
||||
image = result.scalar_one_or_none()
|
||||
if image and image.description: # type: ignore
|
||||
description = image.description
|
||||
@@ -922,17 +922,17 @@ async def build_readable_messages_with_list(
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||
if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping):
|
||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
|
||||
return formatted_string, details_list
|
||||
|
||||
|
||||
def build_readable_messages_with_id(
|
||||
async def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -948,7 +948,7 @@ def build_readable_messages_with_id(
|
||||
"""
|
||||
message_id_list = assign_message_ids(messages)
|
||||
|
||||
formatted_string = build_readable_messages(
|
||||
formatted_string = await build_readable_messages(
|
||||
messages=messages,
|
||||
replace_bot_name=replace_bot_name,
|
||||
merge_messages=merge_messages,
|
||||
@@ -960,10 +960,16 @@ def build_readable_messages_with_id(
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
|
||||
# 如果存在图片映射信息,附加之
|
||||
if pic_mapping_info := await build_pic_mapping_info({}):
|
||||
# 如果当前没有图片映射则不附加
|
||||
if pic_mapping_info:
|
||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
|
||||
return formatted_string, message_id_list
|
||||
|
||||
|
||||
def build_readable_messages(
|
||||
async def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -1004,9 +1010,9 @@ def build_readable_messages(
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = session.execute(
|
||||
actions_in_range = (await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -1014,15 +1020,15 @@ def build_readable_messages(
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
).scalars()
|
||||
)).scalars()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = session.execute(
|
||||
action_after_latest = (await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
).scalars()
|
||||
)).scalars()
|
||||
|
||||
# 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError
|
||||
actions = [
|
||||
@@ -1053,7 +1059,7 @@ def build_readable_messages(
|
||||
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
copy_messages,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1064,7 +1070,7 @@ def build_readable_messages(
|
||||
)
|
||||
|
||||
# 生成图片映射信息并添加到最前面
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
return f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
else:
|
||||
@@ -1079,7 +1085,7 @@ def build_readable_messages(
|
||||
pic_counter = 1
|
||||
|
||||
# 分别格式化,但使用共享的图片映射
|
||||
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
|
||||
formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal(
|
||||
messages_before_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1090,7 +1096,7 @@ def build_readable_messages(
|
||||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
messages_after_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1106,7 +1112,7 @@ def build_readable_messages(
|
||||
|
||||
# 生成图片映射信息
|
||||
if pic_id_mapping:
|
||||
pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
||||
pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
||||
else:
|
||||
pic_mapping_info = "聊天记录信息:\n"
|
||||
|
||||
@@ -1229,7 +1235,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
|
||||
# 在最前面添加图片映射信息
|
||||
final_output_lines = []
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
final_output_lines.append(pic_mapping_info)
|
||||
final_output_lines.append("\n\n")
|
||||
|
||||
@@ -494,7 +494,7 @@ class Prompt:
|
||||
chat_history = ""
|
||||
if self.parameters.message_list_before_now_long:
|
||||
recent_messages = self.parameters.message_list_before_now_long[-10:]
|
||||
chat_history = build_readable_messages(
|
||||
chat_history = await build_readable_messages(
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
@@ -535,7 +535,7 @@ class Prompt:
|
||||
chat_history = ""
|
||||
if self.parameters.message_list_before_now_long:
|
||||
recent_messages = self.parameters.message_list_before_now_long[-20:]
|
||||
chat_history = build_readable_messages(
|
||||
chat_history = await build_readable_messages(
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
@@ -589,7 +589,7 @@ class Prompt:
|
||||
chat_history = ""
|
||||
if self.parameters.message_list_before_now_long:
|
||||
recent_messages = self.parameters.message_list_before_now_long[-15:]
|
||||
chat_history = build_readable_messages(
|
||||
chat_history = await build_readable_messages(
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
@@ -863,7 +863,7 @@ class Prompt:
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
||||
if not person_id:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
@@ -904,7 +904,7 @@ class Prompt:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target_id(reply_to: str) -> str:
|
||||
async def parse_reply_target_id(reply_to: str) -> str:
|
||||
"""
|
||||
解析回复目标中的用户ID
|
||||
|
||||
@@ -924,9 +924,9 @@ class Prompt:
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
user_id = person_info_manager.get_value(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
return ""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
@@ -662,9 +663,32 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_name = None
|
||||
if person_id:
|
||||
# get_value is async, so await it directly
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = person_info_manager.get_value_sync(person_id, "person_name")
|
||||
try:
|
||||
# 如果没有运行的事件循环,直接 asyncio.run
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环在运行,从其他线程提交并等待结果
|
||||
try:
|
||||
from concurrent.futures import TimeoutError
|
||||
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
person_info_manager.get_value(person_id, "person_name"), loop
|
||||
)
|
||||
person_name = fut.result(timeout=2)
|
||||
except Exception as e:
|
||||
# 无法在运行循环上安全等待,退回为 None
|
||||
logger.debug(f"无法通过运行的事件循环获取 person_name: {e}")
|
||||
person_name = None
|
||||
else:
|
||||
person_name = asyncio.run(person_info_manager.get_value(person_id, "person_name"))
|
||||
except RuntimeError:
|
||||
# get_event_loop 在某些上下文可能抛出 RuntimeError,退回到 asyncio.run
|
||||
try:
|
||||
person_name = asyncio.run(person_info_manager.get_value(person_id, "person_name"))
|
||||
except Exception as e:
|
||||
logger.debug(f"获取 person_name 失败: {e}")
|
||||
person_name = None
|
||||
|
||||
target_info["person_id"] = person_id
|
||||
target_info["person_name"] = person_name
|
||||
|
||||
@@ -344,6 +344,39 @@ class StreamContext(BaseDataModel):
|
||||
"""获取优先级信息"""
|
||||
return self.priority_info
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""自定义深拷贝,跳过不可序列化的 asyncio.Task (processing_task)。
|
||||
|
||||
deepcopy 在内部可能会尝试 pickle 某些对象(如 asyncio.Task),
|
||||
这会在多线程或运行时事件循环中导致 TypeError。这里我们手动复制
|
||||
__dict__ 中的字段,确保 processing_task 被设置为 None,其他字段使用
|
||||
copy.deepcopy 递归复制。
|
||||
"""
|
||||
import copy
|
||||
|
||||
# 如果已经复制过,直接返回缓存结果
|
||||
obj_id = id(self)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
# 创建一个未初始化的新实例,然后逐个字段深拷贝
|
||||
cls = self.__class__
|
||||
new = cls.__new__(cls)
|
||||
memo[obj_id] = new
|
||||
|
||||
for k, v in self.__dict__.items():
|
||||
if k == "processing_task":
|
||||
# 不复制 asyncio.Task,避免无法 pickling
|
||||
setattr(new, k, None)
|
||||
else:
|
||||
try:
|
||||
setattr(new, k, copy.deepcopy(v, memo))
|
||||
except Exception:
|
||||
# 如果某个字段无法深拷贝,退回到原始引用(安全性谨慎)
|
||||
setattr(new, k, v)
|
||||
|
||||
return new
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageManagerStats(BaseDataModel):
|
||||
|
||||
@@ -30,7 +30,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns}
|
||||
|
||||
|
||||
def find_messages(
|
||||
async def find_messages(
|
||||
message_filter: dict[str, Any],
|
||||
sort: Optional[List[tuple[str, int]]] = None,
|
||||
limit: int = 0,
|
||||
@@ -51,7 +51,7 @@ def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
query = select(Messages)
|
||||
|
||||
# 应用过滤器
|
||||
@@ -101,8 +101,8 @@ def find_messages(
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
try:
|
||||
results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
@@ -110,8 +110,8 @@ def find_messages(
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
try:
|
||||
latest_results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
latest_results = result.scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
@@ -135,8 +135,8 @@ def find_messages(
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = result = session.execute(query)
|
||||
result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
@@ -152,7 +152,7 @@ def find_messages(
|
||||
return []
|
||||
|
||||
|
||||
def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
"""
|
||||
根据提供的过滤器计算消息数量。
|
||||
|
||||
@@ -163,7 +163,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
query = select(func.count(Messages.id))
|
||||
|
||||
# 应用过滤器
|
||||
@@ -201,7 +201,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = session.execute(query).scalar()
|
||||
count = (await session.execute(query)).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
|
||||
@@ -148,8 +148,8 @@ class MainSystem:
|
||||
# 停止消息重组器
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system import EventType
|
||||
|
||||
asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
|
||||
|
||||
from src.utils.message_chunker import reassembler
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -110,7 +110,7 @@ class ChatMood:
|
||||
limit=int(global_config.chat.max_context_size / 3),
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
@@ -159,7 +159,7 @@ class ChatMood:
|
||||
limit=15,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import datetime
|
||||
import hashlib
|
||||
@@ -57,7 +58,7 @@ class PersonInfoManager:
|
||||
self.person_name_list = {}
|
||||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||||
# try:
|
||||
# with get_db_session() as session:
|
||||
# async with get_db_session() as session:
|
||||
# db.connect(reuse_if_open=True)
|
||||
# # 设置连接池参数(仅对SQLite有效)
|
||||
# if hasattr(db, "execute_sql"):
|
||||
@@ -75,7 +76,7 @@ class PersonInfoManager:
|
||||
try:
|
||||
pass
|
||||
# 在这里获取会话
|
||||
# with get_db_session() as session:
|
||||
# async with get_db_session() as session:
|
||||
# for record in session.execute(
|
||||
# select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
|
||||
# ).fetchall():
|
||||
@@ -87,58 +88,25 @@ class PersonInfoManager:
|
||||
|
||||
@staticmethod
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
"""获取唯一id"""
|
||||
"""获取唯一id(同步)
|
||||
|
||||
说明: 原来该方法为异步并在内部尝试执行数据库检查/迁移,导致在许多调用处未 await 时返回 coroutine 对象。
|
||||
为了避免将 coroutine 传递到其它同步调用(例如数据库查询条件)中,这里将方法改为同步并仅返回基于 platform 和 user_id 的 MD5 哈希值。
|
||||
|
||||
注意: 这会跳过原有的 napcat->qq 迁移检查逻辑。如需保留迁移,请使用显式的、在合适时机执行的迁移任务。
|
||||
"""
|
||||
# 检查platform是否为None或空
|
||||
if platform is None:
|
||||
platform = "unknown"
|
||||
|
||||
if "-" in platform:
|
||||
platform = platform.split("-")[1]
|
||||
# 在此处打一个补丁,如果platform为qq,尝试生成id后检查是否存在,如果不存在,则将平台换为napcat后再次检查,如果存在,则更新原id为platform为qq的id
|
||||
|
||||
components = [platform, str(user_id)]
|
||||
key = "_".join(components)
|
||||
|
||||
# 如果不是 qq 平台,直接返回计算的 id
|
||||
if platform != "qq":
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
qq_id = hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
|
||||
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 检查 qq_id 是否存在
|
||||
existing_qq = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if existing_qq:
|
||||
return p_id
|
||||
|
||||
# 如果 qq_id 不存在,尝试使用 napcat 作为平台生成对应 id 并检查
|
||||
nap_components = ["napcat", str(raw_user_id)]
|
||||
nap_key = "_".join(nap_components)
|
||||
nap_id = hashlib.md5(nap_key.encode()).hexdigest()
|
||||
|
||||
existing_nap = session.execute(select(PersonInfo).where(PersonInfo.person_id == nap_id)).scalar()
|
||||
if not existing_nap:
|
||||
# napcat 也不存在,返回 qq_id(未命中)
|
||||
return p_id
|
||||
|
||||
# napcat 存在,迁移该记录:更新 person_id 与 platform -> qq
|
||||
try:
|
||||
# 更新现有 napcat 记录
|
||||
existing_nap.person_id = p_id
|
||||
existing_nap.platform = "qq"
|
||||
existing_nap.user_id = str(raw_user_id)
|
||||
session.commit()
|
||||
return p_id
|
||||
except Exception:
|
||||
session.rollback()
|
||||
return p_id
|
||||
except Exception as e:
|
||||
logger.error(f"检查/迁移 napcat->qq 时出错: {e}")
|
||||
return p_id
|
||||
|
||||
return _db_check_and_migrate_sync(qq_id, user_id)
|
||||
# 直接返回计算的 id(同步)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def is_person_known(self, platform: str, user_id: int):
|
||||
"""判断是否认识某人"""
|
||||
@@ -157,17 +125,25 @@ class PersonInfoManager:
|
||||
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def get_person_id_by_person_name(person_name: str) -> str:
|
||||
"""根据用户名获取用户ID"""
|
||||
async def get_person_id_by_person_name(self, person_name: str) -> str:
|
||||
"""
|
||||
根据用户名获取用户ID(同步)
|
||||
|
||||
说明: 为了避免在多个调用点将 coroutine 误传递到数据库查询中,
|
||||
此处提供一个同步实现。优先在内存缓存 `self.person_name_list` 中查找,
|
||||
若未命中则返回空字符串。若后续需要更强的一致性,可在异步上下文
|
||||
额外实现带 await 的查询方法。
|
||||
"""
|
||||
try:
|
||||
# 在需要时获取会话
|
||||
async with get_db_session() as session:
|
||||
record = result = await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))
|
||||
result.scalar()
|
||||
return record.person_id if record else ""
|
||||
# 优先使用内存缓存加速查找:self.person_name_list maps person_id -> person_name
|
||||
for pid, pname in self.person_name_list.items():
|
||||
if pname == person_name:
|
||||
return pid
|
||||
|
||||
# 未找到缓存命中,避免在同步路径中进行阻塞的数据库查询,直接返回空字符串
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
@@ -578,26 +554,15 @@ class PersonInfoManager:
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_value(person_id: str, field_name: str) -> Any:
|
||||
async def get_value(person_id: str, field_name: str) -> Any:
|
||||
"""获取单个字段值(同步版本)"""
|
||||
if not person_id:
|
||||
logger.debug("get_value获取失败:person_id不能为空")
|
||||
return None
|
||||
|
||||
import asyncio
|
||||
|
||||
async def _get_record_sync():
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
|
||||
record = result.scalar()
|
||||
return record
|
||||
|
||||
try:
|
||||
record = asyncio.run(_get_record_sync())
|
||||
except RuntimeError:
|
||||
# 如果当前线程已经有事件循环在运行,则使用现有的循环
|
||||
loop = asyncio.get_running_loop()
|
||||
record = loop.run_until_complete(_get_record_sync())
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id))
|
||||
record = result.scalar()
|
||||
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class RelationshipFetcher:
|
||||
# 查询用户关系数据
|
||||
relationships = await db_query(
|
||||
UserRelationships,
|
||||
filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))],
|
||||
filters=[UserRelationships.user_id == str(person_info_manager.get_value(person_id, "user_id"))],
|
||||
limit=1,
|
||||
)
|
||||
|
||||
@@ -259,7 +259,7 @@ class RelationshipFetcher:
|
||||
# 记录信息获取请求
|
||||
self.info_fetching_cache.append(
|
||||
{
|
||||
"person_id": get_person_info_manager().get_person_id_by_person_name(person_name),
|
||||
"person_id": await get_person_info_manager().get_person_id_by_person_name(person_name),
|
||||
"person_name": person_name,
|
||||
"info_type": info_type,
|
||||
"start_time": time.time(),
|
||||
|
||||
@@ -412,7 +412,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def build_readable_messages_to_str(
|
||||
async def build_readable_messages_to_str(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -436,7 +436,7 @@ def build_readable_messages_to_str(
|
||||
Returns:
|
||||
格式化后的可读字符串
|
||||
"""
|
||||
return build_readable_messages(
|
||||
return await build_readable_messages(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions
|
||||
)
|
||||
|
||||
|
||||
@@ -134,7 +134,7 @@ async def is_person_known(platform: str, user_id: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_person_id_by_name(person_name: str) -> str:
|
||||
async def get_person_id_by_name(person_name: str) -> str:
|
||||
"""根据用户名获取person_id
|
||||
|
||||
Args:
|
||||
@@ -148,7 +148,7 @@ def get_person_id_by_name(person_name: str) -> str:
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
return person_info_manager.get_person_id_by_person_name(person_name)
|
||||
return await person_info_manager.get_person_id_by_person_name(person_name)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
|
||||
return ""
|
||||
|
||||
@@ -542,7 +542,22 @@ class PluginManager:
|
||||
plugin_instance.on_unload()
|
||||
|
||||
# 从组件注册表中移除插件的所有组件
|
||||
asyncio.run(component_registry.unregister_plugin(plugin_name))
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
component_registry.unregister_plugin(plugin_name), loop
|
||||
)
|
||||
fut.result(timeout=5)
|
||||
else:
|
||||
asyncio.run(component_registry.unregister_plugin(plugin_name))
|
||||
except Exception:
|
||||
# 最后兜底:直接同步调用(如果 unregister_plugin 为非协程)或忽略错误
|
||||
try:
|
||||
# 如果 unregister_plugin 是普通函数
|
||||
component_registry.unregister_plugin(plugin_name)
|
||||
except Exception as e:
|
||||
logger.debug(f"卸载插件时调用 component_registry.unregister_plugin 失败: {e}")
|
||||
|
||||
# 从已加载插件中移除
|
||||
del self.loaded_plugins[plugin_name]
|
||||
|
||||
@@ -199,7 +199,7 @@ class ChatterInterestScoringSystem:
|
||||
# 如果内存中没有,尝试从关系追踪器获取
|
||||
if hasattr(self, "relationship_tracker") and self.relationship_tracker:
|
||||
try:
|
||||
relationship_score = self.relationship_tracker.get_user_relationship_score(user_id)
|
||||
relationship_score = await self.relationship_tracker.get_user_relationship_score(user_id)
|
||||
# 同时更新内存缓存
|
||||
self.user_relationships[user_id] = relationship_score
|
||||
return relationship_score
|
||||
|
||||
@@ -182,7 +182,7 @@ class ChatterPlanFilter:
|
||||
if plan.mode == ChatMode.PROACTIVE:
|
||||
long_term_memory_block = await self._get_long_term_memory_context()
|
||||
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
truncate=False,
|
||||
@@ -190,7 +190,7 @@ class ChatterPlanFilter:
|
||||
)
|
||||
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
@@ -216,7 +216,7 @@ class ChatterPlanFilter:
|
||||
)
|
||||
|
||||
# 为了兼容性,保留原有的chat_content_block
|
||||
chat_content_block, _ = build_readable_messages_with_id(
|
||||
chat_content_block, _ = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
@@ -224,7 +224,7 @@ class ChatterPlanFilter:
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
@@ -319,7 +319,14 @@ class ChatterPlanFilter:
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
# 获取聊天流的上下文
|
||||
stream_context = message_manager.stream_contexts.get(plan.chat_id)
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(plan.chat_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"[plan_filter] 聊天流 {plan.chat_id} 不存在")
|
||||
return "最近没有聊天内容。", "没有未读消息。", []
|
||||
|
||||
stream_context = chat_stream.context_manager
|
||||
|
||||
# 获取真正的已读和未读消息
|
||||
read_messages = stream_context.history_messages # 已读消息存储在history_messages中
|
||||
@@ -338,7 +345,7 @@ class ChatterPlanFilter:
|
||||
|
||||
# 构建已读历史消息块
|
||||
if read_messages:
|
||||
read_content, read_ids = build_readable_messages_with_id(
|
||||
read_content, read_ids = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in read_messages[-50:]], # 限制数量
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
|
||||
@@ -138,7 +138,7 @@ class ChatterActionPlanner:
|
||||
# 更新StreamContext中的消息信息并刷新focus_energy
|
||||
if context:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
message_manager.update_message(
|
||||
await message_manager.update_message(
|
||||
stream_id=self.chat_id,
|
||||
message_id=message.message_id,
|
||||
interest_value=message_interest,
|
||||
@@ -148,7 +148,7 @@ class ChatterActionPlanner:
|
||||
# 更新数据库中的消息记录
|
||||
try:
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
MessageStorage.update_message_interest_value(message.message_id, message_interest)
|
||||
await MessageStorage.update_message_interest_value(message.message_id, message_interest)
|
||||
logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
|
||||
except Exception as e:
|
||||
logger.warning(f"更新数据库消息兴趣度失败: {e}")
|
||||
|
||||
@@ -124,10 +124,10 @@ class EmojiAction(BaseAction):
|
||||
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
||||
else:
|
||||
# 获取最近的5条消息内容用于判断
|
||||
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
||||
recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
||||
messages_text = ""
|
||||
if recent_messages:
|
||||
messages_text = message_api.build_readable_messages(
|
||||
messages_text = await message_api.build_readable_messages(
|
||||
messages=recent_messages,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
@@ -185,10 +185,10 @@ class EmojiAction(BaseAction):
|
||||
elif global_config.emoji.emoji_selection_mode == "description":
|
||||
# --- 详细描述选择模式 ---
|
||||
# 获取最近的5条消息内容用于判断
|
||||
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
||||
recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
||||
messages_text = ""
|
||||
if recent_messages:
|
||||
messages_text = message_api.build_readable_messages(
|
||||
messages_text = await message_api.build_readable_messages(
|
||||
messages=recent_messages,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
|
||||
@@ -118,7 +118,7 @@ class QZoneService:
|
||||
|
||||
async def read_and_process_feeds(self, target_name: str, stream_id: Optional[str]) -> Dict[str, Any]:
|
||||
"""读取并处理指定好友的说说"""
|
||||
target_person_id = person_api.get_person_id_by_name(target_name)
|
||||
target_person_id = await person_api.get_person_id_by_name(target_name)
|
||||
if not target_person_id:
|
||||
return {"success": False, "message": f"找不到名为'{target_name}'的好友"}
|
||||
target_qq = await person_api.get_person_value(target_person_id, "user_id")
|
||||
|
||||
@@ -331,6 +331,7 @@ class NoticeHandler:
|
||||
|
||||
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
|
||||
await event_manager.trigger_event(
|
||||
<<<<<<< HEAD
|
||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||
permission_group=PLUGIN_NAME,
|
||||
group_id=group_id,
|
||||
@@ -342,6 +343,16 @@ class NoticeHandler:
|
||||
type="text",
|
||||
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
|
||||
)
|
||||
=======
|
||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||
permission_group=PLUGIN_NAME,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
message_id=raw_message.get("message_id",""),
|
||||
emoji_id=like_emoji_id
|
||||
)
|
||||
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]")
|
||||
>>>>>>> 9912d7f643d347cbadcf1e3d618aa78bcbf89cc4
|
||||
return seg_data, user_info
|
||||
|
||||
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "web_search_tool",
|
||||
"version": "1.0.0",
|
||||
"description": "一个用于在互联网上搜索信息的工具",
|
||||
"author": {
|
||||
"name": "MoFox-Studio",
|
||||
"url": "https://github.com/MoFox-Studio"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.10.0"
|
||||
},
|
||||
"keywords": ["web_search", "url_parser"],
|
||||
"categories": ["web_search", "url_parser"],
|
||||
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
|
||||
"plugin_info": {
|
||||
"is_built_in": false,
|
||||
"plugin_type": "web_search"
|
||||
}
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试 ChatStream 的 deepcopy 功能
|
||||
验证 asyncio.Task 序列化问题是否已解决
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import copy
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
|
||||
async def test_chat_stream_deepcopy():
|
||||
"""测试 ChatStream 的 deepcopy 功能"""
|
||||
print("[TEST] 开始测试 ChatStream deepcopy 功能...")
|
||||
|
||||
try:
|
||||
# 创建测试用的用户和群组信息
|
||||
user_info = UserInfo(
|
||||
platform="test_platform",
|
||||
user_id="test_user_123",
|
||||
user_nickname="测试用户",
|
||||
user_cardname="测试卡片名"
|
||||
)
|
||||
|
||||
group_info = GroupInfo(
|
||||
platform="test_platform",
|
||||
group_id="test_group_456",
|
||||
group_name="测试群组"
|
||||
)
|
||||
|
||||
# 创建 ChatStream 实例
|
||||
print("📝 创建 ChatStream 实例...")
|
||||
stream_id = "test_stream_789"
|
||||
platform = "test_platform"
|
||||
|
||||
chat_stream = ChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
)
|
||||
|
||||
print(f"[SUCCESS] ChatStream 创建成功: {chat_stream.stream_id}")
|
||||
|
||||
# 等待一下,让异步任务有机会创建
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 尝试进行 deepcopy
|
||||
print("[INFO] 尝试进行 deepcopy...")
|
||||
copied_stream = copy.deepcopy(chat_stream)
|
||||
|
||||
print("[SUCCESS] deepcopy 成功!")
|
||||
|
||||
# 验证复制后的对象属性
|
||||
print("\n[CHECK] 验证复制后的对象属性:")
|
||||
print(f" - stream_id: {copied_stream.stream_id}")
|
||||
print(f" - platform: {copied_stream.platform}")
|
||||
print(f" - user_info: {copied_stream.user_info.user_nickname}")
|
||||
print(f" - group_info: {copied_stream.group_info.group_name}")
|
||||
|
||||
# 检查 processing_task 是否被正确处理
|
||||
if hasattr(copied_stream.stream_context, 'processing_task'):
|
||||
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
|
||||
if copied_stream.stream_context.processing_task is None:
|
||||
print(" [SUCCESS] processing_task 已被正确设置为 None")
|
||||
else:
|
||||
print(" [WARNING] processing_task 不为 None")
|
||||
else:
|
||||
print(" [SUCCESS] stream_context 没有 processing_task 属性")
|
||||
|
||||
# 验证原始对象和复制对象是不同的实例
|
||||
if id(chat_stream) != id(copied_stream):
|
||||
print("[SUCCESS] 原始对象和复制对象是不同的实例")
|
||||
else:
|
||||
print("[ERROR] 原始对象和复制对象是同一个实例")
|
||||
|
||||
# 验证基本属性是否正确复制
|
||||
if (chat_stream.stream_id == copied_stream.stream_id and
|
||||
chat_stream.platform == copied_stream.platform):
|
||||
print("[SUCCESS] 基本属性正确复制")
|
||||
else:
|
||||
print("[ERROR] 基本属性复制失败")
|
||||
|
||||
print("\n[COMPLETE] 测试完成!deepcopy 功能修复成功!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
result = asyncio.run(test_chat_stream_deepcopy())
|
||||
|
||||
if result:
|
||||
print("\n[SUCCESS] 所有测试通过!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n[ERROR] 测试失败!")
|
||||
sys.exit(1)
|
||||
@@ -1,109 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
简单的 ChatStream deepcopy 测试
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import copy
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
|
||||
async def test_deepcopy():
|
||||
"""测试 deepcopy 功能"""
|
||||
print("开始测试 ChatStream deepcopy 功能...")
|
||||
|
||||
try:
|
||||
# 创建测试用的用户和群组信息
|
||||
user_info = UserInfo(
|
||||
platform="test_platform",
|
||||
user_id="test_user_123",
|
||||
user_nickname="测试用户",
|
||||
user_cardname="测试卡片名"
|
||||
)
|
||||
|
||||
group_info = GroupInfo(
|
||||
platform="test_platform",
|
||||
group_id="test_group_456",
|
||||
group_name="测试群组"
|
||||
)
|
||||
|
||||
# 创建 ChatStream 实例
|
||||
print("创建 ChatStream 实例...")
|
||||
stream_id = "test_stream_789"
|
||||
platform = "test_platform"
|
||||
|
||||
chat_stream = ChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
)
|
||||
|
||||
print(f"ChatStream 创建成功: {chat_stream.stream_id}")
|
||||
|
||||
# 等待一下,让异步任务有机会创建
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 尝试进行 deepcopy
|
||||
print("尝试进行 deepcopy...")
|
||||
copied_stream = copy.deepcopy(chat_stream)
|
||||
|
||||
print("deepcopy 成功!")
|
||||
|
||||
# 验证复制后的对象属性
|
||||
print("\n验证复制后的对象属性:")
|
||||
print(f" - stream_id: {copied_stream.stream_id}")
|
||||
print(f" - platform: {copied_stream.platform}")
|
||||
print(f" - user_info: {copied_stream.user_info.user_nickname}")
|
||||
print(f" - group_info: {copied_stream.group_info.group_name}")
|
||||
|
||||
# 检查 processing_task 是否被正确处理
|
||||
if hasattr(copied_stream.stream_context, 'processing_task'):
|
||||
print(f" - processing_task: {copied_stream.stream_context.processing_task}")
|
||||
if copied_stream.stream_context.processing_task is None:
|
||||
print(" SUCCESS: processing_task 已被正确设置为 None")
|
||||
else:
|
||||
print(" WARNING: processing_task 不为 None")
|
||||
else:
|
||||
print(" SUCCESS: stream_context 没有 processing_task 属性")
|
||||
|
||||
# 验证原始对象和复制对象是不同的实例
|
||||
if id(chat_stream) != id(copied_stream):
|
||||
print("SUCCESS: 原始对象和复制对象是不同的实例")
|
||||
else:
|
||||
print("ERROR: 原始对象和复制对象是同一个实例")
|
||||
|
||||
# 验证基本属性是否正确复制
|
||||
if (chat_stream.stream_id == copied_stream.stream_id and
|
||||
chat_stream.platform == copied_stream.platform):
|
||||
print("SUCCESS: 基本属性正确复制")
|
||||
else:
|
||||
print("ERROR: 基本属性复制失败")
|
||||
|
||||
print("\n测试完成!deepcopy 功能修复成功!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
result = asyncio.run(test_deepcopy())
|
||||
|
||||
if result:
|
||||
print("\n所有测试通过!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n测试失败!")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user