feat(tool_history): 实现流工具历史管理器,以增强工具调用跟踪和缓存- 添加了 StreamToolHistoryManager,用于管理聊天流级别的工具调用历史。- 引入了 ToolCallRecord,用于详细记录工具调用,包括执行时间和缓存命中情况。- 集成了内存缓存和全局缓存系统,以高效检索结果。- 更新了 ToolExecutor,以使用新的历史管理器记录和获取工具调用。- 增强了 ExaSearchEngine,以限制返回结果数量并提升答案质量。- 重构了 CacheManager 中的缓存管理,以包括工具调用统计和性能指标。

This commit is contained in:
Windpicker-owo
2025-11-06 14:22:59 +08:00
parent ec51277539
commit e6f3dfc1e7
5 changed files with 743 additions and 183 deletions

View File

@@ -0,0 +1,414 @@
"""
流式工具历史记录管理器
用于在聊天流级别管理工具调用历史,支持智能缓存和上下文感知
"""
import time
from typing import Any, Optional
from dataclasses import dataclass, asdict, field
import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
logger = get_logger("stream_tool_history")
@dataclass
class ToolCallRecord:
"""工具调用记录"""
tool_name: str
args: dict[str, Any]
result: Optional[dict[str, Any]] = None
status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息
def __post_init__(self):
"""后处理:生成结果预览"""
if self.result and not self.result_preview:
content = self.result.get("content", "")
if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)):
try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..."
except Exception:
self.result_preview = str(content)[:500] + "..."
else:
self.result_preview = str(content)[:500] + "..."
class StreamToolHistoryManager:
"""流式工具历史记录管理器
提供以下功能:
1. 工具调用历史的持久化管理
2. 智能缓存集成和结果去重
3. 上下文感知的历史记录检索
4. 性能监控和统计
"""
def __init__(self, chat_id: str, max_history: int = 20, enable_memory_cache: bool = True):
"""初始化历史记录管理器
Args:
chat_id: 聊天ID用于隔离不同聊天流的历史记录
max_history: 最大历史记录数量
enable_memory_cache: 是否启用内存缓存
"""
self.chat_id = chat_id
self.max_history = max_history
self.enable_memory_cache = enable_memory_cache
# 内存中的历史记录,按时间顺序排列
self._history: list[ToolCallRecord] = []
# 性能统计
self._stats = {
"total_calls": 0,
"cache_hits": 0,
"cache_misses": 0,
"total_execution_time": 0.0,
"average_execution_time": 0.0,
}
logger.info(f"[{chat_id}] 工具历史记录管理器初始化完成,最大历史: {max_history}")
async def add_tool_call(self, record: ToolCallRecord) -> None:
"""添加工具调用记录
Args:
record: 工具调用记录
"""
# 维护历史记录大小
if len(self._history) >= self.max_history:
# 移除最旧的记录
removed_record = self._history.pop(0)
logger.debug(f"[{self.chat_id}] 移除旧记录: {removed_record.tool_name}")
# 添加新记录
self._history.append(record)
# 更新统计
self._stats["total_calls"] += 1
if record.cache_hit:
self._stats["cache_hits"] += 1
else:
self._stats["cache_misses"] += 1
if record.execution_time is not None:
self._stats["total_execution_time"] += record.execution_time
self._stats["average_execution_time"] = self._stats["total_execution_time"] / self._stats["total_calls"]
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
"""从缓存或历史记录中获取结果
Args:
tool_name: 工具名称
args: 工具参数
Returns:
缓存的结果如果不存在则返回None
"""
# 首先检查内存中的历史记录
if self.enable_memory_cache:
memory_result = self._search_memory_cache(tool_name, args)
if memory_result:
logger.info(f"[{self.chat_id}] 内存缓存命中: {tool_name}")
return memory_result
# 然后检查全局缓存系统
try:
# 这里需要工具实例来获取文件路径,但为了解耦,我们先尝试从历史记录中推断
tool_file_path = self._infer_tool_path(tool_name)
# 尝试语义缓存(如果可以推断出语义查询参数)
semantic_query = self._extract_semantic_query(tool_name, args)
cached_result = await tool_cache.get(
tool_name=tool_name,
function_args=args,
tool_file_path=tool_file_path,
semantic_query=semantic_query,
)
if cached_result:
logger.info(f"[{self.chat_id}] 全局缓存命中: {tool_name}")
# 将结果同步到内存缓存
if self.enable_memory_cache:
record = ToolCallRecord(
tool_name=tool_name,
args=args,
result=cached_result,
status="success",
cache_hit=True,
timestamp=time.time(),
)
await self.add_tool_call(record)
return cached_result
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存查询失败: {e}")
return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None,
tool_file_path: Optional[str] = None,
ttl: Optional[int] = None) -> None:
"""缓存工具调用结果
Args:
tool_name: 工具名称
args: 工具参数
result: 执行结果
execution_time: 执行耗时
tool_file_path: 工具文件路径
ttl: 缓存TTL
"""
# 添加到内存历史记录
record = ToolCallRecord(
tool_name=tool_name,
args=args,
result=result,
status="success",
execution_time=execution_time,
cache_hit=False,
timestamp=time.time(),
)
await self.add_tool_call(record)
# 同步到全局缓存系统
try:
if tool_file_path is None:
tool_file_path = self._infer_tool_path(tool_name)
# 尝试语义缓存
semantic_query = self._extract_semantic_query(tool_name, args)
await tool_cache.set(
tool_name=tool_name,
function_args=args,
tool_file_path=tool_file_path,
data=result,
ttl=ttl,
semantic_query=semantic_query,
)
logger.debug(f"[{self.chat_id}] 结果已缓存: {tool_name}")
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]:
"""获取最近的历史记录
Args:
count: 返回的记录数量
status_filter: 状态过滤器可选值success, error, pending
Returns:
历史记录列表
"""
history = self._history.copy()
# 应用状态过滤
if status_filter:
history = [record for record in history if record.status == status_filter]
# 返回最近的记录
return history[-count:] if history else []
def format_for_prompt(self, max_records: int = 5, include_results: bool = True) -> str:
"""格式化历史记录为提示词
Args:
max_records: 最大记录数量
include_results: 是否包含结果预览
Returns:
格式化的提示词字符串
"""
if not self._history:
return ""
recent_records = self._history[-max_records:]
lines = ["## 🔧 最近工具调用记录"]
for i, record in enumerate(recent_records, 1):
status_icon = "" if record.status == "success" else "" if record.status == "error" else ""
# 格式化参数
args_preview = self._format_args_preview(record.args)
# 基础信息
lines.append(f"{i}. {status_icon} **{record.tool_name}**({args_preview})")
# 添加执行时间和缓存信息
if record.execution_time is not None:
time_info = f"{record.execution_time:.2f}s"
cache_info = "🎯缓存" if record.cache_hit else "🔍执行"
lines.append(f" ⏱️ {time_info} | {cache_info}")
# 添加结果预览
if include_results and record.result_preview:
lines.append(f" 📝 结果: {record.result_preview}")
# 添加错误信息
if record.status == "error" and record.error_message:
lines.append(f" ❌ 错误: {record.error_message}")
# 添加统计信息
if self._stats["total_calls"] > 0:
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
avg_time = self._stats["average_execution_time"]
lines.append(f"\n📊 工具统计: 总计{self._stats['total_calls']}次 | 缓存命中率{cache_hit_rate:.1f}% | 平均耗时{avg_time:.2f}s")
return "\n".join(lines)
def get_stats(self) -> dict[str, Any]:
"""获取性能统计信息
Returns:
统计信息字典
"""
cache_hit_rate = 0.0
if self._stats["total_calls"] > 0:
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
return {
**self._stats,
"cache_hit_rate": cache_hit_rate,
"history_size": len(self._history),
"chat_id": self.chat_id,
}
def clear_history(self) -> None:
"""清除历史记录"""
self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
"""在内存历史记录中搜索缓存
Args:
tool_name: 工具名称
args: 工具参数
Returns:
匹配的结果如果不存在则返回None
"""
for record in reversed(self._history): # 从最新的开始搜索
if (record.tool_name == tool_name and
record.status == "success" and
record.args == args):
return record.result
return None
def _infer_tool_path(self, tool_name: str) -> str:
"""推断工具文件路径
Args:
tool_name: 工具名称
Returns:
推断的文件路径
"""
# 基于工具名称推断路径,这是一个简化的实现
# 在实际使用中,可能需要更复杂的映射逻辑
tool_path_mapping = {
"web_search": "src/plugins/built_in/web_search_tool/tools/web_search.py",
"memory_create": "src/memory_graph/tools/memory_tools.py",
"memory_search": "src/memory_graph/tools/memory_tools.py",
"user_profile_update": "src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py",
"chat_stream_impression_update": "src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py",
}
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]:
"""提取语义查询参数
Args:
tool_name: 工具名称
args: 工具参数
Returns:
语义查询字符串如果不存在则返回None
"""
# 为不同工具定义语义查询参数映射
semantic_query_mapping = {
"web_search": "query",
"memory_search": "query",
"knowledge_search": "query",
}
query_key = semantic_query_mapping.get(tool_name)
if query_key and query_key in args:
return str(args[query_key])
return None
def _format_args_preview(self, args: dict[str, Any], max_length: int = 100) -> str:
"""格式化参数预览
Args:
args: 参数字典
max_length: 最大长度
Returns:
格式化的参数预览字符串
"""
if not args:
return ""
try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
if len(args_str) > max_length:
args_str = args_str[:max_length] + "..."
return args_str
except Exception:
# 如果序列化失败,使用简单格式
parts = []
for k, v in list(args.items())[:3]: # 最多显示3个参数
parts.append(f"{k}={str(v)[:20]}")
result = ", ".join(parts)
if len(parts) >= 3 or len(result) > max_length:
result += "..."
return result
# 全局管理器字典按chat_id索引
_stream_managers: dict[str, StreamToolHistoryManager] = {}
def get_stream_tool_history_manager(chat_id: str) -> StreamToolHistoryManager:
"""获取指定聊天的工具历史记录管理器
Args:
chat_id: 聊天ID
Returns:
工具历史记录管理器实例
"""
if chat_id not in _stream_managers:
_stream_managers[chat_id] = StreamToolHistoryManager(chat_id)
return _stream_managers[chat_id]
def cleanup_stream_manager(chat_id: str) -> None:
"""清理指定聊天的管理器
Args:
chat_id: 聊天ID
"""
if chat_id in _stream_managers:
del _stream_managers[chat_id]
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")

View File

@@ -3,7 +3,6 @@ import time
from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.payload_content import ToolCall
@@ -11,6 +10,8 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord
from dataclasses import asdict
logger = get_logger("tool_use")
@@ -36,15 +37,29 @@ def init_tool_executor_prompt():
{tool_history}
## 🔧 工具使用
## 🔧 工具决策指南
根据上下文判断是否需要使用工具。每个工具都有详细的description说明其用途和参数请根据工具定义决定是否调用。
**核心原则:**
- 根据上下文智能判断是否需要使用工具
- 每个工具都有详细的description说明其用途和参数
- 避免重复调用历史记录中已执行的工具(除非参数不同)
- 优先考虑使用已有的缓存结果,避免重复调用
**历史记录说明:**
- 上方显示的是**之前**的工具调用记录
- 请参考历史记录避免重复调用相同参数的工具
- 如果历史记录中已有相关结果,可以考虑直接回答而不调用工具
**⚠️ 记忆创建特别提醒:**
创建记忆时subject主体必须使用对话历史中显示的**真实发送人名字**
- ✅ 正确:从"Prou(12345678): ..."中提取"Prou"作为subject
- ❌ 错误:使用"用户""对方"等泛指词
**工具调用策略:**
1. **避免重复调用**:查看历史记录,如果最近已调用过相同工具且参数一致,无需重复调用
2. **智能选择工具**:根据消息内容选择最合适的工具,避免过度使用
3. **参数优化**:确保工具参数简洁有效,避免冗余信息
**执行指令:**
- 需要使用工具 → 直接调用相应的工具函数
- 不需要工具 → 输出 "No tool needed"
@@ -81,9 +96,8 @@ class ToolExecutor:
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
self._log_prefix_initialized = False
# 工具调用历史
self.tool_call_history: list[dict[str, Any]] = []
"""工具调用历史,包含工具名称、参数和结果"""
# 流式工具历史记录管理器
self.history_manager = get_stream_tool_history_manager(chat_id)
# logger.info(f"{self.log_prefix}工具执行器初始化完成") # 移到异步初始化中
@@ -125,7 +139,7 @@ class ToolExecutor:
bot_name = global_config.bot.nickname
# 构建工具调用历史文本
tool_history = self._format_tool_history()
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
# 获取人设信息
personality_core = global_config.personality.personality_core
@@ -183,83 +197,7 @@ class ToolExecutor:
return tool_definitions
def _format_tool_history(self, max_history: int = 5) -> str:
"""格式化工具调用历史为文本
Args:
max_history: 最多显示的历史记录数量
Returns:
格式化的工具历史文本
"""
if not self.tool_call_history:
return ""
# 只取最近的几条历史
recent_history = self.tool_call_history[-max_history:]
history_lines = ["历史工具调用记录:"]
for i, record in enumerate(recent_history, 1):
tool_name = record.get("tool_name", "unknown")
args = record.get("args", {})
result_preview = record.get("result_preview", "")
status = record.get("status", "success")
# 格式化参数
args_str = ", ".join([f"{k}={v}" for k, v in args.items()])
# 格式化记录
status_emoji = "" if status == "success" else ""
history_lines.append(f"{i}. {status_emoji} {tool_name}({args_str})")
if result_preview:
# 限制结果预览长度
if len(result_preview) > 200:
result_preview = result_preview[:200] + "..."
history_lines.append(f" 结果: {result_preview}")
return "\n".join(history_lines)
def _add_tool_to_history(self, tool_name: str, args: dict, result: dict | None, status: str = "success"):
"""添加工具调用到历史记录
Args:
tool_name: 工具名称
args: 工具参数
result: 工具结果
status: 执行状态 (success/error)
"""
# 生成结果预览
result_preview = ""
if result:
content = result.get("content", "")
if isinstance(content, str):
result_preview = content
elif isinstance(content, list | dict):
import orjson
try:
result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
except Exception:
result_preview = str(content)
else:
result_preview = str(content)
record = {
"tool_name": tool_name,
"args": args,
"result_preview": result_preview,
"status": status,
"timestamp": time.time(),
}
self.tool_call_history.append(record)
# 限制历史记录数量,避免内存溢出
max_history_size = 5
if len(self.tool_call_history) > max_history_size:
self.tool_call_history = self.tool_call_history[-max_history_size:]
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用
@@ -320,10 +258,20 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
# 记录到历史
self._add_tool_to_history(tool_name, tool_args, result, status="success")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=result,
status="success"
))
else:
# 工具返回空结果也记录到历史
self._add_tool_to_history(tool_name, tool_args, None, status="success")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="success"
))
except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
@@ -338,62 +286,72 @@ class ToolExecutor:
tool_results.append(error_info)
# 记录失败到历史
self._add_tool_to_history(tool_name, tool_args, None, status="error")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="error",
error_message=str(e)
))
return tool_results, used_tools
async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
) -> dict[str, Any] | None:
"""执行单个工具调用,并处理缓存"""
"""执行单个工具调用,集成流式历史记录管理器"""
start_time = time.time()
function_args = tool_call.args or {}
tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream)
# 如果工具不存在或未启用缓存,则直接执行
if not tool_instance or not tool_instance.enable_cache:
return await self._original_execute_tool_call(tool_call, tool_instance)
# 尝试从历史记录管理器获取缓存结果
if tool_instance and tool_instance.enable_cache:
try:
cached_result = await self.history_manager.get_cached_result(
tool_name=tool_call.func_name,
args=function_args
)
if cached_result:
execution_time = time.time() - start_time
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
# --- 缓存逻辑开始 ---
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
# 记录缓存命中到历史
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_call.func_name,
args=function_args,
result=cached_result,
status="success",
execution_time=execution_time,
cache_hit=True
))
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
semantic_query=semantic_query,
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查历史缓存时出错: {e}")
# 缓存未命中,执行原始工具调用
# 缓存未命中,执行工具调用
result = await self._original_execute_tool_call(tool_call, tool_instance)
# 将结果存入缓存
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
# 记录执行结果到历史管理器
execution_time = time.time() - start_time
if tool_instance and result and tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query,
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
# --- 缓存逻辑结束 ---
await self.history_manager.cache_result(
tool_name=tool_call.func_name,
args=function_args,
result=result,
execution_time=execution_time,
tool_file_path=tool_file_path,
ttl=tool_instance.cache_ttl
)
except Exception as e:
logger.error(f"{self.log_prefix}缓存结果到历史管理器时出错: {e}")
return result
@@ -528,21 +486,31 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
# 记录到历史
self._add_tool_to_history(tool_name, tool_args, result, status="success")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=result,
status="success"
))
return tool_info
except Exception as e:
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
# 记录失败到历史
self._add_tool_to_history(tool_name, tool_args, None, status="error")
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="error",
error_message=str(e)
))
return None
def clear_tool_history(self):
"""清除工具调用历史"""
self.tool_call_history.clear()
logger.debug(f"{self.log_prefix}已清除工具调用历史")
self.history_manager.clear_history()
def get_tool_history(self) -> list[dict[str, Any]]:
"""获取工具调用历史
@@ -550,7 +518,17 @@ class ToolExecutor:
Returns:
工具调用历史列表
"""
return self.tool_call_history.copy()
# 返回最近的历史记录
records = self.history_manager.get_recent_history(count=10)
return [asdict(record) for record in records]
def get_tool_stats(self) -> dict[str, Any]:
"""获取工具统计信息
Returns:
工具统计信息字典
"""
return self.history_manager.get_stats()
"""