refactor(chat): 优化异步任务处理和消息管理逻辑
- 使用asyncio.create_task替代await调用,提升并发性能 - 简化流管理器的槽位获取逻辑,移除回退方案 - 重构上下文管理器的消息添加和更新机制 - 移除StreamContext中的冗余方法,保持数据模型的简洁性 - 优化兴趣度评分系统的更新流程,减少阻塞操作 这些改动主要关注性能优化和代码结构简化,不涉及功能变更。
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
@@ -304,14 +305,6 @@ class ExpressionSelector:
|
|||||||
try:
|
try:
|
||||||
# start_time = time.time()
|
# start_time = time.time()
|
||||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||||
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
|
||||||
|
|
||||||
# logger.info(f"模型名称: {model_name}")
|
|
||||||
# logger.info(f"LLM返回结果: {content}")
|
|
||||||
# if reasoning_content:
|
|
||||||
# logger.info(f"LLM推理: {reasoning_content}")
|
|
||||||
# else:
|
|
||||||
# logger.info(f"LLM推理: 无")
|
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning("LLM返回空结果")
|
logger.warning("LLM返回空结果")
|
||||||
@@ -338,7 +331,7 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 对选中的所有表达方式,一次性更新count数
|
# 对选中的所有表达方式,一次性更新count数
|
||||||
if valid_expressions:
|
if valid_expressions:
|
||||||
await self.update_expressions_count_batch(valid_expressions, 0.006)
|
asyncio.create_task(self.update_expressions_count_batch(valid_expressions, 0.006))
|
||||||
|
|
||||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||||
return valid_expressions
|
return valid_expressions
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from src.common.data_models.database_data_model import DatabaseMessages
|
|||||||
from src.common.data_models.message_manager_data_model import StreamContext
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.plugin_system.base.component_types import ChatType
|
||||||
|
|
||||||
from .distribution_manager import stream_loop_manager
|
from .distribution_manager import stream_loop_manager
|
||||||
|
|
||||||
@@ -54,7 +55,13 @@ class SingleStreamContextManager:
|
|||||||
bool: 是否成功添加
|
bool: 是否成功添加
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.context.add_message(message)
|
# 直接操作上下文的消息列表
|
||||||
|
message.is_read = False
|
||||||
|
self.context.unread_messages.append(message)
|
||||||
|
|
||||||
|
# 自动检测和更新chat type
|
||||||
|
self._detect_chat_type(message)
|
||||||
|
|
||||||
# 在上下文管理器中计算兴趣值
|
# 在上下文管理器中计算兴趣值
|
||||||
await self._calculate_message_interest(message)
|
await self._calculate_message_interest(message)
|
||||||
self.total_messages += 1
|
self.total_messages += 1
|
||||||
@@ -78,7 +85,28 @@ class SingleStreamContextManager:
|
|||||||
bool: 是否成功更新
|
bool: 是否成功更新
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.context.update_message_info(message_id, **updates)
|
# 直接在未读消息中查找并更新
|
||||||
|
for message in self.context.unread_messages:
|
||||||
|
if message.message_id == message_id:
|
||||||
|
if "interest_value" in updates:
|
||||||
|
message.interest_value = updates["interest_value"]
|
||||||
|
if "actions" in updates:
|
||||||
|
message.actions = updates["actions"]
|
||||||
|
if "should_reply" in updates:
|
||||||
|
message.should_reply = updates["should_reply"]
|
||||||
|
break
|
||||||
|
|
||||||
|
# 在历史消息中查找并更新
|
||||||
|
for message in self.context.history_messages:
|
||||||
|
if message.message_id == message_id:
|
||||||
|
if "interest_value" in updates:
|
||||||
|
message.interest_value = updates["interest_value"]
|
||||||
|
if "actions" in updates:
|
||||||
|
message.actions = updates["actions"]
|
||||||
|
if "should_reply" in updates:
|
||||||
|
message.should_reply = updates["should_reply"]
|
||||||
|
break
|
||||||
|
|
||||||
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
|
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -259,36 +287,17 @@ class SingleStreamContextManager:
|
|||||||
logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True)
|
logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True)
|
||||||
return 0.5
|
return 0.5
|
||||||
|
|
||||||
async def add_message_async(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
def _detect_chat_type(self, message: DatabaseMessages):
|
||||||
"""异步实现的 add_message:将消息添加到 context,并 await 能量更新与分发。"""
|
"""根据消息内容自动检测聊天类型"""
|
||||||
try:
|
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
|
||||||
self.context.add_message(message)
|
if len(self.context.unread_messages) == 1: # 只有这条消息
|
||||||
|
# 如果消息包含群组信息,则为群聊
|
||||||
# 在上下文管理器中计算兴趣值
|
if hasattr(message, "chat_info_group_id") and message.chat_info_group_id:
|
||||||
await self._calculate_message_interest(message)
|
self.context.chat_type = ChatType.GROUP
|
||||||
|
elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name:
|
||||||
self.total_messages += 1
|
self.context.chat_type = ChatType.GROUP
|
||||||
self.last_access_time = time.time()
|
else:
|
||||||
|
self.context.chat_type = ChatType.PRIVATE
|
||||||
# 启动流的循环任务(如果还未启动)
|
|
||||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
|
||||||
|
|
||||||
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id}")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def update_message_async(self, message_id: str, updates: dict[str, Any]) -> bool:
|
|
||||||
"""异步实现的 update_message:更新消息并在需要时 await 能量更新。"""
|
|
||||||
try:
|
|
||||||
self.context.update_message_info(message_id, **updates)
|
|
||||||
|
|
||||||
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"更新单流上下文消息失败 (async) {self.stream_id}/{message_id}: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def clear_context_async(self) -> bool:
|
async def clear_context_async(self) -> bool:
|
||||||
"""异步实现的 clear_context:清空消息并 await 能量重算。"""
|
"""异步实现的 clear_context:清空消息并 await 能量重算。"""
|
||||||
|
|||||||
@@ -23,8 +23,6 @@ class StreamLoopManager:
|
|||||||
def __init__(self, max_concurrent_streams: int | None = None):
|
def __init__(self, max_concurrent_streams: int | None = None):
|
||||||
# 流循环任务管理
|
# 流循环任务管理
|
||||||
self.stream_loops: dict[str, asyncio.Task] = {}
|
self.stream_loops: dict[str, asyncio.Task] = {}
|
||||||
# 跟踪流使用的管理器类型
|
|
||||||
self.stream_management_type: dict[str, str] = {} # stream_id -> "adaptive" or "fallback"
|
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
self.stats: dict[str, Any] = {
|
self.stats: dict[str, Any] = {
|
||||||
@@ -115,7 +113,6 @@ class StreamLoopManager:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# 使用自适应流管理器获取槽位
|
# 使用自适应流管理器获取槽位
|
||||||
use_adaptive = False
|
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||||
adaptive_manager = get_adaptive_stream_manager()
|
adaptive_manager = get_adaptive_stream_manager()
|
||||||
@@ -132,21 +129,14 @@ class StreamLoopManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if slot_acquired:
|
if slot_acquired:
|
||||||
use_adaptive = True
|
|
||||||
logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})")
|
logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
|
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
|
||||||
else:
|
else:
|
||||||
logger.debug("自适应管理器未运行,使用原始方法")
|
logger.debug("自适应管理器未运行")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}")
|
logger.debug(f"自适应管理器获取槽位失败: {e}")
|
||||||
|
|
||||||
# 如果自适应管理器失败或未运行,使用回退方案
|
|
||||||
if not use_adaptive:
|
|
||||||
if not await self._fallback_acquire_slot(stream_id, force):
|
|
||||||
logger.debug(f"回退方案也失败: {stream_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 创建流循环任务
|
# 创建流循环任务
|
||||||
try:
|
try:
|
||||||
@@ -155,68 +145,22 @@ class StreamLoopManager:
|
|||||||
name=f"stream_loop_{stream_id}"
|
name=f"stream_loop_{stream_id}"
|
||||||
)
|
)
|
||||||
self.stream_loops[stream_id] = loop_task
|
self.stream_loops[stream_id] = loop_task
|
||||||
# 记录管理器类型
|
|
||||||
self.stream_management_type[stream_id] = "adaptive" if use_adaptive else "fallback"
|
|
||||||
|
|
||||||
# 更新统计信息
|
# 更新统计信息
|
||||||
self.stats["active_streams"] += 1
|
self.stats["active_streams"] += 1
|
||||||
self.stats["total_loops"] += 1
|
self.stats["total_loops"] += 1
|
||||||
|
|
||||||
logger.info(f"启动流循环任务: {stream_id} (管理器: {'adaptive' if use_adaptive else 'fallback'})")
|
logger.info(f"启动流循环任务: {stream_id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
|
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
|
||||||
# 释放槽位
|
# 释放槽位
|
||||||
if use_adaptive:
|
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||||
try:
|
adaptive_manager = get_adaptive_stream_manager()
|
||||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
adaptive_manager.release_stream_slot(stream_id)
|
||||||
adaptive_manager = get_adaptive_stream_manager()
|
|
||||||
adaptive_manager.release_stream_slot(stream_id)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _fallback_acquire_slot(self, stream_id: str, force: bool) -> bool:
|
|
||||||
"""回退方案:获取槽位(原始方法)"""
|
|
||||||
# 判断是否需要强制分发
|
|
||||||
should_force = force or await self._should_force_dispatch_for_stream(stream_id)
|
|
||||||
|
|
||||||
# 检查是否超过最大并发限制
|
|
||||||
current_streams = len(self.stream_loops)
|
|
||||||
if current_streams >= self.max_concurrent_streams and not should_force:
|
|
||||||
logger.warning(
|
|
||||||
f"超过最大并发流数限制({current_streams}/{self.max_concurrent_streams}),无法启动流 {stream_id}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 处理强制分发情况
|
|
||||||
if should_force and current_streams >= self.max_concurrent_streams:
|
|
||||||
logger.warning(
|
|
||||||
f"流 {stream_id} 未读消息积压严重(>{self.force_dispatch_unread_threshold}),突破并发限制强制启动分发 (当前: {current_streams}/{self.max_concurrent_streams})"
|
|
||||||
)
|
|
||||||
# 检查是否有现有的分发循环,如果有则先移除
|
|
||||||
if stream_id in self.stream_loops:
|
|
||||||
logger.info(f"发现现有流循环 {stream_id},将先移除再重新创建")
|
|
||||||
existing_task = self.stream_loops[stream_id]
|
|
||||||
if not existing_task.done():
|
|
||||||
existing_task.cancel()
|
|
||||||
# 创建异步任务来等待取消完成,并添加异常处理
|
|
||||||
cancel_task = asyncio.create_task(
|
|
||||||
self._wait_for_task_cancel(stream_id, existing_task),
|
|
||||||
name=f"cancel_existing_loop_{stream_id}"
|
|
||||||
)
|
|
||||||
# 为取消任务添加异常处理,避免孤儿任务
|
|
||||||
cancel_task.add_done_callback(
|
|
||||||
lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception()
|
|
||||||
else logger.error(f"取消任务异常: {stream_id} - {task.exception()}")
|
|
||||||
)
|
|
||||||
# 从字典中移除
|
|
||||||
del self.stream_loops[stream_id]
|
|
||||||
current_streams -= 1 # 更新当前流数量
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _determine_stream_priority(self, stream_id: str) -> "StreamPriority":
|
def _determine_stream_priority(self, stream_id: str) -> "StreamPriority":
|
||||||
"""确定流优先级"""
|
"""确定流优先级"""
|
||||||
try:
|
try:
|
||||||
@@ -237,20 +181,6 @@ class StreamLoopManager:
|
|||||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||||
return StreamPriority.NORMAL
|
return StreamPriority.NORMAL
|
||||||
|
|
||||||
# 创建流循环任务
|
|
||||||
try:
|
|
||||||
task = asyncio.create_task(
|
|
||||||
self._stream_loop(stream_id),
|
|
||||||
name=f"stream_loop_{stream_id}" # 为任务添加名称,便于调试
|
|
||||||
)
|
|
||||||
self.stream_loops[stream_id] = task
|
|
||||||
self.stats["total_loops"] += 1
|
|
||||||
|
|
||||||
logger.info(f"启动流循环: {stream_id} (当前总数: {len(self.stream_loops)})")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建流循环任务失败: {stream_id} - {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def stop_stream_loop(self, stream_id: str) -> bool:
|
async def stop_stream_loop(self, stream_id: str) -> bool:
|
||||||
"""停止指定流的循环任务
|
"""停止指定流的循环任务
|
||||||
@@ -342,17 +272,6 @@ class StreamLoopManager:
|
|||||||
# 4. 计算下次检查间隔
|
# 4. 计算下次检查间隔
|
||||||
interval = await self._calculate_interval(stream_id, has_messages)
|
interval = await self._calculate_interval(stream_id, has_messages)
|
||||||
|
|
||||||
if has_messages:
|
|
||||||
updated_unread_count = self._get_unread_count(context)
|
|
||||||
if self._needs_force_dispatch_for_context(context, updated_unread_count):
|
|
||||||
interval = min(interval, max(self.force_dispatch_min_interval, 0.0))
|
|
||||||
logger.debug(
|
|
||||||
"流 %s 未读消息仍有 %d 条,使用加速分发间隔 %.2fs",
|
|
||||||
stream_id,
|
|
||||||
updated_unread_count,
|
|
||||||
interval,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. sleep等待下次检查
|
# 5. sleep等待下次检查
|
||||||
logger.info(f"流 {stream_id} 等待 {interval:.2f}s")
|
logger.info(f"流 {stream_id} 等待 {interval:.2f}s")
|
||||||
await asyncio.sleep(interval)
|
await asyncio.sleep(interval)
|
||||||
@@ -378,23 +297,14 @@ class StreamLoopManager:
|
|||||||
del self.stream_loops[stream_id]
|
del self.stream_loops[stream_id]
|
||||||
logger.debug(f"清理流循环标记: {stream_id}")
|
logger.debug(f"清理流循环标记: {stream_id}")
|
||||||
|
|
||||||
# 根据管理器类型释放相应的槽位
|
# 释放自适应管理器的槽位
|
||||||
management_type = self.stream_management_type.get(stream_id, "fallback")
|
try:
|
||||||
if management_type == "adaptive":
|
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||||
# 释放自适应管理器的槽位
|
adaptive_manager = get_adaptive_stream_manager()
|
||||||
try:
|
adaptive_manager.release_stream_slot(stream_id)
|
||||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
logger.debug(f"释放自适应流处理槽位: {stream_id}")
|
||||||
adaptive_manager = get_adaptive_stream_manager()
|
except Exception as e:
|
||||||
adaptive_manager.release_stream_slot(stream_id)
|
logger.debug(f"释放自适应流处理槽位失败: {e}")
|
||||||
logger.debug(f"释放自适应流处理槽位: {stream_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"释放自适应流处理槽位失败: {e}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"流 {stream_id} 使用回退方案,无需释放自适应槽位")
|
|
||||||
|
|
||||||
# 清理管理器类型记录
|
|
||||||
if stream_id in self.stream_management_type:
|
|
||||||
del self.stream_management_type[stream_id]
|
|
||||||
|
|
||||||
logger.info(f"流循环结束: {stream_id}")
|
logger.info(f"流循环结束: {stream_id}")
|
||||||
|
|
||||||
@@ -417,7 +327,7 @@ class StreamLoopManager:
|
|||||||
logger.error(f"获取流上下文失败 {stream_id}: {e}")
|
logger.error(f"获取流上下文失败 {stream_id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _has_messages_to_process(self, context: Any) -> bool:
|
async def _has_messages_to_process(self, context: StreamContext) -> bool:
|
||||||
"""检查是否有消息需要处理
|
"""检查是否有消息需要处理
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -464,7 +374,7 @@ class StreamLoopManager:
|
|||||||
success = results.get("success", False)
|
success = results.get("success", False)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
await self._refresh_focus_energy(stream_id)
|
asyncio.create_task(self._refresh_focus_energy(stream_id))
|
||||||
process_time = time.time() - start_time
|
process_time = time.time() - start_time
|
||||||
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
|
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
|
||||||
else:
|
else:
|
||||||
@@ -553,16 +463,16 @@ class StreamLoopManager:
|
|||||||
logger.debug(f"检查流 {stream_id} 是否需要强制分发失败: {e}")
|
logger.debug(f"检查流 {stream_id} 是否需要强制分发失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_unread_count(self, context: Any) -> int:
|
def _get_unread_count(self, context: StreamContext) -> int:
|
||||||
try:
|
try:
|
||||||
unread_messages = getattr(context, "unread_messages", None)
|
unread_messages = context.unread_messages
|
||||||
if unread_messages is None:
|
if unread_messages is None:
|
||||||
return 0
|
return 0
|
||||||
return len(unread_messages)
|
return len(unread_messages)
|
||||||
except Exception:
|
except Exception:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def _needs_force_dispatch_for_context(self, context: Any, unread_count: int | None = None) -> bool:
|
def _needs_force_dispatch_for_context(self, context: StreamContext, unread_count: int | None = None) -> bool:
|
||||||
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
|
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ class ChatterActionManager:
|
|||||||
logger.info(f"{log_prefix} 选择不回复,原因: {reason}")
|
logger.info(f"{log_prefix} 选择不回复,原因: {reason}")
|
||||||
|
|
||||||
# 存储no_reply信息到数据库
|
# 存储no_reply信息到数据库
|
||||||
await database_api.store_action_info(
|
asyncio.create_task(database_api.store_action_info(
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream,
|
||||||
action_build_into_prompt=False,
|
action_build_into_prompt=False,
|
||||||
action_prompt_display=reason,
|
action_prompt_display=reason,
|
||||||
@@ -193,10 +193,10 @@ class ChatterActionManager:
|
|||||||
thinking_id=thinking_id,
|
thinking_id=thinking_id,
|
||||||
action_data={"reason": reason},
|
action_data={"reason": reason},
|
||||||
action_name="no_reply",
|
action_name="no_reply",
|
||||||
)
|
))
|
||||||
|
|
||||||
# 自动清空所有未读消息
|
# 自动清空所有未读消息
|
||||||
await self._clear_all_unread_messages(chat_stream.stream_id, "no_reply")
|
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply"))
|
||||||
|
|
||||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||||
|
|
||||||
@@ -214,12 +214,12 @@ class ChatterActionManager:
|
|||||||
|
|
||||||
# 记录执行的动作到目标消息
|
# 记录执行的动作到目标消息
|
||||||
if success:
|
if success:
|
||||||
await self._record_action_to_message(chat_stream, action_name, target_message, action_data)
|
asyncio.create_task(self._record_action_to_message(chat_stream, action_name, target_message, action_data))
|
||||||
# 自动清空所有未读消息
|
# 自动清空所有未读消息
|
||||||
if clear_unread_messages:
|
if clear_unread_messages:
|
||||||
await self._clear_all_unread_messages(chat_stream.stream_id, action_name)
|
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name))
|
||||||
# 重置打断计数
|
# 重置打断计数
|
||||||
await self._reset_interruption_count_after_action(chat_stream.stream_id)
|
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"action_type": action_name,
|
"action_type": action_name,
|
||||||
@@ -260,13 +260,13 @@ class ChatterActionManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 记录回复动作到目标消息
|
# 记录回复动作到目标消息
|
||||||
await self._record_action_to_message(chat_stream, "reply", target_message, action_data)
|
asyncio.create_task(self._record_action_to_message(chat_stream, "reply", target_message, action_data))
|
||||||
|
|
||||||
if clear_unread_messages:
|
if clear_unread_messages:
|
||||||
await self._clear_all_unread_messages(chat_stream.stream_id, "reply")
|
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "reply"))
|
||||||
|
|
||||||
# 回复成功,重置打断计数
|
# 回复成功,重置打断计数
|
||||||
await self._reset_interruption_count_after_action(chat_stream.stream_id)
|
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id))
|
||||||
|
|
||||||
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
|
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
|
||||||
|
|
||||||
|
|||||||
@@ -287,13 +287,13 @@ class DefaultReplyer:
|
|||||||
try:
|
try:
|
||||||
# 构建 Prompt
|
# 构建 Prompt
|
||||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||||
prompt = await asyncio.create_task(self.build_prompt_reply_context(
|
prompt = await self.build_prompt_reply_context(
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
enable_tool=enable_tool,
|
enable_tool=enable_tool,
|
||||||
reply_message=reply_message,
|
reply_message=reply_message,
|
||||||
))
|
)
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
logger.warning("构建prompt失败,跳过回复生成")
|
logger.warning("构建prompt失败,跳过回复生成")
|
||||||
|
|||||||
@@ -175,7 +175,6 @@ class PromptManager:
|
|||||||
self._prompts = {}
|
self._prompts = {}
|
||||||
self._counter = 0
|
self._counter = 0
|
||||||
self._context = PromptContext()
|
self._context = PromptContext()
|
||||||
self._lock = asyncio.Lock()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def async_message_scope(self, message_id: str | None = None):
|
async def async_message_scope(self, message_id: str | None = None):
|
||||||
@@ -190,10 +189,9 @@ class PromptManager:
|
|||||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||||
return context_prompt
|
return context_prompt
|
||||||
|
|
||||||
async with self._lock:
|
if name not in self._prompts:
|
||||||
if name not in self._prompts:
|
raise KeyError(f"Prompt '{name}' not found")
|
||||||
raise KeyError(f"Prompt '{name}' not found")
|
return self._prompts[name]
|
||||||
return self._prompts[name]
|
|
||||||
|
|
||||||
def generate_name(self, template: str) -> str:
|
def generate_name(self, template: str) -> str:
|
||||||
"""为未命名的prompt生成名称"""
|
"""为未命名的prompt生成名称"""
|
||||||
|
|||||||
@@ -53,38 +53,8 @@ class StreamContext(BaseDataModel):
|
|||||||
priority_mode: str | None = None
|
priority_mode: str | None = None
|
||||||
priority_info: dict | None = None
|
priority_info: dict | None = None
|
||||||
|
|
||||||
def add_message(self, message: "DatabaseMessages"):
|
|
||||||
"""添加消息到上下文"""
|
|
||||||
message.is_read = False
|
|
||||||
self.unread_messages.append(message)
|
|
||||||
|
|
||||||
# 自动检测和更新chat type
|
|
||||||
self._detect_chat_type(message)
|
|
||||||
|
|
||||||
def update_message_info(
|
|
||||||
self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
更新消息信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_id: 消息ID
|
|
||||||
interest_value: 兴趣度值
|
|
||||||
actions: 执行的动作列表
|
|
||||||
should_reply: 是否应该回复
|
|
||||||
"""
|
|
||||||
# 在未读消息中查找并更新
|
|
||||||
for message in self.unread_messages:
|
|
||||||
if message.message_id == message_id:
|
|
||||||
message.update_message_info(interest_value, actions, should_reply)
|
|
||||||
break
|
|
||||||
|
|
||||||
# 在历史消息中查找并更新
|
|
||||||
for message in self.history_messages:
|
|
||||||
if message.message_id == message_id:
|
|
||||||
message.update_message_info(interest_value, actions, should_reply)
|
|
||||||
break
|
|
||||||
|
|
||||||
def add_action_to_message(self, message_id: str, action: str):
|
def add_action_to_message(self, message_id: str, action: str):
|
||||||
"""
|
"""
|
||||||
向指定消息添加执行的动作
|
向指定消息添加执行的动作
|
||||||
@@ -105,42 +75,8 @@ class StreamContext(BaseDataModel):
|
|||||||
message.add_action(action)
|
message.add_action(action)
|
||||||
break
|
break
|
||||||
|
|
||||||
def _detect_chat_type(self, message: "DatabaseMessages"):
|
|
||||||
"""根据消息内容自动检测聊天类型"""
|
|
||||||
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
|
|
||||||
if len(self.unread_messages) == 1: # 只有这条消息
|
|
||||||
# 如果消息包含群组信息,则为群聊
|
|
||||||
if hasattr(message, "chat_info_group_id") and message.chat_info_group_id:
|
|
||||||
self.chat_type = ChatType.GROUP
|
|
||||||
elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name:
|
|
||||||
self.chat_type = ChatType.GROUP
|
|
||||||
else:
|
|
||||||
self.chat_type = ChatType.PRIVATE
|
|
||||||
|
|
||||||
def update_chat_type(self, chat_type: ChatType):
|
|
||||||
"""手动更新聊天类型"""
|
|
||||||
self.chat_type = chat_type
|
|
||||||
|
|
||||||
def set_chat_mode(self, chat_mode: ChatMode):
|
|
||||||
"""设置聊天模式"""
|
|
||||||
self.chat_mode = chat_mode
|
|
||||||
|
|
||||||
def is_group_chat(self) -> bool:
|
|
||||||
"""检查是否为群聊"""
|
|
||||||
return self.chat_type == ChatType.GROUP
|
|
||||||
|
|
||||||
def is_private_chat(self) -> bool:
|
|
||||||
"""检查是否为私聊"""
|
|
||||||
return self.chat_type == ChatType.PRIVATE
|
|
||||||
|
|
||||||
def get_chat_type_display(self) -> str:
|
|
||||||
"""获取聊天类型的显示名称"""
|
|
||||||
if self.chat_type == ChatType.GROUP:
|
|
||||||
return "群聊"
|
|
||||||
elif self.chat_type == ChatType.PRIVATE:
|
|
||||||
return "私聊"
|
|
||||||
else:
|
|
||||||
return "未知类型"
|
|
||||||
|
|
||||||
def mark_message_as_read(self, message_id: str):
|
def mark_message_as_read(self, message_id: str):
|
||||||
"""标记消息为已读"""
|
"""标记消息为已读"""
|
||||||
|
|||||||
@@ -220,11 +220,11 @@ class ChatterPlanExecutor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
|
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||||
|
'''
|
||||||
# 记录用户关系追踪
|
# 记录用户关系追踪
|
||||||
if success and action_info.action_message:
|
if success and action_info.action_message:
|
||||||
await self._track_user_interaction(action_info, plan, reply_content)
|
await self._track_user_interaction(action_info, plan, reply_content)
|
||||||
|
'''
|
||||||
execution_time = time.time() - start_time
|
execution_time = time.time() - start_time
|
||||||
self.execution_stats["execution_times"].append(execution_time)
|
self.execution_stats["execution_times"].append(execution_time)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
集成兴趣度评分系统和用户关系追踪机制,实现智能化的聊天决策。
|
集成兴趣度评分系统和用户关系追踪机制,实现智能化的聊天决策。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
@@ -98,10 +99,11 @@ class ChatterActionPlanner:
|
|||||||
|
|
||||||
unread_messages = context.get_unread_messages() if context else []
|
unread_messages = context.get_unread_messages() if context else []
|
||||||
# 2. 使用新的兴趣度管理系统进行评分
|
# 2. 使用新的兴趣度管理系统进行评分
|
||||||
score = 0.0
|
message_interest = 0.0
|
||||||
should_reply = False
|
|
||||||
reply_not_available = False
|
reply_not_available = False
|
||||||
interest_updates: list[dict[str, Any]] = []
|
interest_updates: list[dict[str, Any]] = []
|
||||||
|
message_should_act = False
|
||||||
|
message_should_reply = False
|
||||||
|
|
||||||
if unread_messages:
|
if unread_messages:
|
||||||
# 直接使用消息中已计算的标志,无需重复计算兴趣值
|
# 直接使用消息中已计算的标志,无需重复计算兴趣值
|
||||||
@@ -111,17 +113,8 @@ class ChatterActionPlanner:
|
|||||||
message_should_reply = getattr(message, "should_reply", False)
|
message_should_reply = getattr(message, "should_reply", False)
|
||||||
message_should_act = getattr(message, "should_act", False)
|
message_should_act = getattr(message, "should_act", False)
|
||||||
|
|
||||||
# 确保interest_value不是None
|
if not message_should_reply:
|
||||||
if message_interest is None:
|
reply_not_available = True
|
||||||
message_interest = 0.3
|
|
||||||
|
|
||||||
# 更新最高兴趣度消息
|
|
||||||
if message_interest > score:
|
|
||||||
score = message_interest
|
|
||||||
if message_should_reply:
|
|
||||||
should_reply = True
|
|
||||||
else:
|
|
||||||
reply_not_available = True
|
|
||||||
|
|
||||||
# 如果should_act为false,强制设为no_action
|
# 如果should_act为false,强制设为no_action
|
||||||
if not message_should_act:
|
if not message_should_act:
|
||||||
@@ -142,22 +135,23 @@ class ChatterActionPlanner:
|
|||||||
"message_id": message.message_id,
|
"message_id": message.message_id,
|
||||||
"interest_value": 0.0,
|
"interest_value": 0.0,
|
||||||
"should_reply": False,
|
"should_reply": False,
|
||||||
|
"should_act": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if interest_updates:
|
if interest_updates:
|
||||||
await self._commit_interest_updates(interest_updates)
|
asyncio.create_task(self._commit_interest_updates(interest_updates))
|
||||||
|
|
||||||
# 检查兴趣度是否达到非回复动作阈值
|
# 检查兴趣度是否达到非回复动作阈值
|
||||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||||
if score < non_reply_action_interest_threshold:
|
if not message_should_act:
|
||||||
logger.info(f"兴趣度 {score:.3f} 低于阈值 {non_reply_action_interest_threshold:.3f},不执行动作")
|
logger.info(f"兴趣度 {message_interest:.3f} 低于阈值 {non_reply_action_interest_threshold:.3f},不执行动作")
|
||||||
# 直接返回 no_action
|
# 直接返回 no_action
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
|
|
||||||
no_action = ActionPlannerInfo(
|
no_action = ActionPlannerInfo(
|
||||||
action_type="no_action",
|
action_type="no_action",
|
||||||
reasoning=f"兴趣度评分 {score:.3f} 未达阈值 {non_reply_action_interest_threshold:.3f}",
|
reasoning=f"兴趣度评分 {message_interest:.3f} 未达阈值 {non_reply_action_interest_threshold:.3f}",
|
||||||
action_data={},
|
action_data={},
|
||||||
action_message=None,
|
action_message=None,
|
||||||
)
|
)
|
||||||
@@ -169,9 +163,6 @@ class ChatterActionPlanner:
|
|||||||
plan_filter = ChatterPlanFilter(self.chat_id, available_actions)
|
plan_filter = ChatterPlanFilter(self.chat_id, available_actions)
|
||||||
filtered_plan = await plan_filter.filter(reply_not_available, initial_plan)
|
filtered_plan = await plan_filter.filter(reply_not_available, initial_plan)
|
||||||
|
|
||||||
# 检查filtered_plan是否有reply动作,用于统计
|
|
||||||
has_reply_action = any(decision.action_type == "reply" for decision in filtered_plan.decided_actions)
|
|
||||||
|
|
||||||
# 5. 使用 PlanExecutor 执行 Plan
|
# 5. 使用 PlanExecutor 执行 Plan
|
||||||
execution_result = await self.executor.execute(filtered_plan)
|
execution_result = await self.executor.execute(filtered_plan)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user