diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index 9118adb40..f1abfe9d9 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -574,6 +574,126 @@ class CacheManager: } } + async def recall_relevant_cache( + self, + query_text: str, + tool_name: str | None = None, + top_k: int = 3, + similarity_threshold: float = 0.70, + ) -> list[dict[str, Any]]: + """ + 根据语义相似度主动召回相关的缓存条目 + + 用于在回复前扫描缓存,找到与当前对话相关的历史搜索结果 + + Args: + query_text: 用于语义匹配的查询文本(通常是最近几条聊天内容) + tool_name: 可选,限制只召回特定工具的缓存(如 "web_search") + top_k: 返回的最大结果数 + similarity_threshold: 相似度阈值(L2距离,越小越相似) + + Returns: + 相关缓存条目列表,每个条目包含 {tool_name, query, content, similarity} + """ + if not query_text or not self.embedding_model: + return [] + + try: + # 生成查询向量 + embedding_result = await self.embedding_model.get_embedding(query_text) + if not embedding_result: + return [] + + embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result + validated_embedding = self._validate_embedding(embedding_vector) + if validated_embedding is None: + return [] + + query_embedding = np.array([validated_embedding], dtype="float32") + + # 从 L2 向量数据库查询 + results = vector_db_service.query( + collection_name=self.semantic_cache_collection_name, + query_embeddings=query_embedding.tolist(), + n_results=top_k * 2, # 多取一些,后面会过滤 + ) + + if not results or not results.get("ids") or not results["ids"][0]: + logger.debug("[缓存召回] 未找到相关缓存") + return [] + + recalled_items = [] + ids = results["ids"][0] if isinstance(results["ids"][0], list) else [results["ids"][0]] + distances = results.get("distances", [[]])[0] if results.get("distances") else [] + + for i, cache_key in enumerate(ids): + distance = distances[i] if i < len(distances) else 1.0 + + # 过滤相似度不够的 + if distance > similarity_threshold: + continue + + # 从数据库获取缓存数据 + cache_obj = await db_query( + model_class=CacheEntries, + query_type="get", + filters={"cache_key": cache_key}, + single_result=True, + ) + + if not cache_obj: + continue + + # 检查是否过期 + expires_at = getattr(cache_obj, "expires_at", 0) + if time.time() >= expires_at: + continue + + # 获取工具名称并过滤 + cached_tool_name = getattr(cache_obj, "tool_name", "") + if tool_name and cached_tool_name != tool_name: + continue + + # 解析缓存内容 + try: + cache_value = getattr(cache_obj, "cache_value", "{}") + data = orjson.loads(cache_value) + content = data.get("content", "") if isinstance(data, dict) else str(data) + + # 从 cache_key 中提取原始查询(格式: tool_name::{"query": "xxx", ...}::file_hash) + original_query = "" + try: + key_parts = cache_key.split("::") + if len(key_parts) >= 2: + args_json = key_parts[1] + args = orjson.loads(args_json) + original_query = args.get("query", "") + except Exception: + pass + + recalled_items.append({ + "tool_name": cached_tool_name, + "query": original_query, + "content": content, + "similarity": 1.0 - distance, # 转换为相似度分数 + }) + + except Exception as e: + logger.warning(f"解析缓存内容失败: {e}") + continue + + if len(recalled_items) >= top_k: + break + + if recalled_items: + logger.info(f"[缓存召回] 找到 {len(recalled_items)} 条相关缓存") + + return recalled_items + + except Exception as e: + logger.error(f"[缓存召回] 语义召回失败: {e}") + return [] + # 全局实例 tool_cache = CacheManager() diff --git a/src/plugin_system/core/stream_tool_history.py b/src/plugin_system/core/stream_tool_history.py index 8e498395b..8f8b2f48b 100644 --- a/src/plugin_system/core/stream_tool_history.py +++ b/src/plugin_system/core/stream_tool_history.py @@ -32,15 +32,35 @@ class ToolCallRecord: """后处理:生成结果预览""" if self.result and not self.result_preview: content = self.result.get("content", "") + # 联网搜索等重要工具不截断结果 + no_truncate_tools = {"web_search", "web_surfing", "knowledge_search"} + should_truncate = self.tool_name not in no_truncate_tools + max_length = 500 if should_truncate else 10000 # 联网搜索给更大的限制 + if isinstance(content, str): - self.result_preview = content[:500] + ("..." if len(content) > 500 else "") + if len(content) > max_length: + self.result_preview = content[:max_length] + "..." + else: + self.result_preview = content elif isinstance(content, list | dict): try: - self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")[:500] + "..." + json_str = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") + if len(json_str) > max_length: + self.result_preview = json_str[:max_length] + "..." + else: + self.result_preview = json_str except Exception: - self.result_preview = str(content)[:500] + "..." + str_content = str(content) + if len(str_content) > max_length: + self.result_preview = str_content[:max_length] + "..." + else: + self.result_preview = str_content else: - self.result_preview = str(content)[:500] + "..." + str_content = str(content) + if len(str_content) > max_length: + self.result_preview = str_content[:max_length] + "..." + else: + self.result_preview = str_content class StreamToolHistoryManager: diff --git a/src/plugins/built_in/kokoro_flow_chatter/context_builder.py b/src/plugins/built_in/kokoro_flow_chatter/context_builder.py index cc6c5cfbf..b57f06f42 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/context_builder.py +++ b/src/plugins/built_in/kokoro_flow_chatter/context_builder.py @@ -334,12 +334,46 @@ class KFCContextBuilder: tool_executor = ToolExecutor(chat_id=self.chat_id) - # 首先获取当前的历史记录(在执行新工具调用之前) + info_parts = [] + + # ========== 1. 主动召回联网搜索缓存 ========== + try: + from src.common.cache_manager import tool_cache + + # 使用聊天历史作为语义查询 + query_text = chat_history if chat_history else target_message + recalled_caches = await tool_cache.recall_relevant_cache( + query_text=query_text, + tool_name="web_search", # 只召回联网搜索的缓存 + top_k=2, + similarity_threshold=0.65, # 相似度阈值 + ) + + if recalled_caches: + recall_parts = ["### 🔍 相关的历史搜索结果"] + for item in recalled_caches: + original_query = item.get("query", "") + content = item.get("content", "") + similarity = item.get("similarity", 0) + if content: + # 截断过长的内容 + if len(content) > 500: + content = content[:500] + "..." + recall_parts.append(f"**搜索「{original_query}」** (相关度:{similarity:.0%})\n{content}") + + info_parts.append("\n\n".join(recall_parts)) + logger.info(f"[缓存召回] 召回了 {len(recalled_caches)} 条相关搜索缓存") + except Exception as e: + logger.debug(f"[缓存召回] 召回失败(非关键): {e}") + + # ========== 2. 获取工具调用历史 ========== tool_history_str = tool_executor.history_manager.format_for_prompt( max_records=3, include_results=True ) + if tool_history_str: + info_parts.append(tool_history_str) - # 然后执行工具调用 + # ========== 3. 执行工具调用 ========== tool_results, _, _ = await tool_executor.execute_from_chat_message( sender=sender_name, target_message=target_message, @@ -347,12 +381,6 @@ class KFCContextBuilder: return_details=False, ) - info_parts = [] - - # 显示之前的工具调用历史(不包括当前这次调用) - if tool_history_str: - info_parts.append(tool_history_str) - # 显示当前工具调用的结果(简要信息) if tool_results: current_results_parts = ["### 🔧 刚获取的工具信息"] diff --git a/src/plugins/built_in/web_search_tool/tools/web_search.py b/src/plugins/built_in/web_search_tool/tools/web_search.py index 82ce0325a..d9238fae5 100644 --- a/src/plugins/built_in/web_search_tool/tools/web_search.py +++ b/src/plugins/built_in/web_search_tool/tools/web_search.py @@ -28,7 +28,12 @@ class WebSurfingTool(BaseTool): name: str = "web_search" description: str = ( - "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具" + "联网搜索工具。使用场景:\n" + "1. 用户问的问题你不确定答案、需要验证\n" + "2. 涉及最新信息(新闻、产品、事件、时效性内容)\n" + "3. 需要查找具体数据、事实、定义\n" + "4. 用户明确要求搜索\n" + "不要担心调用频率,搜索结果会被缓存。" ) available_for_llm: bool = True parameters: ClassVar[list] = [