diff --git a/src/chat/utils/prompt_builder.py b/src/chat/utils/prompt_builder.py index 95643c722..1db532b5d 100644 --- a/src/chat/utils/prompt_builder.py +++ b/src/chat/utils/prompt_builder.py @@ -7,33 +7,11 @@ from contextlib import asynccontextmanager from typing import Dict, Any, Optional, List, Union from src.common.logger import get_logger -from src.common.tool_history import ToolHistoryManager install(extra_lines=3) logger = get_logger("prompt_build") -# 创建工具历史管理器实例 -tool_history_manager = ToolHistoryManager() - -def get_tool_history_prompt(message_id: Optional[str] = None) -> str: - """获取工具历史提示词 - - Args: - message_id: 会话ID, 用于只获取当前会话的历史 - - Returns: - 格式化的工具历史提示词 - """ - from src.config.config import global_config - - if not global_config.tool.history.enable_prompt_history: - return "" - - return tool_history_manager.get_recent_history_prompt( - chat_id=message_id - ) - class PromptContext: def __init__(self): self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} @@ -49,7 +27,7 @@ class PromptContext: @_current_context.setter def _current_context(self, value: Optional[str]): """设置当前协程的上下文ID""" - self._current_context_var.set(value) + self._current_context_var.set(value) # type: ignore @asynccontextmanager async def async_scope(self, context_id: Optional[str] = None): @@ -73,7 +51,7 @@ class PromptContext: # 保存当前协程的上下文值,不影响其他协程 previous_context = self._current_context # 设置当前协程的新上下文 - token = self._current_context_var.set(context_id) if context_id else None + token = self._current_context_var.set(context_id) if context_id else None # type: ignore else: # 如果没有提供新上下文,保持当前上下文不变 previous_context = self._current_context @@ -111,7 +89,8 @@ class PromptContext: """异步注册提示模板到指定作用域""" async with self._context_lock: if target_context := context_id or self._current_context: - self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt + if prompt.name: + self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt class PromptManager: @@ -153,40 +132,15 @@ class PromptManager: def add_prompt(self, name: str, fstr: str) -> "Prompt": prompt = Prompt(fstr, name=name) - self._prompts[prompt.name] = prompt + if prompt.name: + self._prompts[prompt.name] = prompt return prompt async def format_prompt(self, name: str, **kwargs) -> str: # 获取当前提示词 prompt = await self.get_prompt_async(name) - # 获取当前会话ID - message_id = self._context._current_context - - # 获取工具历史提示词 - tool_history = "" - if name in ['action_prompt', 'replyer_prompt', 'planner_prompt', 'tool_executor_prompt']: - tool_history = get_tool_history_prompt(message_id) - # 获取基本格式化结果 result = prompt.format(**kwargs) - - # 如果有工具历史,插入到适当位置 - if tool_history: - # 查找合适的插入点 - # 在人格信息和身份块之后,但在主要内容之前 - identity_end = result.find("```\n现在,你说:") - if identity_end == -1: - # 如果找不到特定标记,尝试在第一个段落后插入 - first_double_newline = result.find("\n\n") - if first_double_newline != -1: - # 在第一个双换行后插入 - result = f"{result[:first_double_newline + 2]}{tool_history}\n{result[first_double_newline + 2:]}" - else: - # 如果找不到合适的位置,添加到开头 - result = f"{tool_history}\n\n{result}" - else: - # 在找到的位置插入 - result = f"{result[:identity_end]}\n{tool_history}\n{result[identity_end:]}" return result @@ -195,6 +149,11 @@ global_prompt_manager = PromptManager() class Prompt(str): + template: str + name: Optional[str] + args: List[str] + _args: List[Any] + _kwargs: Dict[str, Any] # 临时标记,作为类常量 _TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__" _TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__" @@ -215,7 +174,7 @@ class Prompt(str): """将临时标记还原为实际的花括号字符""" return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}") - def __new__(cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs): + def __new__(cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs): # 如果传入的是元组,转换为列表 if isinstance(args, tuple): args = list(args) @@ -251,7 +210,7 @@ class Prompt(str): @classmethod async def create_async( - cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs + cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs ): """异步创建Prompt实例""" prompt = cls(fstr, name, args, **kwargs) @@ -260,7 +219,9 @@ class Prompt(str): return prompt @classmethod - def _format_template(cls, template, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str: + def _format_template(cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None) -> str: + if kwargs is None: + kwargs = {} # 预处理模板中的转义花括号 processed_template = cls._process_escaped_braces(template) diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index 7b0a8ec92..d4f872d30 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -4,7 +4,7 @@ import hashlib from pathlib import Path import numpy as np import faiss -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, List from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config @@ -141,7 +141,7 @@ class CacheManager: # 步骤 2a: L1 语义缓存 (FAISS) if query_embedding is not None and self.l1_vector_index.ntotal > 0: faiss.normalize_L2(query_embedding) - distances, indices = self.l1_vector_index.search(query_embedding, 1) + distances, indices = self.l1_vector_index.search(query_embedding, 1) # type: ignore if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似 hit_index = indices[0][0] l1_hit_key = self.l1_vector_id_to_key.get(hit_index) @@ -348,4 +348,64 @@ class CacheManager: logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目") # 全局实例 -tool_cache = CacheManager() \ No newline at end of file +tool_cache = CacheManager() + +import inspect +import time + +def wrap_tool_executor(): + """ + 包装工具执行器以添加缓存功能 + 这个函数应该在系统启动时被调用一次 + """ + from src.plugin_system.core.tool_use import ToolExecutor + from src.plugin_system.apis.tool_api import get_tool_instance + original_execute = ToolExecutor.execute_tool_call + + async def wrapped_execute_tool_call(self, tool_call, tool_instance=None): + if not tool_instance: + tool_instance = get_tool_instance(tool_call.func_name) + + if not tool_instance or not tool_instance.enable_cache: + return await original_execute(self, tool_call, tool_instance) + + try: + tool_file_path = inspect.getfile(tool_instance.__class__) + semantic_query = None + if tool_instance.semantic_cache_query_key: + semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key) + + cached_result = await tool_cache.get( + tool_name=tool_call.func_name, + function_args=tool_call.args, + tool_file_path=tool_file_path, + semantic_query=semantic_query + ) + if cached_result: + logger.info(f"{getattr(self, 'log_prefix', '')}使用缓存结果,跳过工具 {tool_call.func_name} 执行") + return cached_result + except Exception as e: + logger.error(f"{getattr(self, 'log_prefix', '')}检查工具缓存时出错: {e}") + + result = await original_execute(self, tool_call, tool_instance) + + try: + tool_file_path = inspect.getfile(tool_instance.__class__) + semantic_query = None + if tool_instance.semantic_cache_query_key: + semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key) + + await tool_cache.set( + tool_name=tool_call.func_name, + function_args=tool_call.args, + tool_file_path=tool_file_path, + data=result, + ttl=tool_instance.cache_ttl, + semantic_query=semantic_query + ) + except Exception as e: + logger.error(f"{getattr(self, 'log_prefix', '')}设置工具缓存时出错: {e}") + + return result + + ToolExecutor.execute_tool_call = wrapped_execute_tool_call \ No newline at end of file diff --git a/src/common/tool_history.py b/src/common/tool_history.py deleted file mode 100644 index b3edb12ce..000000000 --- a/src/common/tool_history.py +++ /dev/null @@ -1,405 +0,0 @@ -"""工具执行历史记录模块""" -import time -from datetime import datetime -from typing import Any, Dict, List, Optional, Union -import json -from pathlib import Path -import inspect - -from .logger import get_logger -from src.config.config import global_config -from src.common.cache_manager import tool_cache - -logger = get_logger("tool_history") - -class ToolHistoryManager: - """工具执行历史记录管理器""" - - _instance = None - _initialized = False - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self): - if not self._initialized: - self._history: List[Dict[str, Any]] = [] - self._initialized = True - self._data_dir = Path("data/tool_history") - self._data_dir.mkdir(parents=True, exist_ok=True) - self._history_file = self._data_dir / "tool_history.jsonl" - self._load_history() - - def _save_history(self): - """保存所有历史记录到文件""" - try: - with self._history_file.open("w", encoding="utf-8") as f: - for record in self._history: - f.write(json.dumps(record, ensure_ascii=False) + "\n") - except Exception as e: - logger.error(f"保存工具调用记录失败: {e}") - - def _save_record(self, record: Dict[str, Any]): - """保存单条记录到文件""" - try: - with self._history_file.open("a", encoding="utf-8") as f: - f.write(json.dumps(record, ensure_ascii=False) + "\n") - except Exception as e: - logger.error(f"保存工具调用记录失败: {e}") - - def _clean_expired_records(self): - """清理已过期的记录""" - original_count = len(self._history) - self._history = [record for record in self._history if record.get("ttl_count", 0) < record.get("ttl", 5)] - cleaned_count = original_count - len(self._history) - - if cleaned_count > 0: - logger.info(f"清理了 {cleaned_count} 条过期的工具历史记录,剩余 {len(self._history)} 条") - self._save_history() - else: - logger.debug("没有需要清理的过期工具历史记录") - - def record_tool_call(self, - tool_name: str, - args: Dict[str, Any], - result: Any, - execution_time: float, - status: str, - chat_id: Optional[str] = None, - ttl: int = 5): - """记录工具调用 - - Args: - tool_name: 工具名称 - args: 工具调用参数 - result: 工具返回结果 - execution_time: 执行时间(秒) - status: 执行状态("completed"或"error") - chat_id: 聊天ID,与ChatManager中的chat_id对应,用于标识群聊或私聊会话 - ttl: 该记录的生命周期值,插入提示词多少次后删除,默认为5 - """ - # 检查是否启用历史记录且ttl大于0 - if not global_config.tool.history.enable_history or ttl <= 0: - return - - # 先清理过期记录 - self._clean_expired_records() - - try: - # 创建记录 - record = { - "tool_name": tool_name, - "timestamp": datetime.now().isoformat(), - "arguments": self._sanitize_args(args), - "result": self._sanitize_result(result), - "execution_time": execution_time, - "status": status, - "chat_id": chat_id, - "ttl": ttl, - "ttl_count": 0 - } - - # 添加到内存中的历史记录 - self._history.append(record) - - # 保存到文件 - self._save_record(record) - - if status == "completed": - logger.info(f"工具 {tool_name} 调用完成,耗时:{execution_time:.2f}s") - else: - logger.error(f"工具 {tool_name} 调用失败:{result}") - - except Exception as e: - logger.error(f"记录工具调用时发生错误: {e}") - - def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]: - """清理参数中的敏感信息""" - sensitive_keys = ['api_key', 'token', 'password', 'secret'] - sanitized = args.copy() - - def _sanitize_value(value): - if isinstance(value, dict): - return {k: '***' if k.lower() in sensitive_keys else _sanitize_value(v) - for k, v in value.items()} - return value - - return {k: '***' if k.lower() in sensitive_keys else _sanitize_value(v) - for k, v in sanitized.items()} - - def _sanitize_result(self, result: Any) -> Any: - """清理结果中的敏感信息""" - if isinstance(result, dict): - return self._sanitize_args(result) - return result - - def _load_history(self): - """加载历史记录文件""" - try: - if self._history_file.exists(): - self._history = [] - with self._history_file.open("r", encoding="utf-8") as f: - for line in f: - try: - record = json.loads(line) - if record.get("ttl_count", 0) < record.get("ttl", 5): # 只加载未过期的记录 - self._history.append(record) - except json.JSONDecodeError: - continue - logger.info(f"成功加载了 {len(self._history)} 条历史记录") - except Exception as e: - logger.error(f"加载历史记录失败: {e}") - - def query_history(self, - tool_names: Optional[List[str]] = None, - start_time: Optional[Union[datetime, str]] = None, - end_time: Optional[Union[datetime, str]] = None, - chat_id: Optional[str] = None, - limit: Optional[int] = None, - status: Optional[str] = None) -> List[Dict[str, Any]]: - """查询工具调用历史 - - Args: - tool_names: 工具名称列表,为空则查询所有工具 - start_time: 开始时间,可以是datetime对象或ISO格式字符串 - end_time: 结束时间,可以是datetime对象或ISO格式字符串 - chat_id: 聊天ID,与ChatManager中的chat_id对应,用于查询特定群聊或私聊的历史记录 - limit: 返回记录数量限制 - status: 执行状态筛选("completed"或"error") - - Returns: - 符合条件的历史记录列表 - """ - # 先清理过期记录 - self._clean_expired_records() - def _parse_time(time_str: Optional[Union[datetime, str]]) -> Optional[datetime]: - if isinstance(time_str, datetime): - return time_str - elif isinstance(time_str, str): - return datetime.fromisoformat(time_str) - return None - - filtered_history = self._history - - # 按工具名筛选 - if tool_names: - filtered_history = [ - record for record in filtered_history - if record["tool_name"] in tool_names - ] - - # 按时间范围筛选 - start_dt = _parse_time(start_time) - end_dt = _parse_time(end_time) - - if start_dt: - filtered_history = [ - record for record in filtered_history - if datetime.fromisoformat(record["timestamp"]) >= start_dt - ] - - if end_dt: - filtered_history = [ - record for record in filtered_history - if datetime.fromisoformat(record["timestamp"]) <= end_dt - ] - - # 按聊天ID筛选 - if chat_id: - filtered_history = [ - record for record in filtered_history - if record.get("chat_id") == chat_id - ] - - # 按状态筛选 - if status: - filtered_history = [ - record for record in filtered_history - if record["status"] == status - ] - - # 应用数量限制 - if limit: - filtered_history = filtered_history[-limit:] - - return filtered_history - - def get_recent_history_prompt(self, - limit: Optional[int] = None, - chat_id: Optional[str] = None) -> str: - """ - 获取最近工具调用历史的提示词 - - Args: - limit: 返回的历史记录数量,如果不提供则使用配置中的max_history - chat_id: 会话ID,用于只获取当前会话的历史 - - Returns: - 格式化的历史记录提示词 - """ - # 检查是否启用历史记录 - if not global_config.tool.history.enable_history: - return "" - - # 使用配置中的最大历史记录数 - if limit is None: - limit = global_config.tool.history.max_history - - recent_history = self.query_history( - chat_id=chat_id, - limit=limit - ) - - if not recent_history: - return "" - - prompt = "\n工具执行历史:\n" - needs_save = False - updated_history = [] - - for record in recent_history: - # 增加ttl计数 - record["ttl_count"] = record.get("ttl_count", 0) + 1 - needs_save = True - - # 如果未超过ttl,则添加到提示词中 - if record["ttl_count"] < record.get("ttl", 5): - # 提取结果中的name和content - result = record['result'] - if isinstance(result, dict): - name = result.get('name', record['tool_name']) - content = result.get('content', str(result)) - else: - name = record['tool_name'] - content = str(result) - - # 格式化内容,去除多余空白和换行 - content = content.strip().replace('\n', ' ') - - # 如果内容太长则截断 - if len(content) > 200: - content = content[:200] + "..." - - prompt += f"{name}: \n{content}\n\n" - updated_history.append(record) - - # 更新历史记录并保存 - if needs_save: - self._history = updated_history - self._save_history() - - return prompt - - def clear_history(self): - """清除历史记录""" - self._history.clear() - self._save_history() - logger.info("工具调用历史记录已清除") - - -def wrap_tool_executor(): - """ - 包装工具执行器以添加历史记录和缓存功能 - 这个函数应该在系统启动时被调用一次 - """ - from src.plugin_system.core.tool_use import ToolExecutor - from src.plugin_system.apis.tool_api import get_tool_instance - original_execute = ToolExecutor.execute_tool_call - history_manager = ToolHistoryManager() - - async def wrapped_execute_tool_call(self, tool_call, tool_instance=None): - start_time = time.time() - - # 确保我们有 tool_instance - if not tool_instance: - tool_instance = get_tool_instance(tool_call.func_name) - - # 如果没有 tool_instance,就无法进行缓存检查,直接执行 - if not tool_instance: - result = await original_execute(self, tool_call, None) - execution_time = time.time() - start_time - history_manager.record_tool_call( - tool_name=tool_call.func_name, - args=tool_call.args, - result=result, - execution_time=execution_time, - status="completed", - chat_id=getattr(self, 'chat_id', None), - ttl=5 # Default TTL - ) - return result - - # 新的缓存逻辑 - if tool_instance.enable_cache: - try: - tool_file_path = inspect.getfile(tool_instance.__class__) - semantic_query = None - if tool_instance.semantic_cache_query_key: - semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key) - - cached_result = await tool_cache.get( - tool_name=tool_call.func_name, - function_args=tool_call.args, - tool_file_path=tool_file_path, - semantic_query=semantic_query - ) - if cached_result: - logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行") - return cached_result - except Exception as e: - logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}") - - try: - result = await original_execute(self, tool_call, tool_instance) - execution_time = time.time() - start_time - - # 缓存结果 - if tool_instance.enable_cache: - try: - tool_file_path = inspect.getfile(tool_instance.__class__) - semantic_query = None - if tool_instance.semantic_cache_query_key: - semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key) - - await tool_cache.set( - tool_name=tool_call.func_name, - function_args=tool_call.args, - tool_file_path=tool_file_path, - data=result, - ttl=tool_instance.cache_ttl, - semantic_query=semantic_query - ) - except Exception as e: - logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}") - - # 记录成功的调用 - history_manager.record_tool_call( - tool_name=tool_call.func_name, - args=tool_call.args, - result=result, - execution_time=execution_time, - status="completed", - chat_id=getattr(self, 'chat_id', None), - ttl=tool_instance.history_ttl - ) - - return result - - except Exception as e: - execution_time = time.time() - start_time - # 记录失败的调用 - history_manager.record_tool_call( - tool_name=tool_call.func_name, - args=tool_call.args, - result=str(e), - execution_time=execution_time, - status="error", - chat_id=getattr(self, 'chat_id', None), - ttl=tool_instance.history_ttl - ) - raise - - # 替换原始方法 - ToolExecutor.execute_tool_call = wrapped_execute_tool_call \ No newline at end of file diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py index ec8ddec39..da17f9305 100644 --- a/src/plugin_system/apis/tool_api.py +++ b/src/plugin_system/apis/tool_api.py @@ -1,9 +1,7 @@ -from typing import Any, Dict, List, Optional, Type, Union -from datetime import datetime +from typing import Any, Dict, List, Optional, Type from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.component_types import ComponentType -from src.common.tool_history import ToolHistoryManager from src.common.logger import get_logger logger = get_logger("tool_api") @@ -33,110 +31,4 @@ def get_llm_available_tool_definitions(): from src.plugin_system.core import component_registry llm_available_tools = component_registry.get_llm_available_tools() - return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()] - -def get_tool_history( - tool_names: Optional[List[str]] = None, - start_time: Optional[Union[datetime, str]] = None, - end_time: Optional[Union[datetime, str]] = None, - chat_id: Optional[str] = None, - limit: Optional[int] = None, - status: Optional[str] = None -) -> List[Dict[str, Any]]: - """ - 获取工具调用历史记录 - - Args: - tool_names: 工具名称列表,为空则查询所有工具 - start_time: 开始时间,可以是datetime对象或ISO格式字符串 - end_time: 结束时间,可以是datetime对象或ISO格式字符串 - chat_id: 会话ID,用于筛选特定会话的调用 - limit: 返回记录数量限制 - status: 执行状态筛选("completed"或"error") - - Returns: - List[Dict]: 工具调用记录列表,每条记录包含以下字段: - - tool_name: 工具名称 - - timestamp: 调用时间 - - arguments: 调用参数 - - result: 调用结果 - - execution_time: 执行时间 - - status: 执行状态 - - chat_id: 会话ID - """ - history_manager = ToolHistoryManager() - return history_manager.query_history( - tool_names=tool_names, - start_time=start_time, - end_time=end_time, - chat_id=chat_id, - limit=limit, - status=status - ) - - -def get_tool_history_text( - tool_names: Optional[List[str]] = None, - start_time: Optional[Union[datetime, str]] = None, - end_time: Optional[Union[datetime, str]] = None, - chat_id: Optional[str] = None, - limit: Optional[int] = None, - status: Optional[str] = None -) -> str: - """ - 获取工具调用历史记录的文本格式 - - Args: - tool_names: 工具名称列表,为空则查询所有工具 - start_time: 开始时间,可以是datetime对象或ISO格式字符串 - end_time: 结束时间,可以是datetime对象或ISO格式字符串 - chat_id: 会话ID,用于筛选特定会话的调用 - limit: 返回记录数量限制 - status: 执行状态筛选("completed"或"error") - - Returns: - str: 格式化的工具调用历史记录文本 - """ - history = get_tool_history( - tool_names=tool_names, - start_time=start_time, - end_time=end_time, - chat_id=chat_id, - limit=limit, - status=status - ) - - if not history: - return "没有找到工具调用记录" - - text = "工具调用历史记录:\n" - for record in history: - # 提取结果中的name和content - result = record['result'] - if isinstance(result, dict): - name = result.get('name', record['tool_name']) - content = result.get('content', str(result)) - else: - name = record['tool_name'] - content = str(result) - - # 格式化内容 - content = content.strip().replace('\n', ' ') - if len(content) > 200: - content = content[:200] + "..." - - # 格式化时间 - timestamp = datetime.fromisoformat(record['timestamp']).strftime("%Y-%m-%d %H:%M:%S") - - text += f"[{timestamp}] {name}\n" - text += f"结果: {content}\n\n" - - return text - - -def clear_tool_history() -> None: - """ - 清除所有工具调用历史记录 - """ - history_manager = ToolHistoryManager() - history_manager.clear_history() \ No newline at end of file + return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()] \ No newline at end of file