ruff
This commit is contained in:
@@ -5,12 +5,12 @@ MCP Client Manager
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import orjson
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mcp.types
|
||||
import orjson
|
||||
from fastmcp.client import Client, StdioTransport, StreamableHttpTransport
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -4,11 +4,13 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.common.cache_manager import tool_cache
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("stream_tool_history")
|
||||
|
||||
@@ -18,10 +20,10 @@ class ToolCallRecord:
|
||||
"""工具调用记录"""
|
||||
tool_name: str
|
||||
args: dict[str, Any]
|
||||
result: Optional[dict[str, Any]] = None
|
||||
result: dict[str, Any] | None = None
|
||||
status: str = "success" # success, error, pending
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
execution_time: Optional[float] = None # 执行耗时(秒)
|
||||
execution_time: float | None = None # 执行耗时(秒)
|
||||
cache_hit: bool = False # 是否命中缓存
|
||||
result_preview: str = "" # 结果预览
|
||||
error_message: str = "" # 错误信息
|
||||
@@ -32,9 +34,9 @@ class ToolCallRecord:
|
||||
content = self.result.get("content", "")
|
||||
if isinstance(content, str):
|
||||
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
|
||||
elif isinstance(content, (list, dict)):
|
||||
elif isinstance(content, list | dict):
|
||||
try:
|
||||
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..."
|
||||
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")[:500] + "..."
|
||||
except Exception:
|
||||
self.result_preview = str(content)[:500] + "..."
|
||||
else:
|
||||
@@ -105,7 +107,7 @@ class StreamToolHistoryManager:
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
|
||||
|
||||
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""从缓存或历史记录中获取结果
|
||||
|
||||
Args:
|
||||
@@ -160,9 +162,9 @@ class StreamToolHistoryManager:
|
||||
return None
|
||||
|
||||
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
|
||||
execution_time: Optional[float] = None,
|
||||
tool_file_path: Optional[str] = None,
|
||||
ttl: Optional[int] = None) -> None:
|
||||
execution_time: float | None = None,
|
||||
tool_file_path: str | None = None,
|
||||
ttl: int | None = None) -> None:
|
||||
"""缓存工具调用结果
|
||||
|
||||
Args:
|
||||
@@ -207,7 +209,7 @@ class StreamToolHistoryManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
|
||||
|
||||
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]:
|
||||
async def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
|
||||
"""获取最近的历史记录
|
||||
|
||||
Args:
|
||||
@@ -295,7 +297,7 @@ class StreamToolHistoryManager:
|
||||
self._history.clear()
|
||||
logger.info(f"[{self.chat_id}] 工具历史记录已清除")
|
||||
|
||||
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""在内存历史记录中搜索缓存
|
||||
|
||||
Args:
|
||||
@@ -333,7 +335,7 @@ class StreamToolHistoryManager:
|
||||
|
||||
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
|
||||
|
||||
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]:
|
||||
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> str | None:
|
||||
"""提取语义查询参数
|
||||
|
||||
Args:
|
||||
@@ -370,7 +372,7 @@ class StreamToolHistoryManager:
|
||||
return ""
|
||||
|
||||
try:
|
||||
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
|
||||
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||
if len(args_str) > max_length:
|
||||
args_str = args_str[:max_length] + "..."
|
||||
return args_str
|
||||
@@ -411,4 +413,4 @@ def cleanup_stream_manager(chat_id: str) -> None:
|
||||
"""
|
||||
if chat_id in _stream_managers:
|
||||
del _stream_managers[chat_id]
|
||||
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")
|
||||
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import inspect
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
@@ -10,8 +11,7 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord
|
||||
from dataclasses import asdict
|
||||
from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_stream_tool_history_manager
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
@@ -140,7 +140,7 @@ class ToolExecutor:
|
||||
|
||||
# 构建工具调用历史文本
|
||||
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
|
||||
|
||||
|
||||
# 获取人设信息
|
||||
personality_core = global_config.personality.personality_core
|
||||
personality_side = global_config.personality.personality_side
|
||||
@@ -197,7 +197,7 @@ class ToolExecutor:
|
||||
|
||||
return tool_definitions
|
||||
|
||||
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
@@ -338,9 +338,8 @@ class ToolExecutor:
|
||||
if tool_instance and result and tool_instance.enable_cache:
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||
function_args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
await self.history_manager.cache_result(
|
||||
tool_name=tool_call.func_name,
|
||||
|
||||
Reference in New Issue
Block a user