From 17c1d4b4f99610a1ef39dae5c7ec3e8edfe22d97 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 6 Nov 2025 12:47:56 +0800 Subject: [PATCH 1/6] =?UTF-8?q?feat:=20=E5=B0=86=20JSON=20=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=BA=93=E4=BB=8E=20json=20=E6=9B=B4=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=20orjson=EF=BC=8C=E4=BB=A5=E6=8F=90=E9=AB=98=E6=80=A7?= =?UTF-8?q?=E8=83=BD=E5=92=8C=E5=85=BC=E5=AE=B9=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../message_manager/global_notice_manager.py | 8 +++---- src/memory_graph/manager.py | 1 + src/memory_graph/storage/vector_store.py | 12 +++++----- src/memory_graph/tools/memory_tools.py | 24 +++++++++++-------- src/plugin_system/apis/storage_api.py | 8 +++---- src/plugin_system/core/mcp_client_manager.py | 6 ++--- src/plugin_system/core/tool_use.py | 4 ++-- .../proactive/proactive_thinking_executor.py | 2 +- .../services/reply_tracker_service.py | 8 +++---- .../built_in/napcat_adapter_plugin/plugin.py | 8 +++---- .../src/message_chunker.py | 22 ++++++++--------- .../src/recv_handler/message_handler.py | 14 +++++------ .../src/recv_handler/notice_handler.py | 8 +++---- .../napcat_adapter_plugin/src/send_handler.py | 4 ++-- .../napcat_adapter_plugin/src/utils.py | 20 ++++++++-------- .../web_search_tool/engines/metaso_engine.py | 6 ++--- tools/memory_visualizer/visualizer_server.py | 2 +- tools/memory_visualizer/visualizer_simple.py | 4 ++-- 18 files changed, 83 insertions(+), 78 deletions(-) diff --git a/src/chat/message_manager/global_notice_manager.py b/src/chat/message_manager/global_notice_manager.py index 7f382835f..db8a5aa2b 100644 --- a/src/chat/message_manager/global_notice_manager.py +++ b/src/chat/message_manager/global_notice_manager.py @@ -323,8 +323,8 @@ class GlobalNoticeManager: return message.additional_config.get("is_notice", False) elif isinstance(message.additional_config, str): # 兼容JSON字符串格式 - import json - config = json.loads(message.additional_config) + import orjson + config = orjson.loads(message.additional_config) return config.get("is_notice", False) # 检查消息类型或其他标识 @@ -349,8 +349,8 @@ class GlobalNoticeManager: if isinstance(message.additional_config, dict): return message.additional_config.get("notice_type") elif isinstance(message.additional_config, str): - import json - config = json.loads(message.additional_config) + import orjson + config = orjson.loads(message.additional_config) return config.get("notice_type") return None except Exception: diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index eeb4f6d2c..7161878f6 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -137,6 +137,7 @@ class MemoryManager: graph_store=self.graph_store, persistence_manager=self.persistence, embedding_generator=self.embedding_generator, + max_expand_depth=getattr(self.config, 'max_expand_depth', 1), # 从配置读取默认深度 ) self._initialized = True diff --git a/src/memory_graph/storage/vector_store.py b/src/memory_graph/storage/vector_store.py index 74a148c24..243167b40 100644 --- a/src/memory_graph/storage/vector_store.py +++ b/src/memory_graph/storage/vector_store.py @@ -102,8 +102,8 @@ class VectorStore: # 处理额外的元数据,将 list 转换为 JSON 字符串 for key, value in node.metadata.items(): if isinstance(value, (list, dict)): - import json - metadata[key] = json.dumps(value, ensure_ascii=False) + import orjson + metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') elif isinstance(value, (str, int, float, bool)) or value is None: metadata[key] = value else: @@ -141,7 +141,7 @@ class VectorStore: try: # 准备元数据 - import json + import orjson metadatas = [] for n in valid_nodes: metadata = { @@ -151,7 +151,7 @@ class VectorStore: } for key, value in n.metadata.items(): if isinstance(value, (list, dict)): - metadata[key] = json.dumps(value, ensure_ascii=False) + metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') elif isinstance(value, (str, int, float, bool)) or value is None: metadata[key] = value # type: ignore else: @@ -207,7 +207,7 @@ class VectorStore: ) # 解析结果 - import json + import orjson similar_nodes = [] if results["ids"] and results["ids"][0]: for i, node_id in enumerate(results["ids"][0]): @@ -223,7 +223,7 @@ class VectorStore: for key, value in list(metadata.items()): if isinstance(value, str) and (value.startswith('[') or value.startswith('{')): try: - metadata[key] = json.loads(value) + metadata[key] = orjson.loads(value) except: pass # 保持原值 diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index ff844e357..692671e85 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -34,6 +34,7 @@ class MemoryTools: graph_store: GraphStore, persistence_manager: PersistenceManager, embedding_generator: Optional[EmbeddingGenerator] = None, + max_expand_depth: int = 1, ): """ 初始化工具集 @@ -43,11 +44,13 @@ class MemoryTools: graph_store: 图存储 persistence_manager: 持久化管理器 embedding_generator: 嵌入生成器(可选) + max_expand_depth: 图扩展深度的默认值(从配置读取) """ self.vector_store = vector_store self.graph_store = graph_store self.persistence_manager = persistence_manager self._initialized = False + self.max_expand_depth = max_expand_depth # 保存配置的默认值 # 初始化组件 self.extractor = MemoryExtractor() @@ -448,11 +451,12 @@ class MemoryTools: try: query = params.get("query", "") top_k = params.get("top_k", 10) - expand_depth = params.get("expand_depth", 1) + # 使用配置中的默认值而不是硬编码的 1 + expand_depth = params.get("expand_depth", self.max_expand_depth) use_multi_query = params.get("use_multi_query", True) context = params.get("context", None) - logger.info(f"搜索记忆: {query} (top_k={top_k}, multi_query={use_multi_query})") + logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth}, multi_query={use_multi_query})") # 0. 确保初始化 await self._ensure_initialized() @@ -474,9 +478,9 @@ class MemoryTools: ids = metadata["memory_ids"] # 确保是列表 if isinstance(ids, str): - import json + import orjson try: - ids = json.loads(ids) + ids = orjson.loads(ids) except: ids = [ids] if isinstance(ids, list): @@ -649,11 +653,11 @@ class MemoryTools: response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250) - import json, re + import orjson, re response = re.sub(r'```json\s*', '', response) response = re.sub(r'```\s*$', '', response).strip() - data = json.loads(response) + data = orjson.loads(response) queries = data.get("queries", []) result = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) @@ -799,9 +803,9 @@ class MemoryTools: # 确保是列表 if isinstance(ids, str): - import json + import orjson try: - ids = json.loads(ids) + ids = orjson.loads(ids) except Exception as e: logger.warning(f"JSON 解析失败: {e}") ids = [ids] @@ -910,9 +914,9 @@ class MemoryTools: # 提取记忆ID neighbor_memory_ids = neighbor_node_data.get("memory_ids", []) if isinstance(neighbor_memory_ids, str): - import json + import orjson try: - neighbor_memory_ids = json.loads(neighbor_memory_ids) + neighbor_memory_ids = orjson.loads(neighbor_memory_ids) except: neighbor_memory_ids = [neighbor_memory_ids] diff --git a/src/plugin_system/apis/storage_api.py b/src/plugin_system/apis/storage_api.py index 66c7d4e79..2c8060473 100644 --- a/src/plugin_system/apis/storage_api.py +++ b/src/plugin_system/apis/storage_api.py @@ -7,7 +7,7 @@ """ import atexit -import json +import orjson import os import threading from typing import Any, ClassVar @@ -100,10 +100,10 @@ class PluginStorage: if os.path.exists(self.file_path): with open(self.file_path, encoding="utf-8") as f: content = f.read() - self._data = json.loads(content) if content else {} + self._data = orjson.loads(content) if content else {} else: self._data = {} - except (json.JSONDecodeError, Exception) as e: + except (orjson.JSONDecodeError, Exception) as e: logger.warning(f"从 '{self.file_path}' 加载数据失败: {e},将初始化为空数据。") self._data = {} @@ -125,7 +125,7 @@ class PluginStorage: try: with open(self.file_path, "w", encoding="utf-8") as f: - json.dump(self._data, f, indent=4, ensure_ascii=False) + f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')) self._dirty = False # 保存后重置标志 logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。") except Exception as e: diff --git a/src/plugin_system/core/mcp_client_manager.py b/src/plugin_system/core/mcp_client_manager.py index bf7713ac7..3bc7d6cdb 100644 --- a/src/plugin_system/core/mcp_client_manager.py +++ b/src/plugin_system/core/mcp_client_manager.py @@ -5,7 +5,7 @@ MCP Client Manager """ import asyncio -import json +import orjson import shutil from pathlib import Path from typing import Any @@ -89,7 +89,7 @@ class MCPClientManager: try: with open(self.config_path, encoding="utf-8") as f: - config_data = json.load(f) + config_data = orjson.loads(f.read()) servers = {} mcp_servers = config_data.get("mcpServers", {}) @@ -106,7 +106,7 @@ class MCPClientManager: logger.info(f"成功加载 {len(servers)} 个 MCP 服务器配置") return servers - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: logger.error(f"解析 MCP 配置文件失败: {e}") return {} except Exception as e: diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 11832ff51..6aa36aa6a 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -236,10 +236,10 @@ class ToolExecutor: if isinstance(content, str): result_preview = content elif isinstance(content, list | dict): - import json + import orjson try: - result_preview = json.dumps(content, ensure_ascii=False) + result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') except Exception: result_preview = str(content) else: diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py b/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py index 7c11d490f..4ed5aa406 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py @@ -3,7 +3,7 @@ 当定时任务触发时,负责搜集信息、调用LLM决策、并根据决策生成回复 """ -import json +import orjson from datetime import datetime from typing import Any, Literal diff --git a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py index 26306f837..3f70bd7fb 100644 --- a/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py +++ b/src/plugins/built_in/maizone_refactored/services/reply_tracker_service.py @@ -3,7 +3,7 @@ 负责记录和管理已回复过的评论ID,避免重复回复 """ -import json +import orjson import time from pathlib import Path from typing import Any @@ -71,7 +71,7 @@ class ReplyTrackerService: self.replied_comments = {} return - data = json.loads(file_content) + data = orjson.loads(file_content) if self._validate_data(data): self.replied_comments = data logger.info( @@ -81,7 +81,7 @@ class ReplyTrackerService: else: logger.error("加载的数据格式无效,将创建新的记录") self.replied_comments = {} - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: logger.error(f"解析回复记录文件失败: {e}") self._backup_corrupted_file() self.replied_comments = {} @@ -118,7 +118,7 @@ class ReplyTrackerService: # 先写入临时文件 with open(temp_file, "w", encoding="utf-8") as f: - json.dump(self.replied_comments, f, ensure_ascii=False, indent=2) + orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8') # 如果写入成功,重命名为正式文件 if temp_file.stat().st_size > 0: # 确保写入成功 diff --git a/src/plugins/built_in/napcat_adapter_plugin/plugin.py b/src/plugins/built_in/napcat_adapter_plugin/plugin.py index 92dc32608..a228cec7b 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/plugin.py +++ b/src/plugins/built_in/napcat_adapter_plugin/plugin.py @@ -1,6 +1,6 @@ import asyncio import inspect -import json +import orjson from typing import ClassVar, List import websockets as Server @@ -44,10 +44,10 @@ async def message_recv(server_connection: Server.ServerConnection): # 只在debug模式下记录原始消息 if logger.level <= 10: # DEBUG level logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message) - decoded_raw_message: dict = json.loads(raw_message) + decoded_raw_message: dict = orjson.loads(raw_message) try: # 首先尝试解析原始消息 - decoded_raw_message: dict = json.loads(raw_message) + decoded_raw_message: dict = orjson.loads(raw_message) # 检查是否是切片消息 (来自 MMC) if chunker.is_chunk_message(decoded_raw_message): @@ -71,7 +71,7 @@ async def message_recv(server_connection: Server.ServerConnection): elif post_type is None: await put_response(decoded_raw_message) - except json.JSONDecodeError as e: + except orjson.JSONDecodeError as e: logger.error(f"消息解析失败: {e}") logger.debug(f"原始消息: {raw_message[:500]}...") except Exception as e: diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py b/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py index db6c18e59..86902354f 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/message_chunker.py @@ -5,7 +5,7 @@ """ import asyncio -import json +import orjson import time import uuid from typing import Any, Dict, List, Optional, Union @@ -34,7 +34,7 @@ class MessageChunker: """判断消息是否需要切片""" try: if isinstance(message, dict): - message_str = json.dumps(message, ensure_ascii=False) + message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') else: message_str = message return len(message_str.encode("utf-8")) > self.max_chunk_size @@ -58,7 +58,7 @@ class MessageChunker: try: # 统一转换为字符串 if isinstance(message, dict): - message_str = json.dumps(message, ensure_ascii=False) + message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') else: message_str = message @@ -116,7 +116,7 @@ class MessageChunker: """判断是否是切片消息""" try: if isinstance(message, str): - data = json.loads(message) + data = orjson.loads(message) else: data = message @@ -126,7 +126,7 @@ class MessageChunker: and "__mmc_chunk_data__" in data and "__mmc_is_chunked__" in data ) - except (json.JSONDecodeError, TypeError): + except (orjson.JSONDecodeError, TypeError): return False @@ -187,7 +187,7 @@ class MessageReassembler: try: # 统一转换为字典 if isinstance(message, str): - chunk_data = json.loads(message) + chunk_data = orjson.loads(message) else: chunk_data = message @@ -197,8 +197,8 @@ class MessageReassembler: if "_original_message" in chunk_data: # 这是一个被包装的非切片消息,解包返回 try: - return json.loads(chunk_data["_original_message"]) - except json.JSONDecodeError: + return orjson.loads(chunk_data["_original_message"]) + except orjson.JSONDecodeError: return {"text_message": chunk_data["_original_message"]} else: return chunk_data @@ -251,14 +251,14 @@ class MessageReassembler: # 尝试反序列化重组后的消息 try: - return json.loads(reassembled_message) - except json.JSONDecodeError: + return orjson.loads(reassembled_message) + except orjson.JSONDecodeError: # 如果不能反序列化为JSON,则作为文本消息返回 return {"text_message": reassembled_message} return None - except (json.JSONDecodeError, KeyError, TypeError) as e: + except (orjson.JSONDecodeError, KeyError, TypeError) as e: logger.error(f"处理切片消息时出错: {e}") return None diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index adedb19bd..34ec17772 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -1,5 +1,5 @@ import base64 -import json +import orjson import time import uuid from pathlib import Path @@ -783,7 +783,7 @@ class MessageHandler: # 检查JSON消息格式 if not message_data or "data" not in message_data: logger.warning("JSON消息格式不正确") - return Seg(type="json", data=json.dumps(message_data)) + return Seg(type="json", data=orjson.dumps(message_data).decode('utf-8')) try: # 尝试将json_data解析为Python对象 @@ -1146,13 +1146,13 @@ class MessageHandler: return None forward_message_id = forward_message_data.get("id") request_uuid = str(uuid.uuid4()) - payload = json.dumps( + payload = orjson.dumps( { "action": "get_forward_msg", "params": {"message_id": forward_message_id}, "echo": request_uuid, } - ) + ).decode('utf-8') try: connection = self.get_server_connection() if not connection: @@ -1167,9 +1167,9 @@ class MessageHandler: logger.error(f"获取转发消息失败: {str(e)}") return None logger.debug( - f"转发消息原始格式:{json.dumps(response)[:80]}..." - if len(json.dumps(response)) > 80 - else json.dumps(response) + f"转发消息原始格式:{orjson.dumps(response).decode('utf-8')[:80]}..." + if len(orjson.dumps(response).decode('utf-8')) > 80 + else orjson.dumps(response).decode('utf-8') ) response_data: Dict = response.get("data") if not response_data: diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index 67ad380c8..866028472 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -1,5 +1,5 @@ import asyncio -import json +import orjson import time from typing import ClassVar, Optional, Tuple @@ -241,7 +241,7 @@ class NoticeHandler: message_base: MessageBase = MessageBase( message_info=message_info, message_segment=handled_message, - raw_message=json.dumps(raw_message), + raw_message=orjson.dumps(raw_message).decode('utf-8'), ) if system_notice: @@ -602,7 +602,7 @@ class NoticeHandler: message_base: MessageBase = MessageBase( message_info=message_info, message_segment=seg_message, - raw_message=json.dumps( + raw_message=orjson.dumps( { "post_type": "notice", "notice_type": "group_ban", @@ -611,7 +611,7 @@ class NoticeHandler: "user_id": user_id, "operator_id": None, # 自然解除禁言没有操作者 } - ), + ).decode('utf-8'), ) await self.put_notice(message_base) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index f90dab7f8..3df7432b8 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -1,4 +1,4 @@ -import json +import orjson import random import time import uuid @@ -605,7 +605,7 @@ class SendHandler: async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict: request_uuid = str(uuid.uuid4()) - payload = json.dumps({"action": action, "params": params, "echo": request_uuid}) + payload = orjson.dumps({"action": action, "params": params, "echo": request_uuid}).decode('utf-8') # 获取当前连接 connection = self.get_server_connection() diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py index 263e0dcbd..b597a60f9 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/utils.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/utils.py @@ -1,6 +1,6 @@ import base64 import io -import json +import orjson import ssl import uuid from typing import List, Optional, Tuple, Union @@ -34,7 +34,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d """ logger.debug("获取群聊信息中") request_uuid = str(uuid.uuid4()) - payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}) + payload = orjson.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8') try: await websocket.send(payload) socket_response: dict = await get_response(request_uuid) @@ -56,7 +56,7 @@ async def get_group_detail_info(websocket: Server.ServerConnection, group_id: in """ logger.debug("获取群详细信息中") request_uuid = str(uuid.uuid4()) - payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}) + payload = orjson.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8') try: await websocket.send(payload) socket_response: dict = await get_response(request_uuid) @@ -78,13 +78,13 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use """ logger.debug("获取群成员信息中") request_uuid = str(uuid.uuid4()) - payload = json.dumps( + payload = orjson.dumps( { "action": "get_group_member_info", "params": {"group_id": group_id, "user_id": user_id, "no_cache": True}, "echo": request_uuid, } - ) + ).decode('utf-8') try: await websocket.send(payload) socket_response: dict = await get_response(request_uuid) @@ -146,7 +146,7 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None: """ logger.debug("获取自身信息中") request_uuid = str(uuid.uuid4()) - payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}) + payload = orjson.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}).decode('utf-8') try: await websocket.send(payload) response: dict = await get_response(request_uuid) @@ -183,7 +183,7 @@ async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> """ logger.debug("获取陌生人信息中") request_uuid = str(uuid.uuid4()) - payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid}) + payload = orjson.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid}).decode('utf-8') try: await websocket.send(payload) response: dict = await get_response(request_uuid) @@ -208,7 +208,7 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: Uni """ logger.debug("获取消息详情中") request_uuid = str(uuid.uuid4()) - payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}) + payload = orjson.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}).decode('utf-8') try: await websocket.send(payload) response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 @@ -236,13 +236,13 @@ async def get_record_detail( """ logger.debug("获取语音消息详情中") request_uuid = str(uuid.uuid4()) - payload = json.dumps( + payload = orjson.dumps( { "action": "get_record", "params": {"file": file, "file_id": file_id, "out_format": "wav"}, "echo": request_uuid, } - ) + ).decode('utf-8') try: await websocket.send(payload) response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 diff --git a/src/plugins/built_in/web_search_tool/engines/metaso_engine.py b/src/plugins/built_in/web_search_tool/engines/metaso_engine.py index 7a0f30999..78e7e67cb 100644 --- a/src/plugins/built_in/web_search_tool/engines/metaso_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/metaso_engine.py @@ -1,7 +1,7 @@ """ Metaso Search Engine (Chat Completions Mode) """ -import json +import orjson from typing import Any import httpx @@ -43,12 +43,12 @@ class MetasoClient: if data_str == "[DONE]": break try: - data = json.loads(data_str) + data = orjson.loads(data_str) delta = data.get("choices", [{}])[0].get("delta", {}) content_chunk = delta.get("content") if content_chunk: full_response_content += content_chunk - except json.JSONDecodeError: + except orjson.JSONDecodeError: logger.warning(f"Metaso stream: could not decode JSON line: {data_str}") continue diff --git a/tools/memory_visualizer/visualizer_server.py b/tools/memory_visualizer/visualizer_server.py index 7b606e200..222f38053 100644 --- a/tools/memory_visualizer/visualizer_server.py +++ b/tools/memory_visualizer/visualizer_server.py @@ -5,7 +5,7 @@ """ import asyncio -import json +import orjson import logging from datetime import datetime from pathlib import Path diff --git a/tools/memory_visualizer/visualizer_simple.py b/tools/memory_visualizer/visualizer_simple.py index d43b490d0..3a1d4047f 100644 --- a/tools/memory_visualizer/visualizer_simple.py +++ b/tools/memory_visualizer/visualizer_simple.py @@ -4,7 +4,7 @@ 直接从存储的数据文件生成可视化,无需启动完整的记忆管理器 """ -import json +import orjson import sys from pathlib import Path from datetime import datetime @@ -122,7 +122,7 @@ def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]: print(f"📂 加载图数据: {graph_file}") with open(graph_file, 'r', encoding='utf-8') as f: - data = json.load(f) + data = orjson.loads(f.read()) # 解析数据 nodes_dict = {} From fa353bf9d100c4d3e6707012b69ccb959bf58194 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 6 Nov 2025 13:11:54 +0800 Subject: [PATCH 2/6] =?UTF-8?q?feat(web=5Fsearch):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E7=AD=94=E6=A1=88=E6=A8=A1=E5=BC=8F=E6=94=AF=E6=8C=81=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96Exa=E6=90=9C=E7=B4=A2=E5=BC=95=E6=93=8E?= =?UTF-8?q?=E7=9A=84=E7=BB=93=E6=9E=9C=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../web_search_tool/engines/exa_engine.py | 98 +++++++++++++++++-- .../web_search_tool/tools/web_search.py | 47 +++++++-- 2 files changed, 126 insertions(+), 19 deletions(-) diff --git a/src/plugins/built_in/web_search_tool/engines/exa_engine.py b/src/plugins/built_in/web_search_tool/engines/exa_engine.py index 37655eb53..e09232249 100644 --- a/src/plugins/built_in/web_search_tool/engines/exa_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/exa_engine.py @@ -39,7 +39,7 @@ class ExaSearchEngine(BaseSearchEngine): return self.api_manager.is_available() async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: - """执行Exa搜索""" + """执行优化的Exa搜索(使用answer模式)""" if not self.is_available(): return [] @@ -47,7 +47,16 @@ class ExaSearchEngine(BaseSearchEngine): num_results = args.get("num_results", 3) time_range = args.get("time_range", "any") - exa_args = {"num_results": num_results, "text": True, "highlights": True} + # 优化的搜索参数 - 更注重答案质量 + exa_args = { + "num_results": num_results, + "text": True, + "highlights": True, + "summary": True, # 启用自动摘要 + "include_text": True, # 包含全文内容 + } + + # 时间范围过滤 if time_range != "any": today = datetime.now() start_date = today - timedelta(days=7 if time_range == "week" else 30) @@ -61,18 +70,89 @@ class ExaSearchEngine(BaseSearchEngine): return [] loop = asyncio.get_running_loop() + # 使用search_and_contents获取完整内容,优化为answer模式 func = functools.partial(exa_client.search_and_contents, query, **exa_args) search_response = await loop.run_in_executor(None, func) - return [ - { + # 优化结果处理 - 更注重答案质量 + results = [] + for res in search_response.results: + # 获取最佳内容片段 + highlights = getattr(res, "highlights", []) + summary = getattr(res, "summary", "") + text = getattr(res, "text", "") + + # 智能内容选择:摘要 > 高亮 > 文本开头 + if summary and len(summary) > 50: + snippet = summary.strip() + elif highlights: + snippet = " ".join(highlights).strip() + elif text: + snippet = text[:300] + "..." if len(text) > 300 else text + else: + snippet = "内容获取失败" + + # 只保留有意义的摘要 + if len(snippet) < 30: + snippet = text[:200] + "..." if text and len(text) > 200 else snippet + + results.append({ "title": res.title, "url": res.url, - "snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."), + "snippet": snippet, "provider": "Exa", - } - for res in search_response.results - ] + "answer_focused": True, # 标记为答案导向的搜索 + }) + + return results except Exception as e: - logger.error(f"Exa 搜索失败: {e}") + logger.error(f"Exa answer模式搜索失败: {e}") + return [] + + async def answer_search(self, args: dict[str, Any]) -> list[dict[str, Any]]: + """执行Exa快速答案搜索 - 最精简的搜索模式""" + if not self.is_available(): + return [] + + query = args["query"] + num_results = min(args.get("num_results", 2), 2) # 限制结果数量,专注质量 + + # 精简的搜索参数 - 专注快速答案 + exa_args = { + "num_results": num_results, + "text": False, # 不需要全文 + "highlights": True, # 只要关键高亮 + "summary": True, # 优先摘要 + } + + try: + exa_client = self.api_manager.get_next_client() + if not exa_client: + return [] + + loop = asyncio.get_running_loop() + func = functools.partial(exa_client.search_and_contents, query, **exa_args) + search_response = await loop.run_in_executor(None, func) + + # 极简结果处理 - 只保留最核心信息 + results = [] + for res in search_response.results: + summary = getattr(res, "summary", "") + highlights = getattr(res, "highlights", []) + + # 优先使用摘要,否则使用高亮 + answer_text = summary.strip() if summary and len(summary) > 30 else " ".join(highlights).strip() + + if answer_text and len(answer_text) > 20: + results.append({ + "title": res.title, + "url": res.url, + "snippet": answer_text[:400] + "..." if len(answer_text) > 400 else answer_text, + "provider": "Exa-Answer", + "answer_mode": True # 标记为纯答案模式 + }) + + return results + except Exception as e: + logger.error(f"Exa快速答案搜索失败: {e}") return [] diff --git a/src/plugins/built_in/web_search_tool/tools/web_search.py b/src/plugins/built_in/web_search_tool/tools/web_search.py index 466dae538..eaac1d7e1 100644 --- a/src/plugins/built_in/web_search_tool/tools/web_search.py +++ b/src/plugins/built_in/web_search_tool/tools/web_search.py @@ -41,6 +41,13 @@ class WebSurfingTool(BaseTool): False, ["any", "week", "month"], ), + ( + "answer_mode", + ToolParamType.BOOLEAN, + "是否启用答案模式(仅适用于Exa搜索引擎)。启用后将返回更精简、直接的答案,减少冗余信息。默认为False。", + False, + None, + ), ] # type: ignore def __init__(self, plugin_config=None, chat_stream=None): @@ -97,13 +104,19 @@ class WebSurfingTool(BaseTool): ) -> dict[str, Any]: """并行搜索策略:同时使用所有启用的搜索引擎""" search_tasks = [] + answer_mode = function_args.get("answer_mode", False) for engine_name in enabled_engines: engine = self.engines.get(engine_name) if engine and engine.is_available(): custom_args = function_args.copy() custom_args["num_results"] = custom_args.get("num_results", 5) - search_tasks.append(engine.search(custom_args)) + + # 如果启用了answer模式且是Exa引擎,使用answer_search方法 + if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): + search_tasks.append(engine.answer_search(custom_args)) + else: + search_tasks.append(engine.search(custom_args)) if not search_tasks: @@ -137,17 +150,23 @@ class WebSurfingTool(BaseTool): self, function_args: dict[str, Any], enabled_engines: list[str] ) -> dict[str, Any]: """回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个""" + answer_mode = function_args.get("answer_mode", False) + for engine_name in enabled_engines: engine = self.engines.get(engine_name) if not engine or not engine.is_available(): - continue try: custom_args = function_args.copy() custom_args["num_results"] = custom_args.get("num_results", 5) - results = await engine.search(custom_args) + # 如果启用了answer模式且是Exa引擎,使用answer_search方法 + if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): + logger.info("使用Exa答案模式进行搜索(fallback策略)") + results = await engine.answer_search(custom_args) + else: + results = await engine.search(custom_args) if results: # 如果有结果,直接返回 formatted_content = format_search_results(results) @@ -164,22 +183,30 @@ class WebSurfingTool(BaseTool): async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]: """单一搜索策略:只使用第一个可用的搜索引擎""" + answer_mode = function_args.get("answer_mode", False) + for engine_name in enabled_engines: engine = self.engines.get(engine_name) if not engine or not engine.is_available(): - continue try: custom_args = function_args.copy() custom_args["num_results"] = custom_args.get("num_results", 5) - results = await engine.search(custom_args) - formatted_content = format_search_results(results) - return { - "type": "web_search_result", - "content": formatted_content, - } + # 如果启用了answer模式且是Exa引擎,使用answer_search方法 + if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): + logger.info("使用Exa答案模式进行搜索") + results = await engine.answer_search(custom_args) + else: + results = await engine.search(custom_args) + + if results: + formatted_content = format_search_results(results) + return { + "type": "web_search_result", + "content": formatted_content, + } except Exception as e: logger.error(f"{engine_name} 搜索失败: {e}") From ffdd4c6b9c03926a038dd1a83001811de237b145 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 6 Nov 2025 14:22:59 +0800 Subject: [PATCH 3/6] =?UTF-8?q?feat(tool=5Fhistory):=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E6=B5=81=E5=B7=A5=E5=85=B7=E5=8E=86=E5=8F=B2=E7=AE=A1=E7=90=86?= =?UTF-8?q?=E5=99=A8=EF=BC=8C=E4=BB=A5=E5=A2=9E=E5=BC=BA=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E8=B7=9F=E8=B8=AA=E5=92=8C=E7=BC=93=E5=AD=98?= =?UTF-8?q?-=20=E6=B7=BB=E5=8A=A0=E4=BA=86=20StreamToolHistoryManager?= =?UTF-8?q?=EF=BC=8C=E7=94=A8=E4=BA=8E=E7=AE=A1=E7=90=86=E8=81=8A=E5=A4=A9?= =?UTF-8?q?=E6=B5=81=E7=BA=A7=E5=88=AB=E7=9A=84=E5=B7=A5=E5=85=B7=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E5=8E=86=E5=8F=B2=E3=80=82-=20=E5=BC=95=E5=85=A5?= =?UTF-8?q?=E4=BA=86=20ToolCallRecord=EF=BC=8C=E7=94=A8=E4=BA=8E=E8=AF=A6?= =?UTF-8?q?=E7=BB=86=E8=AE=B0=E5=BD=95=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8?= =?UTF-8?q?=EF=BC=8C=E5=8C=85=E6=8B=AC=E6=89=A7=E8=A1=8C=E6=97=B6=E9=97=B4?= =?UTF-8?q?=E5=92=8C=E7=BC=93=E5=AD=98=E5=91=BD=E4=B8=AD=E6=83=85=E5=86=B5?= =?UTF-8?q?=E3=80=82-=20=E9=9B=86=E6=88=90=E4=BA=86=E5=86=85=E5=AD=98?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E5=92=8C=E5=85=A8=E5=B1=80=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E7=B3=BB=E7=BB=9F=EF=BC=8C=E4=BB=A5=E9=AB=98=E6=95=88=E6=A3=80?= =?UTF-8?q?=E7=B4=A2=E7=BB=93=E6=9E=9C=E3=80=82-=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E4=BA=86=20ToolExecutor=EF=BC=8C=E4=BB=A5=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E6=96=B0=E7=9A=84=E5=8E=86=E5=8F=B2=E7=AE=A1=E7=90=86=E5=99=A8?= =?UTF-8?q?=E8=AE=B0=E5=BD=95=E5=92=8C=E8=8E=B7=E5=8F=96=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E3=80=82-=20=E5=A2=9E=E5=BC=BA=E4=BA=86=20Ex?= =?UTF-8?q?aSearchEngine=EF=BC=8C=E4=BB=A5=E9=99=90=E5=88=B6=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E7=BB=93=E6=9E=9C=E6=95=B0=E9=87=8F=E5=B9=B6=E6=8F=90?= =?UTF-8?q?=E5=8D=87=E7=AD=94=E6=A1=88=E8=B4=A8=E9=87=8F=E3=80=82-=20?= =?UTF-8?q?=E9=87=8D=E6=9E=84=E4=BA=86=20CacheManager=20=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E7=AE=A1=E7=90=86=EF=BC=8C=E4=BB=A5=E5=8C=85?= =?UTF-8?q?=E6=8B=AC=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8=E7=BB=9F=E8=AE=A1?= =?UTF-8?q?=E5=92=8C=E6=80=A7=E8=83=BD=E6=8C=87=E6=A0=87=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/default_generator.py | 28 +- src/common/cache_manager.py | 239 ++++++++-- src/plugin_system/core/stream_tool_history.py | 414 ++++++++++++++++++ src/plugin_system/core/tool_use.py | 240 +++++----- .../web_search_tool/engines/exa_engine.py | 5 +- 5 files changed, 743 insertions(+), 183 deletions(-) create mode 100644 src/plugin_system/core/stream_tool_history.py diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 993b3b024..b6c4808ba 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -667,32 +667,46 @@ class DefaultReplyer: return "" try: - # 使用工具执行器获取信息 + # 首先获取当前的历史记录(在执行新工具调用之前) + tool_history_str = self.tool_executor.history_manager.format_for_prompt(max_records=3, include_results=True) + + # 然后执行工具调用 tool_results, _, _ = await self.tool_executor.execute_from_chat_message( sender=sender, target_message=target, chat_history=chat_history, return_details=False ) + info_parts = [] + + # 显示之前的工具调用历史(不包括当前这次调用) + if tool_history_str: + info_parts.append(tool_history_str) + + # 显示当前工具调用的结果(简要信息) if tool_results: - tool_info_str = "以下是你通过工具获取到的实时信息:\n" + current_results_parts = ["## 🔧 刚获取的工具信息"] for tool_result in tool_results: tool_name = tool_result.get("tool_name", "unknown") content = tool_result.get("content", "") result_type = tool_result.get("type", "tool_result") - tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n" + # 不进行截断,让工具自己处理结果长度 + current_results_parts.append(f"- **{tool_name}**: {content}") - tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。" + info_parts.append("\n".join(current_results_parts)) logger.info(f"获取到 {len(tool_results)} 个工具结果") - return tool_info_str - else: - logger.debug("未获取到任何工具结果") + # 如果没有任何信息,返回空字符串 + if not info_parts: + logger.debug("未获取到任何工具结果或历史记录") return "" + return "\n\n".join(info_parts) + except Exception as e: logger.error(f"工具信息获取失败: {e}") return "" + def _parse_reply_target(self, target_message: str) -> tuple[str, str]: """解析回复目标消息 - 使用共享工具""" from src.chat.utils.prompt import Prompt diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index b656b0ca1..7a4f6eda6 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -57,8 +57,16 @@ class CacheManager: # 嵌入模型 self.embedding_model = LLMRequest(model_config.model_task_config.embedding) + # 工具调用统计 + self.tool_stats = { + "total_tool_calls": 0, + "cache_hits_by_tool": {}, # 按工具名称统计缓存命中 + "execution_times_by_tool": {}, # 按工具名称统计执行时间 + "most_used_tools": {}, # 最常用的工具 + } + self._initialized = True - logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)") + logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB) + 工具统计") @staticmethod def _validate_embedding(embedding_result: Any) -> np.ndarray | None: @@ -363,58 +371,205 @@ class CacheManager: def get_health_stats(self) -> dict[str, Any]: """获取缓存健康统计信息""" - from src.common.memory_utils import format_size - + # 简化的健康统计,不包含内存监控(因为相关属性未定义) return { "l1_count": len(self.l1_kv_cache), - "l1_memory": self.l1_current_memory, - "l1_memory_formatted": format_size(self.l1_current_memory), - "l1_max_memory": self.l1_max_memory, - "l1_memory_usage_percent": round((self.l1_current_memory / self.l1_max_memory) * 100, 2), - "l1_max_size": self.l1_max_size, - "l1_size_usage_percent": round((len(self.l1_kv_cache) / self.l1_max_size) * 100, 2), - "average_item_size": self.l1_current_memory // len(self.l1_kv_cache) if self.l1_kv_cache else 0, - "average_item_size_formatted": format_size(self.l1_current_memory // len(self.l1_kv_cache)) if self.l1_kv_cache else "0 B", - "largest_item_size": max(self.l1_size_map.values()) if self.l1_size_map else 0, - "largest_item_size_formatted": format_size(max(self.l1_size_map.values())) if self.l1_size_map else "0 B", + "l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0, + "tool_stats": { + "total_tool_calls": self.tool_stats.get("total_tool_calls", 0), + "tracked_tools": len(self.tool_stats.get("most_used_tools", {})), + "cache_hits": sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()), + "cache_misses": sum(data.get("misses", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()), + } } - + def check_health(self) -> tuple[bool, list[str]]: """检查缓存健康状态 - + Returns: (is_healthy, warnings) - 是否健康,警告列表 """ warnings = [] - - # 检查内存使用 - memory_usage = (self.l1_current_memory / self.l1_max_memory) * 100 - if memory_usage > 90: - warnings.append(f"⚠️ L1缓存内存使用率过高: {memory_usage:.1f}%") - elif memory_usage > 75: - warnings.append(f"⚡ L1缓存内存使用率较高: {memory_usage:.1f}%") - - # 检查条目数 - size_usage = (len(self.l1_kv_cache) / self.l1_max_size) * 100 - if size_usage > 90: - warnings.append(f"⚠️ L1缓存条目数过多: {size_usage:.1f}%") - - # 检查平均条目大小 - if self.l1_kv_cache: - avg_size = self.l1_current_memory // len(self.l1_kv_cache) - if avg_size > 100 * 1024: # >100KB - from src.common.memory_utils import format_size - warnings.append(f"⚡ 平均缓存条目过大: {format_size(avg_size)}") - - # 检查最大单条目 - if self.l1_size_map: - max_size = max(self.l1_size_map.values()) - if max_size > 500 * 1024: # >500KB - from src.common.memory_utils import format_size - warnings.append(f"⚠️ 发现超大缓存条目: {format_size(max_size)}") - + + # 检查L1缓存大小 + l1_size = len(self.l1_kv_cache) + if l1_size > 1000: # 如果超过1000个条目 + warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}") + + # 检查向量索引大小 + vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0 + if isinstance(vector_count, int) and vector_count > 500: + warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}") + + # 检查工具统计健康 + total_calls = self.tool_stats.get("total_tool_calls", 0) + if total_calls > 0: + total_hits = sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()) + cache_hit_rate = (total_hits / total_calls) * 100 + if cache_hit_rate < 50: # 缓存命中率低于50% + warnings.append(f"⚡ 整体缓存命中率较低: {cache_hit_rate:.1f}%") + return len(warnings) == 0, warnings + async def get_tool_result_with_stats(self, + tool_name: str, + function_args: dict[str, Any], + tool_file_path: str | Path, + semantic_query: str | None = None) -> tuple[Any | None, bool]: + """获取工具结果并更新统计信息 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + tool_file_path: 工具文件路径 + semantic_query: 语义查询字符串 + + Returns: + Tuple[结果, 是否命中缓存] + """ + # 更新总调用次数 + self.tool_stats["total_tool_calls"] += 1 + + # 更新工具使用统计 + if tool_name not in self.tool_stats["most_used_tools"]: + self.tool_stats["most_used_tools"][tool_name] = 0 + self.tool_stats["most_used_tools"][tool_name] += 1 + + # 尝试获取缓存 + result = await self.get(tool_name, function_args, tool_file_path, semantic_query) + + # 更新缓存命中统计 + if tool_name not in self.tool_stats["cache_hits_by_tool"]: + self.tool_stats["cache_hits_by_tool"][tool_name] = {"hits": 0, "misses": 0} + + if result is not None: + self.tool_stats["cache_hits_by_tool"][tool_name]["hits"] += 1 + logger.info(f"工具缓存命中: {tool_name}") + return result, True + else: + self.tool_stats["cache_hits_by_tool"][tool_name]["misses"] += 1 + return None, False + + async def set_tool_result_with_stats(self, + tool_name: str, + function_args: dict[str, Any], + tool_file_path: str | Path, + data: Any, + execution_time: float | None = None, + ttl: int | None = None, + semantic_query: str | None = None): + """存储工具结果并更新统计信息 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + tool_file_path: 工具文件路径 + data: 结果数据 + execution_time: 执行时间 + ttl: 缓存TTL + semantic_query: 语义查询字符串 + """ + # 更新执行时间统计 + if execution_time is not None: + if tool_name not in self.tool_stats["execution_times_by_tool"]: + self.tool_stats["execution_times_by_tool"][tool_name] = [] + self.tool_stats["execution_times_by_tool"][tool_name].append(execution_time) + + # 只保留最近100次的执行时间记录 + if len(self.tool_stats["execution_times_by_tool"][tool_name]) > 100: + self.tool_stats["execution_times_by_tool"][tool_name] = \ + self.tool_stats["execution_times_by_tool"][tool_name][-100:] + + # 存储到缓存 + await self.set(tool_name, function_args, tool_file_path, data, ttl, semantic_query) + + def get_tool_performance_stats(self) -> dict[str, Any]: + """获取工具性能统计信息 + + Returns: + 统计信息字典 + """ + stats = self.tool_stats.copy() + + # 计算平均执行时间 + avg_times = {} + for tool_name, times in stats["execution_times_by_tool"].items(): + if times: + avg_times[tool_name] = { + "average": sum(times) / len(times), + "min": min(times), + "max": max(times), + "count": len(times), + } + + # 计算缓存命中率 + cache_hit_rates = {} + for tool_name, hit_data in stats["cache_hits_by_tool"].items(): + total = hit_data["hits"] + hit_data["misses"] + if total > 0: + cache_hit_rates[tool_name] = { + "hit_rate": (hit_data["hits"] / total) * 100, + "hits": hit_data["hits"], + "misses": hit_data["misses"], + "total": total, + } + + # 按使用频率排序工具 + most_used = sorted(stats["most_used_tools"].items(), key=lambda x: x[1], reverse=True) + + return { + "total_tool_calls": stats["total_tool_calls"], + "average_execution_times": avg_times, + "cache_hit_rates": cache_hit_rates, + "most_used_tools": most_used[:10], # 前10个最常用工具 + "cache_health": self.get_health_stats(), + } + + def get_tool_recommendations(self) -> dict[str, Any]: + """获取工具优化建议 + + Returns: + 优化建议字典 + """ + recommendations = [] + + # 分析缓存命中率低的工具 + cache_hit_rates = {} + for tool_name, hit_data in self.tool_stats["cache_hits_by_tool"].items(): + total = hit_data["hits"] + hit_data["misses"] + if total >= 5: # 至少调用5次才分析 + hit_rate = (hit_data["hits"] / total) * 100 + cache_hit_rates[tool_name] = hit_rate + + if hit_rate < 30: # 缓存命中率低于30% + recommendations.append({ + "tool": tool_name, + "type": "low_cache_hit_rate", + "message": f"工具 {tool_name} 的缓存命中率仅为 {hit_rate:.1f}%,建议检查缓存配置或参数变化频率", + "severity": "medium" if hit_rate > 10 else "high", + }) + + # 分析执行时间长的工具 + for tool_name, times in self.tool_stats["execution_times_by_tool"].items(): + if len(times) >= 3: # 至少3次执行才分析 + avg_time = sum(times) / len(times) + if avg_time > 5.0: # 平均执行时间超过5秒 + recommendations.append({ + "tool": tool_name, + "type": "slow_execution", + "message": f"工具 {tool_name} 平均执行时间较长 ({avg_time:.2f}s),建议优化算法或增加缓存", + "severity": "medium" if avg_time < 10.0 else "high", + }) + + return { + "recommendations": recommendations, + "summary": { + "total_issues": len(recommendations), + "high_priority": len([r for r in recommendations if r["severity"] == "high"]), + "medium_priority": len([r for r in recommendations if r["severity"] == "medium"]), + } + } + # 全局实例 tool_cache = CacheManager() diff --git a/src/plugin_system/core/stream_tool_history.py b/src/plugin_system/core/stream_tool_history.py new file mode 100644 index 000000000..a15c77040 --- /dev/null +++ b/src/plugin_system/core/stream_tool_history.py @@ -0,0 +1,414 @@ +""" +流式工具历史记录管理器 +用于在聊天流级别管理工具调用历史,支持智能缓存和上下文感知 +""" + +import time +from typing import Any, Optional +from dataclasses import dataclass, asdict, field +import orjson +from src.common.logger import get_logger +from src.common.cache_manager import tool_cache + +logger = get_logger("stream_tool_history") + + +@dataclass +class ToolCallRecord: + """工具调用记录""" + tool_name: str + args: dict[str, Any] + result: Optional[dict[str, Any]] = None + status: str = "success" # success, error, pending + timestamp: float = field(default_factory=time.time) + execution_time: Optional[float] = None # 执行耗时(秒) + cache_hit: bool = False # 是否命中缓存 + result_preview: str = "" # 结果预览 + error_message: str = "" # 错误信息 + + def __post_init__(self): + """后处理:生成结果预览""" + if self.result and not self.result_preview: + content = self.result.get("content", "") + if isinstance(content, str): + self.result_preview = content[:500] + ("..." if len(content) > 500 else "") + elif isinstance(content, (list, dict)): + try: + self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..." + except Exception: + self.result_preview = str(content)[:500] + "..." + else: + self.result_preview = str(content)[:500] + "..." + + +class StreamToolHistoryManager: + """流式工具历史记录管理器 + + 提供以下功能: + 1. 工具调用历史的持久化管理 + 2. 智能缓存集成和结果去重 + 3. 上下文感知的历史记录检索 + 4. 性能监控和统计 + """ + + def __init__(self, chat_id: str, max_history: int = 20, enable_memory_cache: bool = True): + """初始化历史记录管理器 + + Args: + chat_id: 聊天ID,用于隔离不同聊天流的历史记录 + max_history: 最大历史记录数量 + enable_memory_cache: 是否启用内存缓存 + """ + self.chat_id = chat_id + self.max_history = max_history + self.enable_memory_cache = enable_memory_cache + + # 内存中的历史记录,按时间顺序排列 + self._history: list[ToolCallRecord] = [] + + # 性能统计 + self._stats = { + "total_calls": 0, + "cache_hits": 0, + "cache_misses": 0, + "total_execution_time": 0.0, + "average_execution_time": 0.0, + } + + logger.info(f"[{chat_id}] 工具历史记录管理器初始化完成,最大历史: {max_history}") + + async def add_tool_call(self, record: ToolCallRecord) -> None: + """添加工具调用记录 + + Args: + record: 工具调用记录 + """ + # 维护历史记录大小 + if len(self._history) >= self.max_history: + # 移除最旧的记录 + removed_record = self._history.pop(0) + logger.debug(f"[{self.chat_id}] 移除旧记录: {removed_record.tool_name}") + + # 添加新记录 + self._history.append(record) + + # 更新统计 + self._stats["total_calls"] += 1 + if record.cache_hit: + self._stats["cache_hits"] += 1 + else: + self._stats["cache_misses"] += 1 + + if record.execution_time is not None: + self._stats["total_execution_time"] += record.execution_time + self._stats["average_execution_time"] = self._stats["total_execution_time"] / self._stats["total_calls"] + + logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}") + + async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]: + """从缓存或历史记录中获取结果 + + Args: + tool_name: 工具名称 + args: 工具参数 + + Returns: + 缓存的结果,如果不存在则返回None + """ + # 首先检查内存中的历史记录 + if self.enable_memory_cache: + memory_result = self._search_memory_cache(tool_name, args) + if memory_result: + logger.info(f"[{self.chat_id}] 内存缓存命中: {tool_name}") + return memory_result + + # 然后检查全局缓存系统 + try: + # 这里需要工具实例来获取文件路径,但为了解耦,我们先尝试从历史记录中推断 + tool_file_path = self._infer_tool_path(tool_name) + + # 尝试语义缓存(如果可以推断出语义查询参数) + semantic_query = self._extract_semantic_query(tool_name, args) + + cached_result = await tool_cache.get( + tool_name=tool_name, + function_args=args, + tool_file_path=tool_file_path, + semantic_query=semantic_query, + ) + + if cached_result: + logger.info(f"[{self.chat_id}] 全局缓存命中: {tool_name}") + + # 将结果同步到内存缓存 + if self.enable_memory_cache: + record = ToolCallRecord( + tool_name=tool_name, + args=args, + result=cached_result, + status="success", + cache_hit=True, + timestamp=time.time(), + ) + await self.add_tool_call(record) + + return cached_result + + except Exception as e: + logger.warning(f"[{self.chat_id}] 缓存查询失败: {e}") + + return None + + async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any], + execution_time: Optional[float] = None, + tool_file_path: Optional[str] = None, + ttl: Optional[int] = None) -> None: + """缓存工具调用结果 + + Args: + tool_name: 工具名称 + args: 工具参数 + result: 执行结果 + execution_time: 执行耗时 + tool_file_path: 工具文件路径 + ttl: 缓存TTL + """ + # 添加到内存历史记录 + record = ToolCallRecord( + tool_name=tool_name, + args=args, + result=result, + status="success", + execution_time=execution_time, + cache_hit=False, + timestamp=time.time(), + ) + await self.add_tool_call(record) + + # 同步到全局缓存系统 + try: + if tool_file_path is None: + tool_file_path = self._infer_tool_path(tool_name) + + # 尝试语义缓存 + semantic_query = self._extract_semantic_query(tool_name, args) + + await tool_cache.set( + tool_name=tool_name, + function_args=args, + tool_file_path=tool_file_path, + data=result, + ttl=ttl, + semantic_query=semantic_query, + ) + + logger.debug(f"[{self.chat_id}] 结果已缓存: {tool_name}") + + except Exception as e: + logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}") + + async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]: + """获取最近的历史记录 + + Args: + count: 返回的记录数量 + status_filter: 状态过滤器,可选值:success, error, pending + + Returns: + 历史记录列表 + """ + history = self._history.copy() + + # 应用状态过滤 + if status_filter: + history = [record for record in history if record.status == status_filter] + + # 返回最近的记录 + return history[-count:] if history else [] + + def format_for_prompt(self, max_records: int = 5, include_results: bool = True) -> str: + """格式化历史记录为提示词 + + Args: + max_records: 最大记录数量 + include_results: 是否包含结果预览 + + Returns: + 格式化的提示词字符串 + """ + if not self._history: + return "" + + recent_records = self._history[-max_records:] + + lines = ["## 🔧 最近工具调用记录"] + for i, record in enumerate(recent_records, 1): + status_icon = "✅" if record.status == "success" else "❌" if record.status == "error" else "⏳" + + # 格式化参数 + args_preview = self._format_args_preview(record.args) + + # 基础信息 + lines.append(f"{i}. {status_icon} **{record.tool_name}**({args_preview})") + + # 添加执行时间和缓存信息 + if record.execution_time is not None: + time_info = f"{record.execution_time:.2f}s" + cache_info = "🎯缓存" if record.cache_hit else "🔍执行" + lines.append(f" ⏱️ {time_info} | {cache_info}") + + # 添加结果预览 + if include_results and record.result_preview: + lines.append(f" 📝 结果: {record.result_preview}") + + # 添加错误信息 + if record.status == "error" and record.error_message: + lines.append(f" ❌ 错误: {record.error_message}") + + # 添加统计信息 + if self._stats["total_calls"] > 0: + cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100 + avg_time = self._stats["average_execution_time"] + lines.append(f"\n📊 工具统计: 总计{self._stats['total_calls']}次 | 缓存命中率{cache_hit_rate:.1f}% | 平均耗时{avg_time:.2f}s") + + return "\n".join(lines) + + def get_stats(self) -> dict[str, Any]: + """获取性能统计信息 + + Returns: + 统计信息字典 + """ + cache_hit_rate = 0.0 + if self._stats["total_calls"] > 0: + cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100 + + return { + **self._stats, + "cache_hit_rate": cache_hit_rate, + "history_size": len(self._history), + "chat_id": self.chat_id, + } + + def clear_history(self) -> None: + """清除历史记录""" + self._history.clear() + logger.info(f"[{self.chat_id}] 工具历史记录已清除") + + def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]: + """在内存历史记录中搜索缓存 + + Args: + tool_name: 工具名称 + args: 工具参数 + + Returns: + 匹配的结果,如果不存在则返回None + """ + for record in reversed(self._history): # 从最新的开始搜索 + if (record.tool_name == tool_name and + record.status == "success" and + record.args == args): + return record.result + return None + + def _infer_tool_path(self, tool_name: str) -> str: + """推断工具文件路径 + + Args: + tool_name: 工具名称 + + Returns: + 推断的文件路径 + """ + # 基于工具名称推断路径,这是一个简化的实现 + # 在实际使用中,可能需要更复杂的映射逻辑 + tool_path_mapping = { + "web_search": "src/plugins/built_in/web_search_tool/tools/web_search.py", + "memory_create": "src/memory_graph/tools/memory_tools.py", + "memory_search": "src/memory_graph/tools/memory_tools.py", + "user_profile_update": "src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py", + "chat_stream_impression_update": "src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py", + } + + return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py") + + def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]: + """提取语义查询参数 + + Args: + tool_name: 工具名称 + args: 工具参数 + + Returns: + 语义查询字符串,如果不存在则返回None + """ + # 为不同工具定义语义查询参数映射 + semantic_query_mapping = { + "web_search": "query", + "memory_search": "query", + "knowledge_search": "query", + } + + query_key = semantic_query_mapping.get(tool_name) + if query_key and query_key in args: + return str(args[query_key]) + + return None + + def _format_args_preview(self, args: dict[str, Any], max_length: int = 100) -> str: + """格式化参数预览 + + Args: + args: 参数字典 + max_length: 最大长度 + + Returns: + 格式化的参数预览字符串 + """ + if not args: + return "" + + try: + args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8') + if len(args_str) > max_length: + args_str = args_str[:max_length] + "..." + return args_str + except Exception: + # 如果序列化失败,使用简单格式 + parts = [] + for k, v in list(args.items())[:3]: # 最多显示3个参数 + parts.append(f"{k}={str(v)[:20]}") + result = ", ".join(parts) + if len(parts) >= 3 or len(result) > max_length: + result += "..." + return result + + +# 全局管理器字典,按chat_id索引 +_stream_managers: dict[str, StreamToolHistoryManager] = {} + + +def get_stream_tool_history_manager(chat_id: str) -> StreamToolHistoryManager: + """获取指定聊天的工具历史记录管理器 + + Args: + chat_id: 聊天ID + + Returns: + 工具历史记录管理器实例 + """ + if chat_id not in _stream_managers: + _stream_managers[chat_id] = StreamToolHistoryManager(chat_id) + return _stream_managers[chat_id] + + +def cleanup_stream_manager(chat_id: str) -> None: + """清理指定聊天的管理器 + + Args: + chat_id: 聊天ID + """ + if chat_id in _stream_managers: + del _stream_managers[chat_id] + logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器") \ No newline at end of file diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 6aa36aa6a..c705ac66f 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -3,7 +3,6 @@ import time from typing import Any from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.common.cache_manager import tool_cache from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.payload_content import ToolCall @@ -11,6 +10,8 @@ from src.llm_models.utils_model import LLMRequest from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.core.global_announcement_manager import global_announcement_manager +from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord +from dataclasses import asdict logger = get_logger("tool_use") @@ -36,15 +37,29 @@ def init_tool_executor_prompt(): {tool_history} -## 🔧 工具使用 +## 🔧 工具决策指南 -根据上下文判断是否需要使用工具。每个工具都有详细的description说明其用途和参数,请根据工具定义决定是否调用。 +**核心原则:** +- 根据上下文智能判断是否需要使用工具 +- 每个工具都有详细的description说明其用途和参数 +- 避免重复调用历史记录中已执行的工具(除非参数不同) +- 优先考虑使用已有的缓存结果,避免重复调用 + +**历史记录说明:** +- 上方显示的是**之前**的工具调用记录 +- 请参考历史记录避免重复调用相同参数的工具 +- 如果历史记录中已有相关结果,可以考虑直接回答而不调用工具 **⚠️ 记忆创建特别提醒:** 创建记忆时,subject(主体)必须使用对话历史中显示的**真实发送人名字**! - ✅ 正确:从"Prou(12345678): ..."中提取"Prou"作为subject - ❌ 错误:使用"用户"、"对方"等泛指词 +**工具调用策略:** +1. **避免重复调用**:查看历史记录,如果最近已调用过相同工具且参数一致,无需重复调用 +2. **智能选择工具**:根据消息内容选择最合适的工具,避免过度使用 +3. **参数优化**:确保工具参数简洁有效,避免冗余信息 + **执行指令:** - 需要使用工具 → 直接调用相应的工具函数 - 不需要工具 → 输出 "No tool needed" @@ -81,9 +96,8 @@ class ToolExecutor: """待处理的第二步工具调用,格式为 {tool_name: step_two_definition}""" self._log_prefix_initialized = False - # 工具调用历史 - self.tool_call_history: list[dict[str, Any]] = [] - """工具调用历史,包含工具名称、参数和结果""" + # 流式工具历史记录管理器 + self.history_manager = get_stream_tool_history_manager(chat_id) # logger.info(f"{self.log_prefix}工具执行器初始化完成") # 移到异步初始化中 @@ -125,7 +139,7 @@ class ToolExecutor: bot_name = global_config.bot.nickname # 构建工具调用历史文本 - tool_history = self._format_tool_history() + tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True) # 获取人设信息 personality_core = global_config.personality.personality_core @@ -183,83 +197,7 @@ class ToolExecutor: return tool_definitions - def _format_tool_history(self, max_history: int = 5) -> str: - """格式化工具调用历史为文本 - - Args: - max_history: 最多显示的历史记录数量 - - Returns: - 格式化的工具历史文本 - """ - if not self.tool_call_history: - return "" - - # 只取最近的几条历史 - recent_history = self.tool_call_history[-max_history:] - - history_lines = ["历史工具调用记录:"] - for i, record in enumerate(recent_history, 1): - tool_name = record.get("tool_name", "unknown") - args = record.get("args", {}) - result_preview = record.get("result_preview", "") - status = record.get("status", "success") - - # 格式化参数 - args_str = ", ".join([f"{k}={v}" for k, v in args.items()]) - - # 格式化记录 - status_emoji = "✓" if status == "success" else "✗" - history_lines.append(f"{i}. {status_emoji} {tool_name}({args_str})") - - if result_preview: - # 限制结果预览长度 - if len(result_preview) > 200: - result_preview = result_preview[:200] + "..." - history_lines.append(f" 结果: {result_preview}") - - return "\n".join(history_lines) - - def _add_tool_to_history(self, tool_name: str, args: dict, result: dict | None, status: str = "success"): - """添加工具调用到历史记录 - - Args: - tool_name: 工具名称 - args: 工具参数 - result: 工具结果 - status: 执行状态 (success/error) - """ - # 生成结果预览 - result_preview = "" - if result: - content = result.get("content", "") - if isinstance(content, str): - result_preview = content - elif isinstance(content, list | dict): - import orjson - - try: - result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') - except Exception: - result_preview = str(content) - else: - result_preview = str(content) - - record = { - "tool_name": tool_name, - "args": args, - "result_preview": result_preview, - "status": status, - "timestamp": time.time(), - } - - self.tool_call_history.append(record) - - # 限制历史记录数量,避免内存溢出 - max_history_size = 5 - if len(self.tool_call_history) > max_history_size: - self.tool_call_history = self.tool_call_history[-max_history_size:] - + async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]: """执行工具调用 @@ -320,10 +258,20 @@ class ToolExecutor: logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...") # 记录到历史 - self._add_tool_to_history(tool_name, tool_args, result, status="success") + await self.history_manager.add_tool_call(ToolCallRecord( + tool_name=tool_name, + args=tool_args, + result=result, + status="success" + )) else: # 工具返回空结果也记录到历史 - self._add_tool_to_history(tool_name, tool_args, None, status="success") + await self.history_manager.add_tool_call(ToolCallRecord( + tool_name=tool_name, + args=tool_args, + result=None, + status="success" + )) except Exception as e: logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}") @@ -338,62 +286,72 @@ class ToolExecutor: tool_results.append(error_info) # 记录失败到历史 - self._add_tool_to_history(tool_name, tool_args, None, status="error") + await self.history_manager.add_tool_call(ToolCallRecord( + tool_name=tool_name, + args=tool_args, + result=None, + status="error", + error_message=str(e) + )) return tool_results, used_tools async def execute_tool_call( self, tool_call: ToolCall, tool_instance: BaseTool | None = None ) -> dict[str, Any] | None: - """执行单个工具调用,并处理缓存""" + """执行单个工具调用,集成流式历史记录管理器""" + start_time = time.time() function_args = tool_call.args or {} tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream) - # 如果工具不存在或未启用缓存,则直接执行 - if not tool_instance or not tool_instance.enable_cache: - return await self._original_execute_tool_call(tool_call, tool_instance) + # 尝试从历史记录管理器获取缓存结果 + if tool_instance and tool_instance.enable_cache: + try: + cached_result = await self.history_manager.get_cached_result( + tool_name=tool_call.func_name, + args=function_args + ) + if cached_result: + execution_time = time.time() - start_time + logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行") - # --- 缓存逻辑开始 --- - try: - tool_file_path = inspect.getfile(tool_instance.__class__) - semantic_query = None - if tool_instance.semantic_cache_query_key: - semantic_query = function_args.get(tool_instance.semantic_cache_query_key) + # 记录缓存命中到历史 + await self.history_manager.add_tool_call(ToolCallRecord( + tool_name=tool_call.func_name, + args=function_args, + result=cached_result, + status="success", + execution_time=execution_time, + cache_hit=True + )) - cached_result = await tool_cache.get( - tool_name=tool_call.func_name, - function_args=function_args, - tool_file_path=tool_file_path, - semantic_query=semantic_query, - ) - if cached_result: - logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行") - return cached_result - except Exception as e: - logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}") + return cached_result + except Exception as e: + logger.error(f"{self.log_prefix}检查历史缓存时出错: {e}") - # 缓存未命中,执行原始工具调用 + # 缓存未命中,执行工具调用 result = await self._original_execute_tool_call(tool_call, tool_instance) - # 将结果存入缓存 - try: - tool_file_path = inspect.getfile(tool_instance.__class__) - semantic_query = None - if tool_instance.semantic_cache_query_key: - semantic_query = function_args.get(tool_instance.semantic_cache_query_key) + # 记录执行结果到历史管理器 + execution_time = time.time() - start_time + if tool_instance and result and tool_instance.enable_cache: + try: + tool_file_path = inspect.getfile(tool_instance.__class__) + semantic_query = None + if tool_instance.semantic_cache_query_key: + semantic_query = function_args.get(tool_instance.semantic_cache_query_key) - await tool_cache.set( - tool_name=tool_call.func_name, - function_args=function_args, - tool_file_path=tool_file_path, - data=result, - ttl=tool_instance.cache_ttl, - semantic_query=semantic_query, - ) - except Exception as e: - logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}") - # --- 缓存逻辑结束 --- + await self.history_manager.cache_result( + tool_name=tool_call.func_name, + args=function_args, + result=result, + execution_time=execution_time, + tool_file_path=tool_file_path, + ttl=tool_instance.cache_ttl + ) + except Exception as e: + logger.error(f"{self.log_prefix}缓存结果到历史管理器时出错: {e}") return result @@ -528,21 +486,31 @@ class ToolExecutor: logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}") # 记录到历史 - self._add_tool_to_history(tool_name, tool_args, result, status="success") + await self.history_manager.add_tool_call(ToolCallRecord( + tool_name=tool_name, + args=tool_args, + result=result, + status="success" + )) return tool_info except Exception as e: logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}") # 记录失败到历史 - self._add_tool_to_history(tool_name, tool_args, None, status="error") + await self.history_manager.add_tool_call(ToolCallRecord( + tool_name=tool_name, + args=tool_args, + result=None, + status="error", + error_message=str(e) + )) return None def clear_tool_history(self): """清除工具调用历史""" - self.tool_call_history.clear() - logger.debug(f"{self.log_prefix}已清除工具调用历史") + self.history_manager.clear_history() def get_tool_history(self) -> list[dict[str, Any]]: """获取工具调用历史 @@ -550,7 +518,17 @@ class ToolExecutor: Returns: 工具调用历史列表 """ - return self.tool_call_history.copy() + # 返回最近的历史记录 + records = self.history_manager.get_recent_history(count=10) + return [asdict(record) for record in records] + + def get_tool_stats(self) -> dict[str, Any]: + """获取工具统计信息 + + Returns: + 工具统计信息字典 + """ + return self.history_manager.get_stats() """ diff --git a/src/plugins/built_in/web_search_tool/engines/exa_engine.py b/src/plugins/built_in/web_search_tool/engines/exa_engine.py index e09232249..323216e81 100644 --- a/src/plugins/built_in/web_search_tool/engines/exa_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/exa_engine.py @@ -44,7 +44,7 @@ class ExaSearchEngine(BaseSearchEngine): return [] query = args["query"] - num_results = args.get("num_results", 3) + num_results = min(args.get("num_results", 5), 5) # 默认5个结果,但限制最多5个 time_range = args.get("time_range", "any") # 优化的搜索参数 - 更注重答案质量 @@ -53,7 +53,6 @@ class ExaSearchEngine(BaseSearchEngine): "text": True, "highlights": True, "summary": True, # 启用自动摘要 - "include_text": True, # 包含全文内容 } # 时间范围过滤 @@ -115,7 +114,7 @@ class ExaSearchEngine(BaseSearchEngine): return [] query = args["query"] - num_results = min(args.get("num_results", 2), 2) # 限制结果数量,专注质量 + num_results = min(args.get("num_results", 3), 3) # answer模式默认3个结果,专注质量 # 精简的搜索参数 - 专注快速答案 exa_args = { From e9b37e032d8386d3906780c19781c34f11a66ee5 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 6 Nov 2025 14:26:30 +0800 Subject: [PATCH 4/6] =?UTF-8?q?feat(memory):=20=E4=BC=98=E5=8C=96=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E6=A3=80=E7=B4=A2=E5=8A=A9=E6=89=8B=E7=9A=84=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2=E7=94=9F=E6=88=90=E9=80=BB=E8=BE=91=EF=BC=8C=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E7=A4=BA=E4=BE=8B=E4=BB=A5=E6=8F=90=E9=AB=98=E5=87=86?= =?UTF-8?q?=E7=A1=AE=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memory_graph/manager.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index 7161878f6..69b0f5d1d 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -363,17 +363,14 @@ class MemoryManager: # 构建上下文信息 chat_history = context.get("chat_history", "") if context else "" - sender = context.get("sender", "") if context else "" - participants = context.get("participants", []) if context else [] - participants_str = "、".join(participants) if participants else "无" - + prompt = f"""你是记忆检索助手。为提高检索准确率,请为查询生成3-5个不同角度的搜索语句。 **核心原则(重要!):** -对于包含多个概念的复杂查询(如"杰瑞喵如何评价新的记忆系统"),应该生成: +对于包含多个概念的复杂查询(如"小明如何评价新的记忆系统"),应该生成: 1. 完整查询(包含所有要素)- 权重1.0 2. 每个关键概念的独立查询(如"新的记忆系统")- 权重0.8,避免被主体淹没! -3. 主体+动作组合(如"杰瑞喵 评价")- 权重0.6 +3. 主体+动作组合(如"小明 评价")- 权重0.6 4. 泛化查询(如"记忆系统")- 权重0.7 **要求:** @@ -382,9 +379,7 @@ class MemoryManager: - 查询简洁(5-20字) - 直接输出JSON,不要添加说明 -**已知参与者:** {participants_str} **对话上下文:** {chat_history[-300:] if chat_history else "无"} -**当前查询:** {sender}: {query} **输出JSON格式:** ```json From 0da5c04ba291b749619d3383a09ecf26a2feb11d Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 6 Nov 2025 14:29:14 +0800 Subject: [PATCH 5/6] =?UTF-8?q?fix(memory):=20=E6=9B=B4=E6=96=B0=E5=A4=8D?= =?UTF-8?q?=E6=9D=82=E6=9F=A5=E8=AF=A2=E7=A4=BA=E4=BE=8B=E4=BB=A5=E6=8F=90?= =?UTF-8?q?=E9=AB=98=E6=A3=80=E7=B4=A2=E5=87=86=E7=A1=AE=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memory_graph/manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index 69b0f5d1d..6f20b4889 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -367,11 +367,11 @@ class MemoryManager: prompt = f"""你是记忆检索助手。为提高检索准确率,请为查询生成3-5个不同角度的搜索语句。 **核心原则(重要!):** -对于包含多个概念的复杂查询(如"小明如何评价新的记忆系统"),应该生成: +对于包含多个概念的复杂查询(如"小明如何评价小王"),应该生成: 1. 完整查询(包含所有要素)- 权重1.0 -2. 每个关键概念的独立查询(如"新的记忆系统")- 权重0.8,避免被主体淹没! +2. 每个关键概念的独立查询(如"小明"、"小王")- 权重0.8,避免被主体淹没! 3. 主体+动作组合(如"小明 评价")- 权重0.6 -4. 泛化查询(如"记忆系统")- 权重0.7 +4. 泛化查询(如"评价")- 权重0.7 **要求:** - 第一个必须是原始查询或同义改写 From d75476d41cd860828496e232c27d8be1f73d4dfe Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 6 Nov 2025 15:15:53 +0800 Subject: [PATCH 6/6] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=E8=81=8A?= =?UTF-8?q?=E5=A4=A9=E5=9B=9E=E5=A4=8D=E7=94=9F=E6=88=90=E5=99=A8=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=8F=82=E4=B8=8E=E8=80=85=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E5=92=8C=E8=81=8A=E5=A4=A9=E5=8E=86=E5=8F=B2=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/default_generator.py | 44 ++++++++++++++-- src/memory_graph/manager.py | 2 - src/memory_graph/tools/memory_tools.py | 52 ++++++++++++++----- .../planner/plan_filter.py | 2 +- 4 files changed, 82 insertions(+), 18 deletions(-) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index b6c4808ba..a3d9e5a5c 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -553,18 +553,56 @@ class DefaultReplyer: if user_info_obj: sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "") + # 获取参与者信息 + participants = [] + try: + # 尝试从聊天流中获取参与者信息 + if hasattr(stream, 'chat_history_manager'): + history_manager = stream.chat_history_manager + # 获取最近的参与者列表 + recent_records = history_manager.get_memory_chat_history( + user_id=getattr(stream, "user_id", ""), + count=10, + memory_types=["chat_message", "system_message"] + ) + # 提取唯一的参与者名称 + for record in recent_records[:5]: # 最近5条记录 + content = record.get("content", {}) + participant = content.get("participant_name") + if participant and participant not in participants: + participants.append(participant) + + # 如果消息包含发送者信息,也添加到参与者列表 + if content.get("sender_name") and content.get("sender_name") not in participants: + participants.append(content.get("sender_name")) + except Exception as e: + logger.debug(f"获取参与者信息失败: {e}") + + # 如果发送者不在参与者列表中,添加进去 + if sender_name and sender_name not in participants: + participants.insert(0, sender_name) + + # 格式化聊天历史为更友好的格式 + formatted_history = "" + if chat_history: + # 移除过长的历史记录,只保留最近部分 + lines = chat_history.strip().split('\n') + recent_lines = lines[-10:] if len(lines) > 10 else lines + formatted_history = '\n'.join(recent_lines) + query_context = { - "chat_history": chat_history if chat_history else "", + "chat_history": formatted_history, "sender": sender_name, + "participants": participants, } - # 使用记忆管理器的智能检索(自动优化查询) + # 使用记忆管理器的智能检索(多查询策略) memories = await manager.search_memories( query=target, top_k=10, min_importance=0.3, include_forgotten=False, - optimize_query=True, + use_multi_query=True, context=query_context, ) diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index 6f20b4889..aef5bded0 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -432,7 +432,6 @@ class MemoryManager: time_range: Optional[Tuple[datetime, datetime]] = None, min_importance: float = 0.0, include_forgotten: bool = False, - optimize_query: bool = True, use_multi_query: bool = True, expand_depth: int = 1, context: Optional[Dict[str, Any]] = None, @@ -453,7 +452,6 @@ class MemoryManager: time_range: 时间范围过滤 (start, end) min_importance: 最小重要性 include_forgotten: 是否包含已遗忘的记忆 - optimize_query: 是否使用小模型优化查询(已弃用,被 use_multi_query 替代) use_multi_query: 是否使用多查询策略(推荐,默认True) expand_depth: 图扩展深度(0=禁用, 1=推荐, 2-3=深度探索) context: 查询上下文(用于优化) diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index 692671e85..fddc151b0 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -629,27 +629,55 @@ class MemoryTools: try: from src.llm_models.utils_model import LLMRequest from src.config.config import model_config - + llm = LLMRequest( model_set=model_config.model_task_config.utils_small, request_type="memory.multi_query" ) - - participants = context.get("participants", []) if context else [] - prompt = f"""为查询生成3-5个不同角度的搜索语句(JSON格式)。 -**查询:** {query} + # 获取上下文信息 + participants = context.get("participants", []) if context else [] + chat_history = context.get("chat_history", "") if context else "" + sender = context.get("sender", "") if context else "" + + # 处理聊天历史,提取最近5条左右的对话 + recent_chat = "" + if chat_history: + lines = chat_history.strip().split('\n') + # 取最近5条消息 + recent_lines = lines[-5:] if len(lines) > 5 else lines + recent_chat = '\n'.join(recent_lines) + + prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句(JSON格式)。 + +**当前查询:** {query} +**发送者:** {sender if sender else '未知'} **参与者:** {', '.join(participants) if participants else '无'} -**原则:** 对复杂查询(如"杰瑞喵如何评价新的记忆系统"),应生成: -1. 完整查询(权重1.0) -2. 每个关键概念独立查询(权重0.8)- 重要! -3. 主体+动作(权重0.6) +**最近聊天记录(最近5条):** +{recent_chat if recent_chat else '无聊天历史'} -**输出JSON:** +**分析原则:** +1. **上下文理解**:根据聊天历史理解查询的真实意图 +2. **指代消解**:识别并代换"他"、"她"、"它"、"那个"等指代词 +3. **话题关联**:结合最近讨论的话题生成更精准的查询 +4. **查询分解**:对复杂查询分解为多个子查询 + +**生成策略:** +1. **完整查询**(权重1.0):结合上下文的完整查询,包含指代消解 +2. **关键概念查询**(权重0.8):查询中的核心概念,特别是聊天中提到的实体 +3. **话题扩展查询**(权重0.7):基于最近聊天话题的相关查询 +4. **动作/情感查询**(权重0.6):如果涉及情感或动作,生成相关查询 + +**输出JSON格式:** ```json -{{"queries": [{{"text": "查询1", "weight": 1.0}}, {{"text": "查询2", "weight": 0.8}}]}} -```""" +{{"queries": [{{"text": "查询语句", "weight": 1.0}}, {{"text": "查询语句", "weight": 0.8}}]}} +``` + +**示例:** +- 查询:"他怎么样了?" + 聊天中提到"小明生病了" → "小明身体恢复情况" +- 查询:"那个项目" + 聊天中讨论"记忆系统开发" → "记忆系统项目进展" +""" response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250) diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py index 14e31c400..a194a1705 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py @@ -652,7 +652,7 @@ class ChatterPlanFilter: enhanced_memories = await memory_manager.search_memories( query=query, top_k=5, - optimize_query=False, # 直接使用关键词查询 + use_multi_query=False, # 直接使用关键词查询 ) if not enhanced_memories: