This commit is contained in:
明天好像没什么
2025-11-07 21:01:45 +08:00
parent 80b040da2f
commit c8d7c09625
49 changed files with 854 additions and 872 deletions

View File

@@ -5,12 +5,12 @@ MCP Client Manager
"""
import asyncio
import orjson
import shutil
from pathlib import Path
from typing import Any
import mcp.types
import orjson
from fastmcp.client import Client, StdioTransport, StreamableHttpTransport
from src.common.logger import get_logger

View File

@@ -4,11 +4,13 @@
"""
import time
from typing import Any, Optional
from dataclasses import dataclass, asdict, field
from dataclasses import dataclass, field
from typing import Any
import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
logger = get_logger("stream_tool_history")
@@ -18,10 +20,10 @@ class ToolCallRecord:
"""工具调用记录"""
tool_name: str
args: dict[str, Any]
result: Optional[dict[str, Any]] = None
result: dict[str, Any] | None = None
status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒)
execution_time: float | None = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息
@@ -32,9 +34,9 @@ class ToolCallRecord:
content = self.result.get("content", "")
if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)):
elif isinstance(content, list | dict):
try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..."
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")[:500] + "..."
except Exception:
self.result_preview = str(content)[:500] + "..."
else:
@@ -105,7 +107,7 @@ class StreamToolHistoryManager:
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""从缓存或历史记录中获取结果
Args:
@@ -160,9 +162,9 @@ class StreamToolHistoryManager:
return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None,
tool_file_path: Optional[str] = None,
ttl: Optional[int] = None) -> None:
execution_time: float | None = None,
tool_file_path: str | None = None,
ttl: int | None = None) -> None:
"""缓存工具调用结果
Args:
@@ -207,7 +209,7 @@ class StreamToolHistoryManager:
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]:
async def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
"""获取最近的历史记录
Args:
@@ -295,7 +297,7 @@ class StreamToolHistoryManager:
self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""在内存历史记录中搜索缓存
Args:
@@ -333,7 +335,7 @@ class StreamToolHistoryManager:
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]:
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> str | None:
"""提取语义查询参数
Args:
@@ -370,7 +372,7 @@ class StreamToolHistoryManager:
return ""
try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
if len(args_str) > max_length:
args_str = args_str[:max_length] + "..."
return args_str
@@ -411,4 +413,4 @@ def cleanup_stream_manager(chat_id: str) -> None:
"""
if chat_id in _stream_managers:
del _stream_managers[chat_id]
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")

View File

@@ -1,5 +1,6 @@
import inspect
import time
from dataclasses import asdict
from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager
@@ -10,8 +11,7 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord
from dataclasses import asdict
from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_stream_tool_history_manager
logger = get_logger("tool_use")
@@ -140,7 +140,7 @@ class ToolExecutor:
# 构建工具调用历史文本
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
# 获取人设信息
personality_core = global_config.personality.personality_core
personality_side = global_config.personality.personality_side
@@ -197,7 +197,7 @@ class ToolExecutor:
return tool_definitions
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用
@@ -338,9 +338,8 @@ class ToolExecutor:
if tool_instance and result and tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
function_args.get(tool_instance.semantic_cache_query_key)
await self.history_manager.cache_result(
tool_name=tool_call.func_name,