refactor(chat): 优化异步任务处理和消息管理逻辑
- 使用asyncio.create_task替代await调用,提升并发性能 - 简化流管理器的槽位获取逻辑,移除回退方案 - 重构上下文管理器的消息添加和更新机制 - 移除StreamContext中的冗余方法,保持数据模型的简洁性 - 优化兴趣度评分系统的更新流程,减少阻塞操作 这些改动主要关注性能优化和代码结构简化,不涉及功能变更。
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import random
|
||||
import time
|
||||
@@ -304,14 +305,6 @@ class ExpressionSelector:
|
||||
try:
|
||||
# start_time = time.time()
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
||||
|
||||
# logger.info(f"模型名称: {model_name}")
|
||||
# logger.info(f"LLM返回结果: {content}")
|
||||
# if reasoning_content:
|
||||
# logger.info(f"LLM推理: {reasoning_content}")
|
||||
# else:
|
||||
# logger.info(f"LLM推理: 无")
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
@@ -338,7 +331,7 @@ class ExpressionSelector:
|
||||
|
||||
# 对选中的所有表达方式,一次性更新count数
|
||||
if valid_expressions:
|
||||
await self.update_expressions_count_batch(valid_expressions, 0.006)
|
||||
asyncio.create_task(self.update_expressions_count_batch(valid_expressions, 0.006))
|
||||
|
||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions
|
||||
|
||||
@@ -13,6 +13,7 @@ from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
from .distribution_manager import stream_loop_manager
|
||||
|
||||
@@ -54,7 +55,13 @@ class SingleStreamContextManager:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
try:
|
||||
self.context.add_message(message)
|
||||
# 直接操作上下文的消息列表
|
||||
message.is_read = False
|
||||
self.context.unread_messages.append(message)
|
||||
|
||||
# 自动检测和更新chat type
|
||||
self._detect_chat_type(message)
|
||||
|
||||
# 在上下文管理器中计算兴趣值
|
||||
await self._calculate_message_interest(message)
|
||||
self.total_messages += 1
|
||||
@@ -78,7 +85,28 @@ class SingleStreamContextManager:
|
||||
bool: 是否成功更新
|
||||
"""
|
||||
try:
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
# 直接在未读消息中查找并更新
|
||||
for message in self.context.unread_messages:
|
||||
if message.message_id == message_id:
|
||||
if "interest_value" in updates:
|
||||
message.interest_value = updates["interest_value"]
|
||||
if "actions" in updates:
|
||||
message.actions = updates["actions"]
|
||||
if "should_reply" in updates:
|
||||
message.should_reply = updates["should_reply"]
|
||||
break
|
||||
|
||||
# 在历史消息中查找并更新
|
||||
for message in self.context.history_messages:
|
||||
if message.message_id == message_id:
|
||||
if "interest_value" in updates:
|
||||
message.interest_value = updates["interest_value"]
|
||||
if "actions" in updates:
|
||||
message.actions = updates["actions"]
|
||||
if "should_reply" in updates:
|
||||
message.should_reply = updates["should_reply"]
|
||||
break
|
||||
|
||||
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -259,36 +287,17 @@ class SingleStreamContextManager:
|
||||
logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True)
|
||||
return 0.5
|
||||
|
||||
async def add_message_async(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
||||
"""异步实现的 add_message:将消息添加到 context,并 await 能量更新与分发。"""
|
||||
try:
|
||||
self.context.add_message(message)
|
||||
|
||||
# 在上下文管理器中计算兴趣值
|
||||
await self._calculate_message_interest(message)
|
||||
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||
|
||||
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_message_async(self, message_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""异步实现的 update_message:更新消息并在需要时 await 能量更新。"""
|
||||
try:
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
|
||||
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"更新单流上下文消息失败 (async) {self.stream_id}/{message_id}: {e}", exc_info=True)
|
||||
return False
|
||||
def _detect_chat_type(self, message: DatabaseMessages):
|
||||
"""根据消息内容自动检测聊天类型"""
|
||||
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
|
||||
if len(self.context.unread_messages) == 1: # 只有这条消息
|
||||
# 如果消息包含群组信息,则为群聊
|
||||
if hasattr(message, "chat_info_group_id") and message.chat_info_group_id:
|
||||
self.context.chat_type = ChatType.GROUP
|
||||
elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name:
|
||||
self.context.chat_type = ChatType.GROUP
|
||||
else:
|
||||
self.context.chat_type = ChatType.PRIVATE
|
||||
|
||||
async def clear_context_async(self) -> bool:
|
||||
"""异步实现的 clear_context:清空消息并 await 能量重算。"""
|
||||
|
||||
@@ -23,8 +23,6 @@ class StreamLoopManager:
|
||||
def __init__(self, max_concurrent_streams: int | None = None):
|
||||
# 流循环任务管理
|
||||
self.stream_loops: dict[str, asyncio.Task] = {}
|
||||
# 跟踪流使用的管理器类型
|
||||
self.stream_management_type: dict[str, str] = {} # stream_id -> "adaptive" or "fallback"
|
||||
|
||||
# 统计信息
|
||||
self.stats: dict[str, Any] = {
|
||||
@@ -115,7 +113,6 @@ class StreamLoopManager:
|
||||
return True
|
||||
|
||||
# 使用自适应流管理器获取槽位
|
||||
use_adaptive = False
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
@@ -132,21 +129,14 @@ class StreamLoopManager:
|
||||
)
|
||||
|
||||
if slot_acquired:
|
||||
use_adaptive = True
|
||||
logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})")
|
||||
else:
|
||||
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
|
||||
else:
|
||||
logger.debug("自适应管理器未运行,使用原始方法")
|
||||
logger.debug("自适应管理器未运行")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}")
|
||||
|
||||
# 如果自适应管理器失败或未运行,使用回退方案
|
||||
if not use_adaptive:
|
||||
if not await self._fallback_acquire_slot(stream_id, force):
|
||||
logger.debug(f"回退方案也失败: {stream_id}")
|
||||
return False
|
||||
logger.debug(f"自适应管理器获取槽位失败: {e}")
|
||||
|
||||
# 创建流循环任务
|
||||
try:
|
||||
@@ -155,68 +145,22 @@ class StreamLoopManager:
|
||||
name=f"stream_loop_{stream_id}"
|
||||
)
|
||||
self.stream_loops[stream_id] = loop_task
|
||||
# 记录管理器类型
|
||||
self.stream_management_type[stream_id] = "adaptive" if use_adaptive else "fallback"
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["active_streams"] += 1
|
||||
self.stats["total_loops"] += 1
|
||||
|
||||
logger.info(f"启动流循环任务: {stream_id} (管理器: {'adaptive' if use_adaptive else 'fallback'})")
|
||||
logger.info(f"启动流循环任务: {stream_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
|
||||
# 释放槽位
|
||||
if use_adaptive:
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
async def _fallback_acquire_slot(self, stream_id: str, force: bool) -> bool:
|
||||
"""回退方案:获取槽位(原始方法)"""
|
||||
# 判断是否需要强制分发
|
||||
should_force = force or await self._should_force_dispatch_for_stream(stream_id)
|
||||
|
||||
# 检查是否超过最大并发限制
|
||||
current_streams = len(self.stream_loops)
|
||||
if current_streams >= self.max_concurrent_streams and not should_force:
|
||||
logger.warning(
|
||||
f"超过最大并发流数限制({current_streams}/{self.max_concurrent_streams}),无法启动流 {stream_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# 处理强制分发情况
|
||||
if should_force and current_streams >= self.max_concurrent_streams:
|
||||
logger.warning(
|
||||
f"流 {stream_id} 未读消息积压严重(>{self.force_dispatch_unread_threshold}),突破并发限制强制启动分发 (当前: {current_streams}/{self.max_concurrent_streams})"
|
||||
)
|
||||
# 检查是否有现有的分发循环,如果有则先移除
|
||||
if stream_id in self.stream_loops:
|
||||
logger.info(f"发现现有流循环 {stream_id},将先移除再重新创建")
|
||||
existing_task = self.stream_loops[stream_id]
|
||||
if not existing_task.done():
|
||||
existing_task.cancel()
|
||||
# 创建异步任务来等待取消完成,并添加异常处理
|
||||
cancel_task = asyncio.create_task(
|
||||
self._wait_for_task_cancel(stream_id, existing_task),
|
||||
name=f"cancel_existing_loop_{stream_id}"
|
||||
)
|
||||
# 为取消任务添加异常处理,避免孤儿任务
|
||||
cancel_task.add_done_callback(
|
||||
lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception()
|
||||
else logger.error(f"取消任务异常: {stream_id} - {task.exception()}")
|
||||
)
|
||||
# 从字典中移除
|
||||
del self.stream_loops[stream_id]
|
||||
current_streams -= 1 # 更新当前流数量
|
||||
|
||||
return True
|
||||
|
||||
def _determine_stream_priority(self, stream_id: str) -> "StreamPriority":
|
||||
"""确定流优先级"""
|
||||
try:
|
||||
@@ -237,20 +181,6 @@ class StreamLoopManager:
|
||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||
return StreamPriority.NORMAL
|
||||
|
||||
# 创建流循环任务
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
self._stream_loop(stream_id),
|
||||
name=f"stream_loop_{stream_id}" # 为任务添加名称,便于调试
|
||||
)
|
||||
self.stream_loops[stream_id] = task
|
||||
self.stats["total_loops"] += 1
|
||||
|
||||
logger.info(f"启动流循环: {stream_id} (当前总数: {len(self.stream_loops)})")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建流循环任务失败: {stream_id} - {e}")
|
||||
return False
|
||||
|
||||
async def stop_stream_loop(self, stream_id: str) -> bool:
|
||||
"""停止指定流的循环任务
|
||||
@@ -342,17 +272,6 @@ class StreamLoopManager:
|
||||
# 4. 计算下次检查间隔
|
||||
interval = await self._calculate_interval(stream_id, has_messages)
|
||||
|
||||
if has_messages:
|
||||
updated_unread_count = self._get_unread_count(context)
|
||||
if self._needs_force_dispatch_for_context(context, updated_unread_count):
|
||||
interval = min(interval, max(self.force_dispatch_min_interval, 0.0))
|
||||
logger.debug(
|
||||
"流 %s 未读消息仍有 %d 条,使用加速分发间隔 %.2fs",
|
||||
stream_id,
|
||||
updated_unread_count,
|
||||
interval,
|
||||
)
|
||||
|
||||
# 5. sleep等待下次检查
|
||||
logger.info(f"流 {stream_id} 等待 {interval:.2f}s")
|
||||
await asyncio.sleep(interval)
|
||||
@@ -378,9 +297,6 @@ class StreamLoopManager:
|
||||
del self.stream_loops[stream_id]
|
||||
logger.debug(f"清理流循环标记: {stream_id}")
|
||||
|
||||
# 根据管理器类型释放相应的槽位
|
||||
management_type = self.stream_management_type.get(stream_id, "fallback")
|
||||
if management_type == "adaptive":
|
||||
# 释放自适应管理器的槽位
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
@@ -389,12 +305,6 @@ class StreamLoopManager:
|
||||
logger.debug(f"释放自适应流处理槽位: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"释放自适应流处理槽位失败: {e}")
|
||||
else:
|
||||
logger.debug(f"流 {stream_id} 使用回退方案,无需释放自适应槽位")
|
||||
|
||||
# 清理管理器类型记录
|
||||
if stream_id in self.stream_management_type:
|
||||
del self.stream_management_type[stream_id]
|
||||
|
||||
logger.info(f"流循环结束: {stream_id}")
|
||||
|
||||
@@ -417,7 +327,7 @@ class StreamLoopManager:
|
||||
logger.error(f"获取流上下文失败 {stream_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _has_messages_to_process(self, context: Any) -> bool:
|
||||
async def _has_messages_to_process(self, context: StreamContext) -> bool:
|
||||
"""检查是否有消息需要处理
|
||||
|
||||
Args:
|
||||
@@ -464,7 +374,7 @@ class StreamLoopManager:
|
||||
success = results.get("success", False)
|
||||
|
||||
if success:
|
||||
await self._refresh_focus_energy(stream_id)
|
||||
asyncio.create_task(self._refresh_focus_energy(stream_id))
|
||||
process_time = time.time() - start_time
|
||||
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
|
||||
else:
|
||||
@@ -553,16 +463,16 @@ class StreamLoopManager:
|
||||
logger.debug(f"检查流 {stream_id} 是否需要强制分发失败: {e}")
|
||||
return False
|
||||
|
||||
def _get_unread_count(self, context: Any) -> int:
|
||||
def _get_unread_count(self, context: StreamContext) -> int:
|
||||
try:
|
||||
unread_messages = getattr(context, "unread_messages", None)
|
||||
unread_messages = context.unread_messages
|
||||
if unread_messages is None:
|
||||
return 0
|
||||
return len(unread_messages)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _needs_force_dispatch_for_context(self, context: Any, unread_count: int | None = None) -> bool:
|
||||
def _needs_force_dispatch_for_context(self, context: StreamContext, unread_count: int | None = None) -> bool:
|
||||
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
|
||||
return False
|
||||
|
||||
|
||||
@@ -185,7 +185,7 @@ class ChatterActionManager:
|
||||
logger.info(f"{log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
# 存储no_reply信息到数据库
|
||||
await database_api.store_action_info(
|
||||
asyncio.create_task(database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
@@ -193,10 +193,10 @@ class ChatterActionManager:
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
)
|
||||
))
|
||||
|
||||
# 自动清空所有未读消息
|
||||
await self._clear_all_unread_messages(chat_stream.stream_id, "no_reply")
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply"))
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
@@ -214,12 +214,12 @@ class ChatterActionManager:
|
||||
|
||||
# 记录执行的动作到目标消息
|
||||
if success:
|
||||
await self._record_action_to_message(chat_stream, action_name, target_message, action_data)
|
||||
asyncio.create_task(self._record_action_to_message(chat_stream, action_name, target_message, action_data))
|
||||
# 自动清空所有未读消息
|
||||
if clear_unread_messages:
|
||||
await self._clear_all_unread_messages(chat_stream.stream_id, action_name)
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name))
|
||||
# 重置打断计数
|
||||
await self._reset_interruption_count_after_action(chat_stream.stream_id)
|
||||
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id))
|
||||
|
||||
return {
|
||||
"action_type": action_name,
|
||||
@@ -260,13 +260,13 @@ class ChatterActionManager:
|
||||
)
|
||||
|
||||
# 记录回复动作到目标消息
|
||||
await self._record_action_to_message(chat_stream, "reply", target_message, action_data)
|
||||
asyncio.create_task(self._record_action_to_message(chat_stream, "reply", target_message, action_data))
|
||||
|
||||
if clear_unread_messages:
|
||||
await self._clear_all_unread_messages(chat_stream.stream_id, "reply")
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "reply"))
|
||||
|
||||
# 回复成功,重置打断计数
|
||||
await self._reset_interruption_count_after_action(chat_stream.stream_id)
|
||||
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id))
|
||||
|
||||
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
|
||||
|
||||
|
||||
@@ -287,13 +287,13 @@ class DefaultReplyer:
|
||||
try:
|
||||
# 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await asyncio.create_task(self.build_prompt_reply_context(
|
||||
prompt = await self.build_prompt_reply_context(
|
||||
reply_to=reply_to,
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
))
|
||||
)
|
||||
|
||||
if not prompt:
|
||||
logger.warning("构建prompt失败,跳过回复生成")
|
||||
|
||||
@@ -175,7 +175,6 @@ class PromptManager:
|
||||
self._prompts = {}
|
||||
self._counter = 0
|
||||
self._context = PromptContext()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: str | None = None):
|
||||
@@ -190,7 +189,6 @@ class PromptManager:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
|
||||
async with self._lock:
|
||||
if name not in self._prompts:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
@@ -53,37 +53,7 @@ class StreamContext(BaseDataModel):
|
||||
priority_mode: str | None = None
|
||||
priority_info: dict | None = None
|
||||
|
||||
def add_message(self, message: "DatabaseMessages"):
|
||||
"""添加消息到上下文"""
|
||||
message.is_read = False
|
||||
self.unread_messages.append(message)
|
||||
|
||||
# 自动检测和更新chat type
|
||||
self._detect_chat_type(message)
|
||||
|
||||
def update_message_info(
|
||||
self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None
|
||||
):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
interest_value: 兴趣度值
|
||||
actions: 执行的动作列表
|
||||
should_reply: 是否应该回复
|
||||
"""
|
||||
# 在未读消息中查找并更新
|
||||
for message in self.unread_messages:
|
||||
if message.message_id == message_id:
|
||||
message.update_message_info(interest_value, actions, should_reply)
|
||||
break
|
||||
|
||||
# 在历史消息中查找并更新
|
||||
for message in self.history_messages:
|
||||
if message.message_id == message_id:
|
||||
message.update_message_info(interest_value, actions, should_reply)
|
||||
break
|
||||
|
||||
def add_action_to_message(self, message_id: str, action: str):
|
||||
"""
|
||||
@@ -105,42 +75,8 @@ class StreamContext(BaseDataModel):
|
||||
message.add_action(action)
|
||||
break
|
||||
|
||||
def _detect_chat_type(self, message: "DatabaseMessages"):
|
||||
"""根据消息内容自动检测聊天类型"""
|
||||
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
|
||||
if len(self.unread_messages) == 1: # 只有这条消息
|
||||
# 如果消息包含群组信息,则为群聊
|
||||
if hasattr(message, "chat_info_group_id") and message.chat_info_group_id:
|
||||
self.chat_type = ChatType.GROUP
|
||||
elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name:
|
||||
self.chat_type = ChatType.GROUP
|
||||
else:
|
||||
self.chat_type = ChatType.PRIVATE
|
||||
|
||||
def update_chat_type(self, chat_type: ChatType):
|
||||
"""手动更新聊天类型"""
|
||||
self.chat_type = chat_type
|
||||
|
||||
def set_chat_mode(self, chat_mode: ChatMode):
|
||||
"""设置聊天模式"""
|
||||
self.chat_mode = chat_mode
|
||||
|
||||
def is_group_chat(self) -> bool:
|
||||
"""检查是否为群聊"""
|
||||
return self.chat_type == ChatType.GROUP
|
||||
|
||||
def is_private_chat(self) -> bool:
|
||||
"""检查是否为私聊"""
|
||||
return self.chat_type == ChatType.PRIVATE
|
||||
|
||||
def get_chat_type_display(self) -> str:
|
||||
"""获取聊天类型的显示名称"""
|
||||
if self.chat_type == ChatType.GROUP:
|
||||
return "群聊"
|
||||
elif self.chat_type == ChatType.PRIVATE:
|
||||
return "私聊"
|
||||
else:
|
||||
return "未知类型"
|
||||
|
||||
def mark_message_as_read(self, message_id: str):
|
||||
"""标记消息为已读"""
|
||||
|
||||
@@ -220,11 +220,11 @@ class ChatterPlanExecutor:
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||
|
||||
'''
|
||||
# 记录用户关系追踪
|
||||
if success and action_info.action_message:
|
||||
await self._track_user_interaction(action_info, plan, reply_content)
|
||||
|
||||
'''
|
||||
execution_time = time.time() - start_time
|
||||
self.execution_stats["execution_times"].append(execution_time)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
集成兴趣度评分系统和用户关系追踪机制,实现智能化的聊天决策。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -98,10 +99,11 @@ class ChatterActionPlanner:
|
||||
|
||||
unread_messages = context.get_unread_messages() if context else []
|
||||
# 2. 使用新的兴趣度管理系统进行评分
|
||||
score = 0.0
|
||||
should_reply = False
|
||||
message_interest = 0.0
|
||||
reply_not_available = False
|
||||
interest_updates: list[dict[str, Any]] = []
|
||||
message_should_act = False
|
||||
message_should_reply = False
|
||||
|
||||
if unread_messages:
|
||||
# 直接使用消息中已计算的标志,无需重复计算兴趣值
|
||||
@@ -111,16 +113,7 @@ class ChatterActionPlanner:
|
||||
message_should_reply = getattr(message, "should_reply", False)
|
||||
message_should_act = getattr(message, "should_act", False)
|
||||
|
||||
# 确保interest_value不是None
|
||||
if message_interest is None:
|
||||
message_interest = 0.3
|
||||
|
||||
# 更新最高兴趣度消息
|
||||
if message_interest > score:
|
||||
score = message_interest
|
||||
if message_should_reply:
|
||||
should_reply = True
|
||||
else:
|
||||
if not message_should_reply:
|
||||
reply_not_available = True
|
||||
|
||||
# 如果should_act为false,强制设为no_action
|
||||
@@ -142,22 +135,23 @@ class ChatterActionPlanner:
|
||||
"message_id": message.message_id,
|
||||
"interest_value": 0.0,
|
||||
"should_reply": False,
|
||||
"should_act": False,
|
||||
}
|
||||
)
|
||||
|
||||
if interest_updates:
|
||||
await self._commit_interest_updates(interest_updates)
|
||||
asyncio.create_task(self._commit_interest_updates(interest_updates))
|
||||
|
||||
# 检查兴趣度是否达到非回复动作阈值
|
||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
if score < non_reply_action_interest_threshold:
|
||||
logger.info(f"兴趣度 {score:.3f} 低于阈值 {non_reply_action_interest_threshold:.3f},不执行动作")
|
||||
if not message_should_act:
|
||||
logger.info(f"兴趣度 {message_interest:.3f} 低于阈值 {non_reply_action_interest_threshold:.3f},不执行动作")
|
||||
# 直接返回 no_action
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
|
||||
no_action = ActionPlannerInfo(
|
||||
action_type="no_action",
|
||||
reasoning=f"兴趣度评分 {score:.3f} 未达阈值 {non_reply_action_interest_threshold:.3f}",
|
||||
reasoning=f"兴趣度评分 {message_interest:.3f} 未达阈值 {non_reply_action_interest_threshold:.3f}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
)
|
||||
@@ -169,9 +163,6 @@ class ChatterActionPlanner:
|
||||
plan_filter = ChatterPlanFilter(self.chat_id, available_actions)
|
||||
filtered_plan = await plan_filter.filter(reply_not_available, initial_plan)
|
||||
|
||||
# 检查filtered_plan是否有reply动作,用于统计
|
||||
has_reply_action = any(decision.action_type == "reply" for decision in filtered_plan.decided_actions)
|
||||
|
||||
# 5. 使用 PlanExecutor 执行 Plan
|
||||
execution_result = await self.executor.execute(filtered_plan)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user