Merge branch 'feature/memory-graph-system' of https://github.com/MoFox-Studio/MoFox_Bot into feature/memory-graph-system

This commit is contained in:
tt-P607
2025-11-06 16:52:19 +08:00
24 changed files with 1035 additions and 304 deletions

View File

@@ -323,8 +323,8 @@ class GlobalNoticeManager:
return message.additional_config.get("is_notice", False) return message.additional_config.get("is_notice", False)
elif isinstance(message.additional_config, str): elif isinstance(message.additional_config, str):
# 兼容JSON字符串格式 # 兼容JSON字符串格式
import json import orjson
config = json.loads(message.additional_config) config = orjson.loads(message.additional_config)
return config.get("is_notice", False) return config.get("is_notice", False)
# 检查消息类型或其他标识 # 检查消息类型或其他标识
@@ -349,8 +349,8 @@ class GlobalNoticeManager:
if isinstance(message.additional_config, dict): if isinstance(message.additional_config, dict):
return message.additional_config.get("notice_type") return message.additional_config.get("notice_type")
elif isinstance(message.additional_config, str): elif isinstance(message.additional_config, str):
import json import orjson
config = json.loads(message.additional_config) config = orjson.loads(message.additional_config)
return config.get("notice_type") return config.get("notice_type")
return None return None
except Exception: except Exception:

View File

@@ -553,18 +553,56 @@ class DefaultReplyer:
if user_info_obj: if user_info_obj:
sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "") sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "")
# 获取参与者信息
participants = []
try:
# 尝试从聊天流中获取参与者信息
if hasattr(stream, 'chat_history_manager'):
history_manager = stream.chat_history_manager
# 获取最近的参与者列表
recent_records = history_manager.get_memory_chat_history(
user_id=getattr(stream, "user_id", ""),
count=10,
memory_types=["chat_message", "system_message"]
)
# 提取唯一的参与者名称
for record in recent_records[:5]: # 最近5条记录
content = record.get("content", {})
participant = content.get("participant_name")
if participant and participant not in participants:
participants.append(participant)
# 如果消息包含发送者信息,也添加到参与者列表
if content.get("sender_name") and content.get("sender_name") not in participants:
participants.append(content.get("sender_name"))
except Exception as e:
logger.debug(f"获取参与者信息失败: {e}")
# 如果发送者不在参与者列表中,添加进去
if sender_name and sender_name not in participants:
participants.insert(0, sender_name)
# 格式化聊天历史为更友好的格式
formatted_history = ""
if chat_history:
# 移除过长的历史记录,只保留最近部分
lines = chat_history.strip().split('\n')
recent_lines = lines[-10:] if len(lines) > 10 else lines
formatted_history = '\n'.join(recent_lines)
query_context = { query_context = {
"chat_history": chat_history if chat_history else "", "chat_history": formatted_history,
"sender": sender_name, "sender": sender_name,
"participants": participants,
} }
# 使用记忆管理器的智能检索(自动优化查询 # 使用记忆管理器的智能检索(多查询策略
memories = await manager.search_memories( memories = await manager.search_memories(
query=target, query=target,
top_k=10, top_k=10,
min_importance=0.3, min_importance=0.3,
include_forgotten=False, include_forgotten=False,
optimize_query=True, use_multi_query=True,
context=query_context, context=query_context,
) )
@@ -667,32 +705,46 @@ class DefaultReplyer:
return "" return ""
try: try:
# 使用工具执行器获取信息 # 首先获取当前的历史记录(在执行新工具调用之前)
tool_history_str = self.tool_executor.history_manager.format_for_prompt(max_records=3, include_results=True)
# 然后执行工具调用
tool_results, _, _ = await self.tool_executor.execute_from_chat_message( tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
sender=sender, target_message=target, chat_history=chat_history, return_details=False sender=sender, target_message=target, chat_history=chat_history, return_details=False
) )
info_parts = []
# 显示之前的工具调用历史(不包括当前这次调用)
if tool_history_str:
info_parts.append(tool_history_str)
# 显示当前工具调用的结果(简要信息)
if tool_results: if tool_results:
tool_info_str = "以下是你通过工具获取到的实时信息:\n" current_results_parts = ["## 🔧 刚获取的工具信息"]
for tool_result in tool_results: for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown") tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "") content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result") result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n" # 不进行截断,让工具自己处理结果长度
current_results_parts.append(f"- **{tool_name}**: {content}")
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。" info_parts.append("\n".join(current_results_parts))
logger.info(f"获取到 {len(tool_results)} 个工具结果") logger.info(f"获取到 {len(tool_results)} 个工具结果")
return tool_info_str # 如果没有任何信息,返回空字符串
else: if not info_parts:
logger.debug("未获取到任何工具结果") logger.debug("未获取到任何工具结果或历史记录")
return "" return ""
return "\n\n".join(info_parts)
except Exception as e: except Exception as e:
logger.error(f"工具信息获取失败: {e}") logger.error(f"工具信息获取失败: {e}")
return "" return ""
def _parse_reply_target(self, target_message: str) -> tuple[str, str]: def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
"""解析回复目标消息 - 使用共享工具""" """解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt from src.chat.utils.prompt import Prompt

View File

@@ -57,8 +57,16 @@ class CacheManager:
# 嵌入模型 # 嵌入模型
self.embedding_model = LLMRequest(model_config.model_task_config.embedding) self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
# 工具调用统计
self.tool_stats = {
"total_tool_calls": 0,
"cache_hits_by_tool": {}, # 按工具名称统计缓存命中
"execution_times_by_tool": {}, # 按工具名称统计执行时间
"most_used_tools": {}, # 最常用的工具
}
self._initialized = True self._initialized = True
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)") logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB) + 工具统计")
@staticmethod @staticmethod
def _validate_embedding(embedding_result: Any) -> np.ndarray | None: def _validate_embedding(embedding_result: Any) -> np.ndarray | None:
@@ -363,20 +371,16 @@ class CacheManager:
def get_health_stats(self) -> dict[str, Any]: def get_health_stats(self) -> dict[str, Any]:
"""获取缓存健康统计信息""" """获取缓存健康统计信息"""
from src.common.memory_utils import format_size # 简化的健康统计,不包含内存监控(因为相关属性未定义)
return { return {
"l1_count": len(self.l1_kv_cache), "l1_count": len(self.l1_kv_cache),
"l1_memory": self.l1_current_memory, "l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0,
"l1_memory_formatted": format_size(self.l1_current_memory), "tool_stats": {
"l1_max_memory": self.l1_max_memory, "total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
"l1_memory_usage_percent": round((self.l1_current_memory / self.l1_max_memory) * 100, 2), "tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
"l1_max_size": self.l1_max_size, "cache_hits": sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
"l1_size_usage_percent": round((len(self.l1_kv_cache) / self.l1_max_size) * 100, 2), "cache_misses": sum(data.get("misses", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
"average_item_size": self.l1_current_memory // len(self.l1_kv_cache) if self.l1_kv_cache else 0, }
"average_item_size_formatted": format_size(self.l1_current_memory // len(self.l1_kv_cache)) if self.l1_kv_cache else "0 B",
"largest_item_size": max(self.l1_size_map.values()) if self.l1_size_map else 0,
"largest_item_size_formatted": format_size(max(self.l1_size_map.values())) if self.l1_size_map else "0 B",
} }
def check_health(self) -> tuple[bool, list[str]]: def check_health(self) -> tuple[bool, list[str]]:
@@ -387,34 +391,185 @@ class CacheManager:
""" """
warnings = [] warnings = []
# 检查内存使用 # 检查L1缓存大小
memory_usage = (self.l1_current_memory / self.l1_max_memory) * 100 l1_size = len(self.l1_kv_cache)
if memory_usage > 90: if l1_size > 1000: # 如果超过1000个条目
warnings.append(f"⚠️ L1缓存内存使用率过高: {memory_usage:.1f}%") warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
elif memory_usage > 75:
warnings.append(f"⚡ L1缓存内存使用率较高: {memory_usage:.1f}%")
# 检查条目数 # 检查向量索引大小
size_usage = (len(self.l1_kv_cache) / self.l1_max_size) * 100 vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0
if size_usage > 90: if isinstance(vector_count, int) and vector_count > 500:
warnings.append(f"⚠️ L1缓存条目数多: {size_usage:.1f}%") warnings.append(f"⚠️ 向量索引条目数多: {vector_count}")
# 检查平均条目大小 # 检查工具统计健康
if self.l1_kv_cache: total_calls = self.tool_stats.get("total_tool_calls", 0)
avg_size = self.l1_current_memory // len(self.l1_kv_cache) if total_calls > 0:
if avg_size > 100 * 1024: # >100KB total_hits = sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values())
from src.common.memory_utils import format_size cache_hit_rate = (total_hits / total_calls) * 100
warnings.append(f"⚡ 平均缓存条目过大: {format_size(avg_size)}") if cache_hit_rate < 50: # 缓存命中率低于50%
warnings.append(f"⚡ 整体缓存命中率较低: {cache_hit_rate:.1f}%")
# 检查最大单条目
if self.l1_size_map:
max_size = max(self.l1_size_map.values())
if max_size > 500 * 1024: # >500KB
from src.common.memory_utils import format_size
warnings.append(f"⚠️ 发现超大缓存条目: {format_size(max_size)}")
return len(warnings) == 0, warnings return len(warnings) == 0, warnings
async def get_tool_result_with_stats(self,
tool_name: str,
function_args: dict[str, Any],
tool_file_path: str | Path,
semantic_query: str | None = None) -> tuple[Any | None, bool]:
"""获取工具结果并更新统计信息
Args:
tool_name: 工具名称
function_args: 函数参数
tool_file_path: 工具文件路径
semantic_query: 语义查询字符串
Returns:
Tuple[结果, 是否命中缓存]
"""
# 更新总调用次数
self.tool_stats["total_tool_calls"] += 1
# 更新工具使用统计
if tool_name not in self.tool_stats["most_used_tools"]:
self.tool_stats["most_used_tools"][tool_name] = 0
self.tool_stats["most_used_tools"][tool_name] += 1
# 尝试获取缓存
result = await self.get(tool_name, function_args, tool_file_path, semantic_query)
# 更新缓存命中统计
if tool_name not in self.tool_stats["cache_hits_by_tool"]:
self.tool_stats["cache_hits_by_tool"][tool_name] = {"hits": 0, "misses": 0}
if result is not None:
self.tool_stats["cache_hits_by_tool"][tool_name]["hits"] += 1
logger.info(f"工具缓存命中: {tool_name}")
return result, True
else:
self.tool_stats["cache_hits_by_tool"][tool_name]["misses"] += 1
return None, False
async def set_tool_result_with_stats(self,
tool_name: str,
function_args: dict[str, Any],
tool_file_path: str | Path,
data: Any,
execution_time: float | None = None,
ttl: int | None = None,
semantic_query: str | None = None):
"""存储工具结果并更新统计信息
Args:
tool_name: 工具名称
function_args: 函数参数
tool_file_path: 工具文件路径
data: 结果数据
execution_time: 执行时间
ttl: 缓存TTL
semantic_query: 语义查询字符串
"""
# 更新执行时间统计
if execution_time is not None:
if tool_name not in self.tool_stats["execution_times_by_tool"]:
self.tool_stats["execution_times_by_tool"][tool_name] = []
self.tool_stats["execution_times_by_tool"][tool_name].append(execution_time)
# 只保留最近100次的执行时间记录
if len(self.tool_stats["execution_times_by_tool"][tool_name]) > 100:
self.tool_stats["execution_times_by_tool"][tool_name] = \
self.tool_stats["execution_times_by_tool"][tool_name][-100:]
# 存储到缓存
await self.set(tool_name, function_args, tool_file_path, data, ttl, semantic_query)
def get_tool_performance_stats(self) -> dict[str, Any]:
"""获取工具性能统计信息
Returns:
统计信息字典
"""
stats = self.tool_stats.copy()
# 计算平均执行时间
avg_times = {}
for tool_name, times in stats["execution_times_by_tool"].items():
if times:
avg_times[tool_name] = {
"average": sum(times) / len(times),
"min": min(times),
"max": max(times),
"count": len(times),
}
# 计算缓存命中率
cache_hit_rates = {}
for tool_name, hit_data in stats["cache_hits_by_tool"].items():
total = hit_data["hits"] + hit_data["misses"]
if total > 0:
cache_hit_rates[tool_name] = {
"hit_rate": (hit_data["hits"] / total) * 100,
"hits": hit_data["hits"],
"misses": hit_data["misses"],
"total": total,
}
# 按使用频率排序工具
most_used = sorted(stats["most_used_tools"].items(), key=lambda x: x[1], reverse=True)
return {
"total_tool_calls": stats["total_tool_calls"],
"average_execution_times": avg_times,
"cache_hit_rates": cache_hit_rates,
"most_used_tools": most_used[:10], # 前10个最常用工具
"cache_health": self.get_health_stats(),
}
def get_tool_recommendations(self) -> dict[str, Any]:
"""获取工具优化建议
Returns:
优化建议字典
"""
recommendations = []
# 分析缓存命中率低的工具
cache_hit_rates = {}
for tool_name, hit_data in self.tool_stats["cache_hits_by_tool"].items():
total = hit_data["hits"] + hit_data["misses"]
if total >= 5: # 至少调用5次才分析
hit_rate = (hit_data["hits"] / total) * 100
cache_hit_rates[tool_name] = hit_rate
if hit_rate < 30: # 缓存命中率低于30%
recommendations.append({
"tool": tool_name,
"type": "low_cache_hit_rate",
"message": f"工具 {tool_name} 的缓存命中率仅为 {hit_rate:.1f}%,建议检查缓存配置或参数变化频率",
"severity": "medium" if hit_rate > 10 else "high",
})
# 分析执行时间长的工具
for tool_name, times in self.tool_stats["execution_times_by_tool"].items():
if len(times) >= 3: # 至少3次执行才分析
avg_time = sum(times) / len(times)
if avg_time > 5.0: # 平均执行时间超过5秒
recommendations.append({
"tool": tool_name,
"type": "slow_execution",
"message": f"工具 {tool_name} 平均执行时间较长 ({avg_time:.2f}s),建议优化算法或增加缓存",
"severity": "medium" if avg_time < 10.0 else "high",
})
return {
"recommendations": recommendations,
"summary": {
"total_issues": len(recommendations),
"high_priority": len([r for r in recommendations if r["severity"] == "high"]),
"medium_priority": len([r for r in recommendations if r["severity"] == "medium"]),
}
}
# 全局实例 # 全局实例
tool_cache = CacheManager() tool_cache = CacheManager()

View File

@@ -137,6 +137,7 @@ class MemoryManager:
graph_store=self.graph_store, graph_store=self.graph_store,
persistence_manager=self.persistence, persistence_manager=self.persistence,
embedding_generator=self.embedding_generator, embedding_generator=self.embedding_generator,
max_expand_depth=getattr(self.config, 'max_expand_depth', 1), # 从配置读取默认深度
) )
self._initialized = True self._initialized = True
@@ -362,18 +363,15 @@ class MemoryManager:
# 构建上下文信息 # 构建上下文信息
chat_history = context.get("chat_history", "") if context else "" chat_history = context.get("chat_history", "") if context else ""
sender = context.get("sender", "") if context else ""
participants = context.get("participants", []) if context else []
participants_str = "".join(participants) if participants else ""
prompt = f"""你是记忆检索助手。为提高检索准确率请为查询生成3-5个不同角度的搜索语句。 prompt = f"""你是记忆检索助手。为提高检索准确率请为查询生成3-5个不同角度的搜索语句。
**核心原则(重要!):** **核心原则(重要!):**
对于包含多个概念的复杂查询(如"杰瑞喵如何评价新的记忆系统"),应该生成: 对于包含多个概念的复杂查询(如"小明如何评价小王"),应该生成:
1. 完整查询(包含所有要素)- 权重1.0 1. 完整查询(包含所有要素)- 权重1.0
2. 每个关键概念的独立查询(如"新的记忆系统"- 权重0.8,避免被主体淹没! 2. 每个关键概念的独立查询(如"小明""小王"- 权重0.8,避免被主体淹没!
3. 主体+动作组合(如"杰瑞喵 评价"- 权重0.6 3. 主体+动作组合(如"小明 评价"- 权重0.6
4. 泛化查询(如"记忆系统"- 权重0.7 4. 泛化查询(如"评价"- 权重0.7
**要求:** **要求:**
- 第一个必须是原始查询或同义改写 - 第一个必须是原始查询或同义改写
@@ -381,9 +379,7 @@ class MemoryManager:
- 查询简洁5-20字 - 查询简洁5-20字
- 直接输出JSON不要添加说明 - 直接输出JSON不要添加说明
**已知参与者:** {participants_str}
**对话上下文:** {chat_history[-300:] if chat_history else ""} **对话上下文:** {chat_history[-300:] if chat_history else ""}
**当前查询:** {sender}: {query}
**输出JSON格式** **输出JSON格式**
```json ```json
@@ -436,7 +432,6 @@ class MemoryManager:
time_range: Optional[Tuple[datetime, datetime]] = None, time_range: Optional[Tuple[datetime, datetime]] = None,
min_importance: float = 0.0, min_importance: float = 0.0,
include_forgotten: bool = False, include_forgotten: bool = False,
optimize_query: bool = True,
use_multi_query: bool = True, use_multi_query: bool = True,
expand_depth: int = 1, expand_depth: int = 1,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
@@ -457,7 +452,6 @@ class MemoryManager:
time_range: 时间范围过滤 (start, end) time_range: 时间范围过滤 (start, end)
min_importance: 最小重要性 min_importance: 最小重要性
include_forgotten: 是否包含已遗忘的记忆 include_forgotten: 是否包含已遗忘的记忆
optimize_query: 是否使用小模型优化查询(已弃用,被 use_multi_query 替代)
use_multi_query: 是否使用多查询策略推荐默认True use_multi_query: 是否使用多查询策略推荐默认True
expand_depth: 图扩展深度0=禁用, 1=推荐, 2-3=深度探索) expand_depth: 图扩展深度0=禁用, 1=推荐, 2-3=深度探索)
context: 查询上下文(用于优化) context: 查询上下文(用于优化)

View File

@@ -102,8 +102,8 @@ class VectorStore:
# 处理额外的元数据,将 list 转换为 JSON 字符串 # 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items(): for key, value in node.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, (list, dict)):
import json import orjson
metadata[key] = json.dumps(value, ensure_ascii=False) metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value metadata[key] = value
else: else:
@@ -141,7 +141,7 @@ class VectorStore:
try: try:
# 准备元数据 # 准备元数据
import json import orjson
metadatas = [] metadatas = []
for n in valid_nodes: for n in valid_nodes:
metadata = { metadata = {
@@ -151,7 +151,7 @@ class VectorStore:
} }
for key, value in n.metadata.items(): for key, value in n.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, (list, dict)):
metadata[key] = json.dumps(value, ensure_ascii=False) metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value # type: ignore metadata[key] = value # type: ignore
else: else:
@@ -207,7 +207,7 @@ class VectorStore:
) )
# 解析结果 # 解析结果
import json import orjson
similar_nodes = [] similar_nodes = []
if results["ids"] and results["ids"][0]: if results["ids"] and results["ids"][0]:
for i, node_id in enumerate(results["ids"][0]): for i, node_id in enumerate(results["ids"][0]):
@@ -223,7 +223,7 @@ class VectorStore:
for key, value in list(metadata.items()): for key, value in list(metadata.items()):
if isinstance(value, str) and (value.startswith('[') or value.startswith('{')): if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
try: try:
metadata[key] = json.loads(value) metadata[key] = orjson.loads(value)
except: except:
pass # 保持原值 pass # 保持原值

View File

@@ -34,6 +34,7 @@ class MemoryTools:
graph_store: GraphStore, graph_store: GraphStore,
persistence_manager: PersistenceManager, persistence_manager: PersistenceManager,
embedding_generator: Optional[EmbeddingGenerator] = None, embedding_generator: Optional[EmbeddingGenerator] = None,
max_expand_depth: int = 1,
): ):
""" """
初始化工具集 初始化工具集
@@ -43,11 +44,13 @@ class MemoryTools:
graph_store: 图存储 graph_store: 图存储
persistence_manager: 持久化管理器 persistence_manager: 持久化管理器
embedding_generator: 嵌入生成器(可选) embedding_generator: 嵌入生成器(可选)
max_expand_depth: 图扩展深度的默认值(从配置读取)
""" """
self.vector_store = vector_store self.vector_store = vector_store
self.graph_store = graph_store self.graph_store = graph_store
self.persistence_manager = persistence_manager self.persistence_manager = persistence_manager
self._initialized = False self._initialized = False
self.max_expand_depth = max_expand_depth # 保存配置的默认值
# 初始化组件 # 初始化组件
self.extractor = MemoryExtractor() self.extractor = MemoryExtractor()
@@ -448,11 +451,12 @@ class MemoryTools:
try: try:
query = params.get("query", "") query = params.get("query", "")
top_k = params.get("top_k", 10) top_k = params.get("top_k", 10)
expand_depth = params.get("expand_depth", 1) # 使用配置中的默认值而不是硬编码的 1
expand_depth = params.get("expand_depth", self.max_expand_depth)
use_multi_query = params.get("use_multi_query", True) use_multi_query = params.get("use_multi_query", True)
context = params.get("context", None) context = params.get("context", None)
logger.info(f"搜索记忆: {query} (top_k={top_k}, multi_query={use_multi_query})") logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth}, multi_query={use_multi_query})")
# 0. 确保初始化 # 0. 确保初始化
await self._ensure_initialized() await self._ensure_initialized()
@@ -474,9 +478,9 @@ class MemoryTools:
ids = metadata["memory_ids"] ids = metadata["memory_ids"]
# 确保是列表 # 确保是列表
if isinstance(ids, str): if isinstance(ids, str):
import json import orjson
try: try:
ids = json.loads(ids) ids = orjson.loads(ids)
except: except:
ids = [ids] ids = [ids]
if isinstance(ids, list): if isinstance(ids, list):
@@ -631,29 +635,57 @@ class MemoryTools:
request_type="memory.multi_query" request_type="memory.multi_query"
) )
# 获取上下文信息
participants = context.get("participants", []) if context else [] participants = context.get("participants", []) if context else []
prompt = f"""为查询生成3-5个不同角度的搜索语句JSON格式 chat_history = context.get("chat_history", "") if context else ""
sender = context.get("sender", "") if context else ""
**查询:** {query} # 处理聊天历史提取最近5条左右的对话
recent_chat = ""
if chat_history:
lines = chat_history.strip().split('\n')
# 取最近5条消息
recent_lines = lines[-5:] if len(lines) > 5 else lines
recent_chat = '\n'.join(recent_lines)
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句JSON格式
**当前查询:** {query}
**发送者:** {sender if sender else '未知'}
**参与者:** {', '.join(participants) if participants else ''} **参与者:** {', '.join(participants) if participants else ''}
**原则:** 对复杂查询(如"杰瑞喵如何评价新的记忆系统"),应生成: **最近聊天记录最近5条**
1. 完整查询权重1.0 {recent_chat if recent_chat else '无聊天历史'}
2. 每个关键概念独立查询权重0.8- 重要!
3. 主体+动作权重0.6
**输出JSON** **分析原则**
1. **上下文理解**:根据聊天历史理解查询的真实意图
2. **指代消解**:识别并代换"""""""那个"等指代词
3. **话题关联**:结合最近讨论的话题生成更精准的查询
4. **查询分解**:对复杂查询分解为多个子查询
**生成策略:**
1. **完整查询**权重1.0):结合上下文的完整查询,包含指代消解
2. **关键概念查询**权重0.8):查询中的核心概念,特别是聊天中提到的实体
3. **话题扩展查询**权重0.7):基于最近聊天话题的相关查询
4. **动作/情感查询**权重0.6):如果涉及情感或动作,生成相关查询
**输出JSON格式**
```json ```json
{{"queries": [{{"text": "查询1", "weight": 1.0}}, {{"text": "查询2", "weight": 0.8}}]}} {{"queries": [{{"text": "查询语句", "weight": 1.0}}, {{"text": "查询语句", "weight": 0.8}}]}}
```""" ```
**示例:**
- 查询:"他怎么样了?" + 聊天中提到"小明生病了""小明身体恢复情况"
- 查询:"那个项目" + 聊天中讨论"记忆系统开发""记忆系统项目进展"
"""
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250) response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
import json, re import orjson, re
response = re.sub(r'```json\s*', '', response) response = re.sub(r'```json\s*', '', response)
response = re.sub(r'```\s*$', '', response).strip() response = re.sub(r'```\s*$', '', response).strip()
data = json.loads(response) data = orjson.loads(response)
queries = data.get("queries", []) queries = data.get("queries", [])
result = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) result = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
@@ -799,9 +831,9 @@ class MemoryTools:
# 确保是列表 # 确保是列表
if isinstance(ids, str): if isinstance(ids, str):
import json import orjson
try: try:
ids = json.loads(ids) ids = orjson.loads(ids)
except Exception as e: except Exception as e:
logger.warning(f"JSON 解析失败: {e}") logger.warning(f"JSON 解析失败: {e}")
ids = [ids] ids = [ids]
@@ -910,9 +942,9 @@ class MemoryTools:
# 提取记忆ID # 提取记忆ID
neighbor_memory_ids = neighbor_node_data.get("memory_ids", []) neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
if isinstance(neighbor_memory_ids, str): if isinstance(neighbor_memory_ids, str):
import json import orjson
try: try:
neighbor_memory_ids = json.loads(neighbor_memory_ids) neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
except: except:
neighbor_memory_ids = [neighbor_memory_ids] neighbor_memory_ids = [neighbor_memory_ids]

View File

@@ -7,7 +7,7 @@
""" """
import atexit import atexit
import json import orjson
import os import os
import threading import threading
from typing import Any, ClassVar from typing import Any, ClassVar
@@ -100,10 +100,10 @@ class PluginStorage:
if os.path.exists(self.file_path): if os.path.exists(self.file_path):
with open(self.file_path, encoding="utf-8") as f: with open(self.file_path, encoding="utf-8") as f:
content = f.read() content = f.read()
self._data = json.loads(content) if content else {} self._data = orjson.loads(content) if content else {}
else: else:
self._data = {} self._data = {}
except (json.JSONDecodeError, Exception) as e: except (orjson.JSONDecodeError, Exception) as e:
logger.warning(f"'{self.file_path}' 加载数据失败: {e},将初始化为空数据。") logger.warning(f"'{self.file_path}' 加载数据失败: {e},将初始化为空数据。")
self._data = {} self._data = {}
@@ -125,7 +125,7 @@ class PluginStorage:
try: try:
with open(self.file_path, "w", encoding="utf-8") as f: with open(self.file_path, "w", encoding="utf-8") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False) f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8'))
self._dirty = False # 保存后重置标志 self._dirty = False # 保存后重置标志
logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。") logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。")
except Exception as e: except Exception as e:

View File

@@ -5,7 +5,7 @@ MCP Client Manager
""" """
import asyncio import asyncio
import json import orjson
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -89,7 +89,7 @@ class MCPClientManager:
try: try:
with open(self.config_path, encoding="utf-8") as f: with open(self.config_path, encoding="utf-8") as f:
config_data = json.load(f) config_data = orjson.loads(f.read())
servers = {} servers = {}
mcp_servers = config_data.get("mcpServers", {}) mcp_servers = config_data.get("mcpServers", {})
@@ -106,7 +106,7 @@ class MCPClientManager:
logger.info(f"成功加载 {len(servers)} 个 MCP 服务器配置") logger.info(f"成功加载 {len(servers)} 个 MCP 服务器配置")
return servers return servers
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
logger.error(f"解析 MCP 配置文件失败: {e}") logger.error(f"解析 MCP 配置文件失败: {e}")
return {} return {}
except Exception as e: except Exception as e:

View File

@@ -0,0 +1,414 @@
"""
流式工具历史记录管理器
用于在聊天流级别管理工具调用历史,支持智能缓存和上下文感知
"""
import time
from typing import Any, Optional
from dataclasses import dataclass, asdict, field
import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
logger = get_logger("stream_tool_history")
@dataclass
class ToolCallRecord:
"""工具调用记录"""
tool_name: str
args: dict[str, Any]
result: Optional[dict[str, Any]] = None
status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息
def __post_init__(self):
"""后处理:生成结果预览"""
if self.result and not self.result_preview:
content = self.result.get("content", "")
if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)):
try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..."
except Exception:
self.result_preview = str(content)[:500] + "..."
else:
self.result_preview = str(content)[:500] + "..."
class StreamToolHistoryManager:
"""流式工具历史记录管理器
提供以下功能:
1. 工具调用历史的持久化管理
2. 智能缓存集成和结果去重
3. 上下文感知的历史记录检索
4. 性能监控和统计
"""
def __init__(self, chat_id: str, max_history: int = 20, enable_memory_cache: bool = True):
"""初始化历史记录管理器
Args:
chat_id: 聊天ID用于隔离不同聊天流的历史记录
max_history: 最大历史记录数量
enable_memory_cache: 是否启用内存缓存
"""
self.chat_id = chat_id
self.max_history = max_history
self.enable_memory_cache = enable_memory_cache
# 内存中的历史记录,按时间顺序排列
self._history: list[ToolCallRecord] = []
# 性能统计
self._stats = {
"total_calls": 0,
"cache_hits": 0,
"cache_misses": 0,
"total_execution_time": 0.0,
"average_execution_time": 0.0,
}
logger.info(f"[{chat_id}] 工具历史记录管理器初始化完成,最大历史: {max_history}")
async def add_tool_call(self, record: ToolCallRecord) -> None:
"""添加工具调用记录
Args:
record: 工具调用记录
"""
# 维护历史记录大小
if len(self._history) >= self.max_history:
# 移除最旧的记录
removed_record = self._history.pop(0)
logger.debug(f"[{self.chat_id}] 移除旧记录: {removed_record.tool_name}")
# 添加新记录
self._history.append(record)
# 更新统计
self._stats["total_calls"] += 1
if record.cache_hit:
self._stats["cache_hits"] += 1
else:
self._stats["cache_misses"] += 1
if record.execution_time is not None:
self._stats["total_execution_time"] += record.execution_time
self._stats["average_execution_time"] = self._stats["total_execution_time"] / self._stats["total_calls"]
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
"""从缓存或历史记录中获取结果
Args:
tool_name: 工具名称
args: 工具参数
Returns:
缓存的结果如果不存在则返回None
"""
# 首先检查内存中的历史记录
if self.enable_memory_cache:
memory_result = self._search_memory_cache(tool_name, args)
if memory_result:
logger.info(f"[{self.chat_id}] 内存缓存命中: {tool_name}")
return memory_result
# 然后检查全局缓存系统
try:
# 这里需要工具实例来获取文件路径,但为了解耦,我们先尝试从历史记录中推断
tool_file_path = self._infer_tool_path(tool_name)
# 尝试语义缓存(如果可以推断出语义查询参数)
semantic_query = self._extract_semantic_query(tool_name, args)
cached_result = await tool_cache.get(
tool_name=tool_name,
function_args=args,
tool_file_path=tool_file_path,
semantic_query=semantic_query,
)
if cached_result:
logger.info(f"[{self.chat_id}] 全局缓存命中: {tool_name}")
# 将结果同步到内存缓存
if self.enable_memory_cache:
record = ToolCallRecord(
tool_name=tool_name,
args=args,
result=cached_result,
status="success",
cache_hit=True,
timestamp=time.time(),
)
await self.add_tool_call(record)
return cached_result
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存查询失败: {e}")
return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None,
tool_file_path: Optional[str] = None,
ttl: Optional[int] = None) -> None:
"""缓存工具调用结果
Args:
tool_name: 工具名称
args: 工具参数
result: 执行结果
execution_time: 执行耗时
tool_file_path: 工具文件路径
ttl: 缓存TTL
"""
# 添加到内存历史记录
record = ToolCallRecord(
tool_name=tool_name,
args=args,
result=result,
status="success",
execution_time=execution_time,
cache_hit=False,
timestamp=time.time(),
)
await self.add_tool_call(record)
# 同步到全局缓存系统
try:
if tool_file_path is None:
tool_file_path = self._infer_tool_path(tool_name)
# 尝试语义缓存
semantic_query = self._extract_semantic_query(tool_name, args)
await tool_cache.set(
tool_name=tool_name,
function_args=args,
tool_file_path=tool_file_path,
data=result,
ttl=ttl,
semantic_query=semantic_query,
)
logger.debug(f"[{self.chat_id}] 结果已缓存: {tool_name}")
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]:
"""获取最近的历史记录
Args:
count: 返回的记录数量
status_filter: 状态过滤器可选值success, error, pending
Returns:
历史记录列表
"""
history = self._history.copy()
# 应用状态过滤
if status_filter:
history = [record for record in history if record.status == status_filter]
# 返回最近的记录
return history[-count:] if history else []
def format_for_prompt(self, max_records: int = 5, include_results: bool = True) -> str:
"""格式化历史记录为提示词
Args:
max_records: 最大记录数量
include_results: 是否包含结果预览
Returns:
格式化的提示词字符串
"""
if not self._history:
return ""
recent_records = self._history[-max_records:]
lines = ["## 🔧 最近工具调用记录"]
for i, record in enumerate(recent_records, 1):
status_icon = "" if record.status == "success" else "" if record.status == "error" else ""
# 格式化参数
args_preview = self._format_args_preview(record.args)
# 基础信息
lines.append(f"{i}. {status_icon} **{record.tool_name}**({args_preview})")
# 添加执行时间和缓存信息
if record.execution_time is not None:
time_info = f"{record.execution_time:.2f}s"
cache_info = "🎯缓存" if record.cache_hit else "🔍执行"
lines.append(f" ⏱️ {time_info} | {cache_info}")
# 添加结果预览
if include_results and record.result_preview:
lines.append(f" 📝 结果: {record.result_preview}")
# 添加错误信息
if record.status == "error" and record.error_message:
lines.append(f" ❌ 错误: {record.error_message}")
# 添加统计信息
if self._stats["total_calls"] > 0:
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
avg_time = self._stats["average_execution_time"]
lines.append(f"\n📊 工具统计: 总计{self._stats['total_calls']}次 | 缓存命中率{cache_hit_rate:.1f}% | 平均耗时{avg_time:.2f}s")
return "\n".join(lines)
def get_stats(self) -> dict[str, Any]:
"""获取性能统计信息
Returns:
统计信息字典
"""
cache_hit_rate = 0.0
if self._stats["total_calls"] > 0:
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
return {
**self._stats,
"cache_hit_rate": cache_hit_rate,
"history_size": len(self._history),
"chat_id": self.chat_id,
}
def clear_history(self) -> None:
"""清除历史记录"""
self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
"""在内存历史记录中搜索缓存
Args:
tool_name: 工具名称
args: 工具参数
Returns:
匹配的结果如果不存在则返回None
"""
for record in reversed(self._history): # 从最新的开始搜索
if (record.tool_name == tool_name and
record.status == "success" and
record.args == args):
return record.result
return None
def _infer_tool_path(self, tool_name: str) -> str:
"""推断工具文件路径
Args:
tool_name: 工具名称
Returns:
推断的文件路径
"""
# 基于工具名称推断路径,这是一个简化的实现
# 在实际使用中,可能需要更复杂的映射逻辑
tool_path_mapping = {
"web_search": "src/plugins/built_in/web_search_tool/tools/web_search.py",
"memory_create": "src/memory_graph/tools/memory_tools.py",
"memory_search": "src/memory_graph/tools/memory_tools.py",
"user_profile_update": "src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py",
"chat_stream_impression_update": "src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py",
}
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]:
"""提取语义查询参数
Args:
tool_name: 工具名称
args: 工具参数
Returns:
语义查询字符串如果不存在则返回None
"""
# 为不同工具定义语义查询参数映射
semantic_query_mapping = {
"web_search": "query",
"memory_search": "query",
"knowledge_search": "query",
}
query_key = semantic_query_mapping.get(tool_name)
if query_key and query_key in args:
return str(args[query_key])
return None
def _format_args_preview(self, args: dict[str, Any], max_length: int = 100) -> str:
"""格式化参数预览
Args:
args: 参数字典
max_length: 最大长度
Returns:
格式化的参数预览字符串
"""
if not args:
return ""
try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
if len(args_str) > max_length:
args_str = args_str[:max_length] + "..."
return args_str
except Exception:
# 如果序列化失败,使用简单格式
parts = []
for k, v in list(args.items())[:3]: # 最多显示3个参数
parts.append(f"{k}={str(v)[:20]}")
result = ", ".join(parts)
if len(parts) >= 3 or len(result) > max_length:
result += "..."
return result
# 全局管理器字典按chat_id索引
_stream_managers: dict[str, StreamToolHistoryManager] = {}
def get_stream_tool_history_manager(chat_id: str) -> StreamToolHistoryManager:
"""获取指定聊天的工具历史记录管理器
Args:
chat_id: 聊天ID
Returns:
工具历史记录管理器实例
"""
if chat_id not in _stream_managers:
_stream_managers[chat_id] = StreamToolHistoryManager(chat_id)
return _stream_managers[chat_id]
def cleanup_stream_manager(chat_id: str) -> None:
"""清理指定聊天的管理器
Args:
chat_id: 聊天ID
"""
if chat_id in _stream_managers:
del _stream_managers[chat_id]
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")

View File

@@ -3,7 +3,6 @@ import time
from typing import Any from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.payload_content import ToolCall from src.llm_models.payload_content import ToolCall
@@ -11,6 +10,8 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord
from dataclasses import asdict
logger = get_logger("tool_use") logger = get_logger("tool_use")
@@ -36,15 +37,29 @@ def init_tool_executor_prompt():
{tool_history} {tool_history}
## 🔧 工具使用 ## 🔧 工具决策指南
根据上下文判断是否需要使用工具。每个工具都有详细的description说明其用途和参数请根据工具定义决定是否调用。 **核心原则:**
- 根据上下文智能判断是否需要使用工具
- 每个工具都有详细的description说明其用途和参数
- 避免重复调用历史记录中已执行的工具(除非参数不同)
- 优先考虑使用已有的缓存结果,避免重复调用
**历史记录说明:**
- 上方显示的是**之前**的工具调用记录
- 请参考历史记录避免重复调用相同参数的工具
- 如果历史记录中已有相关结果,可以考虑直接回答而不调用工具
**⚠️ 记忆创建特别提醒:** **⚠️ 记忆创建特别提醒:**
创建记忆时subject主体必须使用对话历史中显示的**真实发送人名字** 创建记忆时subject主体必须使用对话历史中显示的**真实发送人名字**
- ✅ 正确:从"Prou(12345678): ..."中提取"Prou"作为subject - ✅ 正确:从"Prou(12345678): ..."中提取"Prou"作为subject
- ❌ 错误:使用"用户""对方"等泛指词 - ❌ 错误:使用"用户""对方"等泛指词
**工具调用策略:**
1. **避免重复调用**:查看历史记录,如果最近已调用过相同工具且参数一致,无需重复调用
2. **智能选择工具**:根据消息内容选择最合适的工具,避免过度使用
3. **参数优化**:确保工具参数简洁有效,避免冗余信息
**执行指令:** **执行指令:**
- 需要使用工具 → 直接调用相应的工具函数 - 需要使用工具 → 直接调用相应的工具函数
- 不需要工具 → 输出 "No tool needed" - 不需要工具 → 输出 "No tool needed"
@@ -81,9 +96,8 @@ class ToolExecutor:
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}""" """待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
self._log_prefix_initialized = False self._log_prefix_initialized = False
# 工具调用历史 # 流式工具历史记录管理器
self.tool_call_history: list[dict[str, Any]] = [] self.history_manager = get_stream_tool_history_manager(chat_id)
"""工具调用历史,包含工具名称、参数和结果"""
# logger.info(f"{self.log_prefix}工具执行器初始化完成") # 移到异步初始化中 # logger.info(f"{self.log_prefix}工具执行器初始化完成") # 移到异步初始化中
@@ -125,7 +139,7 @@ class ToolExecutor:
bot_name = global_config.bot.nickname bot_name = global_config.bot.nickname
# 构建工具调用历史文本 # 构建工具调用历史文本
tool_history = self._format_tool_history() tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
# 获取人设信息 # 获取人设信息
personality_core = global_config.personality.personality_core personality_core = global_config.personality.personality_core
@@ -183,82 +197,6 @@ class ToolExecutor:
return tool_definitions return tool_definitions
def _format_tool_history(self, max_history: int = 5) -> str:
"""格式化工具调用历史为文本
Args:
max_history: 最多显示的历史记录数量
Returns:
格式化的工具历史文本
"""
if not self.tool_call_history:
return ""
# 只取最近的几条历史
recent_history = self.tool_call_history[-max_history:]
history_lines = ["历史工具调用记录:"]
for i, record in enumerate(recent_history, 1):
tool_name = record.get("tool_name", "unknown")
args = record.get("args", {})
result_preview = record.get("result_preview", "")
status = record.get("status", "success")
# 格式化参数
args_str = ", ".join([f"{k}={v}" for k, v in args.items()])
# 格式化记录
status_emoji = "" if status == "success" else ""
history_lines.append(f"{i}. {status_emoji} {tool_name}({args_str})")
if result_preview:
# 限制结果预览长度
if len(result_preview) > 200:
result_preview = result_preview[:200] + "..."
history_lines.append(f" 结果: {result_preview}")
return "\n".join(history_lines)
def _add_tool_to_history(self, tool_name: str, args: dict, result: dict | None, status: str = "success"):
"""添加工具调用到历史记录
Args:
tool_name: 工具名称
args: 工具参数
result: 工具结果
status: 执行状态 (success/error)
"""
# 生成结果预览
result_preview = ""
if result:
content = result.get("content", "")
if isinstance(content, str):
result_preview = content
elif isinstance(content, list | dict):
import json
try:
result_preview = json.dumps(content, ensure_ascii=False)
except Exception:
result_preview = str(content)
else:
result_preview = str(content)
record = {
"tool_name": tool_name,
"args": args,
"result_preview": result_preview,
"status": status,
"timestamp": time.time(),
}
self.tool_call_history.append(record)
# 限制历史记录数量,避免内存溢出
max_history_size = 5
if len(self.tool_call_history) > max_history_size:
self.tool_call_history = self.tool_call_history[-max_history_size:]
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]: async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用 """执行工具调用
@@ -320,10 +258,20 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...") logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
# 记录到历史 # 记录到历史
self._add_tool_to_history(tool_name, tool_args, result, status="success") await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=result,
status="success"
))
else: else:
# 工具返回空结果也记录到历史 # 工具返回空结果也记录到历史
self._add_tool_to_history(tool_name, tool_args, None, status="success") await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="success"
))
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}") logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
@@ -338,62 +286,72 @@ class ToolExecutor:
tool_results.append(error_info) tool_results.append(error_info)
# 记录失败到历史 # 记录失败到历史
self._add_tool_to_history(tool_name, tool_args, None, status="error") await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="error",
error_message=str(e)
))
return tool_results, used_tools return tool_results, used_tools
async def execute_tool_call( async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: BaseTool | None = None self, tool_call: ToolCall, tool_instance: BaseTool | None = None
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""执行单个工具调用,并处理缓存""" """执行单个工具调用,集成流式历史记录管理器"""
start_time = time.time()
function_args = tool_call.args or {} function_args = tool_call.args or {}
tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream) tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream)
# 如果工具不存在或未启用缓存,则直接执行 # 尝试从历史记录管理器获取缓存结果
if not tool_instance or not tool_instance.enable_cache: if tool_instance and tool_instance.enable_cache:
return await self._original_execute_tool_call(tool_call, tool_instance)
# --- 缓存逻辑开始 ---
try: try:
tool_file_path = inspect.getfile(tool_instance.__class__) cached_result = await self.history_manager.get_cached_result(
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
cached_result = await tool_cache.get(
tool_name=tool_call.func_name, tool_name=tool_call.func_name,
function_args=function_args, args=function_args
tool_file_path=tool_file_path,
semantic_query=semantic_query,
) )
if cached_result: if cached_result:
execution_time = time.time() - start_time
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行") logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
# 记录缓存命中到历史
await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_call.func_name,
args=function_args,
result=cached_result,
status="success",
execution_time=execution_time,
cache_hit=True
))
return cached_result return cached_result
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}") logger.error(f"{self.log_prefix}检查历史缓存时出错: {e}")
# 缓存未命中,执行原始工具调用 # 缓存未命中,执行工具调用
result = await self._original_execute_tool_call(tool_call, tool_instance) result = await self._original_execute_tool_call(tool_call, tool_instance)
# 将结果存入缓存 # 记录执行结果到历史管理器
execution_time = time.time() - start_time
if tool_instance and result and tool_instance.enable_cache:
try: try:
tool_file_path = inspect.getfile(tool_instance.__class__) tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None semantic_query = None
if tool_instance.semantic_cache_query_key: if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key) semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set( await self.history_manager.cache_result(
tool_name=tool_call.func_name, tool_name=tool_call.func_name,
function_args=function_args, args=function_args,
result=result,
execution_time=execution_time,
tool_file_path=tool_file_path, tool_file_path=tool_file_path,
data=result, ttl=tool_instance.cache_ttl
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query,
) )
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}") logger.error(f"{self.log_prefix}缓存结果到历史管理器时出错: {e}")
# --- 缓存逻辑结束 ---
return result return result
@@ -528,21 +486,31 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}") logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
# 记录到历史 # 记录到历史
self._add_tool_to_history(tool_name, tool_args, result, status="success") await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=result,
status="success"
))
return tool_info return tool_info
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}") logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
# 记录失败到历史 # 记录失败到历史
self._add_tool_to_history(tool_name, tool_args, None, status="error") await self.history_manager.add_tool_call(ToolCallRecord(
tool_name=tool_name,
args=tool_args,
result=None,
status="error",
error_message=str(e)
))
return None return None
def clear_tool_history(self): def clear_tool_history(self):
"""清除工具调用历史""" """清除工具调用历史"""
self.tool_call_history.clear() self.history_manager.clear_history()
logger.debug(f"{self.log_prefix}已清除工具调用历史")
def get_tool_history(self) -> list[dict[str, Any]]: def get_tool_history(self) -> list[dict[str, Any]]:
"""获取工具调用历史 """获取工具调用历史
@@ -550,7 +518,17 @@ class ToolExecutor:
Returns: Returns:
工具调用历史列表 工具调用历史列表
""" """
return self.tool_call_history.copy() # 返回最近的历史记录
records = self.history_manager.get_recent_history(count=10)
return [asdict(record) for record in records]
def get_tool_stats(self) -> dict[str, Any]:
"""获取工具统计信息
Returns:
工具统计信息字典
"""
return self.history_manager.get_stats()
""" """

View File

@@ -652,7 +652,7 @@ class ChatterPlanFilter:
enhanced_memories = await memory_manager.search_memories( enhanced_memories = await memory_manager.search_memories(
query=query, query=query,
top_k=5, top_k=5,
optimize_query=False, # 直接使用关键词查询 use_multi_query=False, # 直接使用关键词查询
) )
if not enhanced_memories: if not enhanced_memories:

View File

@@ -3,7 +3,7 @@
当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复 当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复
""" """
import json import orjson
from datetime import datetime from datetime import datetime
from typing import Any, Literal from typing import Any, Literal

View File

@@ -3,7 +3,7 @@
负责记录和管理已回复过的评论ID避免重复回复 负责记录和管理已回复过的评论ID避免重复回复
""" """
import json import orjson
import time import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -71,7 +71,7 @@ class ReplyTrackerService:
self.replied_comments = {} self.replied_comments = {}
return return
data = json.loads(file_content) data = orjson.loads(file_content)
if self._validate_data(data): if self._validate_data(data):
self.replied_comments = data self.replied_comments = data
logger.info( logger.info(
@@ -81,7 +81,7 @@ class ReplyTrackerService:
else: else:
logger.error("加载的数据格式无效,将创建新的记录") logger.error("加载的数据格式无效,将创建新的记录")
self.replied_comments = {} self.replied_comments = {}
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
logger.error(f"解析回复记录文件失败: {e}") logger.error(f"解析回复记录文件失败: {e}")
self._backup_corrupted_file() self._backup_corrupted_file()
self.replied_comments = {} self.replied_comments = {}
@@ -118,7 +118,7 @@ class ReplyTrackerService:
# 先写入临时文件 # 先写入临时文件
with open(temp_file, "w", encoding="utf-8") as f: with open(temp_file, "w", encoding="utf-8") as f:
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2) orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')
# 如果写入成功,重命名为正式文件 # 如果写入成功,重命名为正式文件
if temp_file.stat().st_size > 0: # 确保写入成功 if temp_file.stat().st_size > 0: # 确保写入成功

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
import inspect import inspect
import json import orjson
from typing import ClassVar, List from typing import ClassVar, List
import websockets as Server import websockets as Server
@@ -44,10 +44,10 @@ async def message_recv(server_connection: Server.ServerConnection):
# 只在debug模式下记录原始消息 # 只在debug模式下记录原始消息
if logger.level <= 10: # DEBUG level if logger.level <= 10: # DEBUG level
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message) logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
decoded_raw_message: dict = json.loads(raw_message) decoded_raw_message: dict = orjson.loads(raw_message)
try: try:
# 首先尝试解析原始消息 # 首先尝试解析原始消息
decoded_raw_message: dict = json.loads(raw_message) decoded_raw_message: dict = orjson.loads(raw_message)
# 检查是否是切片消息 (来自 MMC) # 检查是否是切片消息 (来自 MMC)
if chunker.is_chunk_message(decoded_raw_message): if chunker.is_chunk_message(decoded_raw_message):
@@ -71,7 +71,7 @@ async def message_recv(server_connection: Server.ServerConnection):
elif post_type is None: elif post_type is None:
await put_response(decoded_raw_message) await put_response(decoded_raw_message)
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
logger.error(f"消息解析失败: {e}") logger.error(f"消息解析失败: {e}")
logger.debug(f"原始消息: {raw_message[:500]}...") logger.debug(f"原始消息: {raw_message[:500]}...")
except Exception as e: except Exception as e:

View File

@@ -5,7 +5,7 @@
""" """
import asyncio import asyncio
import json import orjson
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@@ -34,7 +34,7 @@ class MessageChunker:
"""判断消息是否需要切片""" """判断消息是否需要切片"""
try: try:
if isinstance(message, dict): if isinstance(message, dict):
message_str = json.dumps(message, ensure_ascii=False) message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
else: else:
message_str = message message_str = message
return len(message_str.encode("utf-8")) > self.max_chunk_size return len(message_str.encode("utf-8")) > self.max_chunk_size
@@ -58,7 +58,7 @@ class MessageChunker:
try: try:
# 统一转换为字符串 # 统一转换为字符串
if isinstance(message, dict): if isinstance(message, dict):
message_str = json.dumps(message, ensure_ascii=False) message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
else: else:
message_str = message message_str = message
@@ -116,7 +116,7 @@ class MessageChunker:
"""判断是否是切片消息""" """判断是否是切片消息"""
try: try:
if isinstance(message, str): if isinstance(message, str):
data = json.loads(message) data = orjson.loads(message)
else: else:
data = message data = message
@@ -126,7 +126,7 @@ class MessageChunker:
and "__mmc_chunk_data__" in data and "__mmc_chunk_data__" in data
and "__mmc_is_chunked__" in data and "__mmc_is_chunked__" in data
) )
except (json.JSONDecodeError, TypeError): except (orjson.JSONDecodeError, TypeError):
return False return False
@@ -187,7 +187,7 @@ class MessageReassembler:
try: try:
# 统一转换为字典 # 统一转换为字典
if isinstance(message, str): if isinstance(message, str):
chunk_data = json.loads(message) chunk_data = orjson.loads(message)
else: else:
chunk_data = message chunk_data = message
@@ -197,8 +197,8 @@ class MessageReassembler:
if "_original_message" in chunk_data: if "_original_message" in chunk_data:
# 这是一个被包装的非切片消息,解包返回 # 这是一个被包装的非切片消息,解包返回
try: try:
return json.loads(chunk_data["_original_message"]) return orjson.loads(chunk_data["_original_message"])
except json.JSONDecodeError: except orjson.JSONDecodeError:
return {"text_message": chunk_data["_original_message"]} return {"text_message": chunk_data["_original_message"]}
else: else:
return chunk_data return chunk_data
@@ -251,14 +251,14 @@ class MessageReassembler:
# 尝试反序列化重组后的消息 # 尝试反序列化重组后的消息
try: try:
return json.loads(reassembled_message) return orjson.loads(reassembled_message)
except json.JSONDecodeError: except orjson.JSONDecodeError:
# 如果不能反序列化为JSON则作为文本消息返回 # 如果不能反序列化为JSON则作为文本消息返回
return {"text_message": reassembled_message} return {"text_message": reassembled_message}
return None return None
except (json.JSONDecodeError, KeyError, TypeError) as e: except (orjson.JSONDecodeError, KeyError, TypeError) as e:
logger.error(f"处理切片消息时出错: {e}") logger.error(f"处理切片消息时出错: {e}")
return None return None

View File

@@ -1,5 +1,5 @@
import base64 import base64
import json import orjson
import time import time
import uuid import uuid
from pathlib import Path from pathlib import Path
@@ -783,7 +783,7 @@ class MessageHandler:
# 检查JSON消息格式 # 检查JSON消息格式
if not message_data or "data" not in message_data: if not message_data or "data" not in message_data:
logger.warning("JSON消息格式不正确") logger.warning("JSON消息格式不正确")
return Seg(type="json", data=json.dumps(message_data)) return Seg(type="json", data=orjson.dumps(message_data).decode('utf-8'))
try: try:
# 尝试将json_data解析为Python对象 # 尝试将json_data解析为Python对象
@@ -1146,13 +1146,13 @@ class MessageHandler:
return None return None
forward_message_id = forward_message_data.get("id") forward_message_id = forward_message_data.get("id")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps( payload = orjson.dumps(
{ {
"action": "get_forward_msg", "action": "get_forward_msg",
"params": {"message_id": forward_message_id}, "params": {"message_id": forward_message_id},
"echo": request_uuid, "echo": request_uuid,
} }
) ).decode('utf-8')
try: try:
connection = self.get_server_connection() connection = self.get_server_connection()
if not connection: if not connection:
@@ -1167,9 +1167,9 @@ class MessageHandler:
logger.error(f"获取转发消息失败: {str(e)}") logger.error(f"获取转发消息失败: {str(e)}")
return None return None
logger.debug( logger.debug(
f"转发消息原始格式:{json.dumps(response)[:80]}..." f"转发消息原始格式:{orjson.dumps(response).decode('utf-8')[:80]}..."
if len(json.dumps(response)) > 80 if len(orjson.dumps(response).decode('utf-8')) > 80
else json.dumps(response) else orjson.dumps(response).decode('utf-8')
) )
response_data: Dict = response.get("data") response_data: Dict = response.get("data")
if not response_data: if not response_data:

View File

@@ -1,5 +1,5 @@
import asyncio import asyncio
import json import orjson
import time import time
from typing import ClassVar, Optional, Tuple from typing import ClassVar, Optional, Tuple
@@ -241,7 +241,7 @@ class NoticeHandler:
message_base: MessageBase = MessageBase( message_base: MessageBase = MessageBase(
message_info=message_info, message_info=message_info,
message_segment=handled_message, message_segment=handled_message,
raw_message=json.dumps(raw_message), raw_message=orjson.dumps(raw_message).decode('utf-8'),
) )
if system_notice: if system_notice:
@@ -602,7 +602,7 @@ class NoticeHandler:
message_base: MessageBase = MessageBase( message_base: MessageBase = MessageBase(
message_info=message_info, message_info=message_info,
message_segment=seg_message, message_segment=seg_message,
raw_message=json.dumps( raw_message=orjson.dumps(
{ {
"post_type": "notice", "post_type": "notice",
"notice_type": "group_ban", "notice_type": "group_ban",
@@ -611,7 +611,7 @@ class NoticeHandler:
"user_id": user_id, "user_id": user_id,
"operator_id": None, # 自然解除禁言没有操作者 "operator_id": None, # 自然解除禁言没有操作者
} }
), ).decode('utf-8'),
) )
await self.put_notice(message_base) await self.put_notice(message_base)

View File

@@ -1,4 +1,4 @@
import json import orjson
import random import random
import time import time
import uuid import uuid
@@ -605,7 +605,7 @@ class SendHandler:
async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict: async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict:
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": action, "params": params, "echo": request_uuid}) payload = orjson.dumps({"action": action, "params": params, "echo": request_uuid}).decode('utf-8')
# 获取当前连接 # 获取当前连接
connection = self.get_server_connection() connection = self.get_server_connection()

View File

@@ -1,6 +1,6 @@
import base64 import base64
import io import io
import json import orjson
import ssl import ssl
import uuid import uuid
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -34,7 +34,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
""" """
logger.debug("获取群聊信息中") logger.debug("获取群聊信息中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}) payload = orjson.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
socket_response: dict = await get_response(request_uuid) socket_response: dict = await get_response(request_uuid)
@@ -56,7 +56,7 @@ async def get_group_detail_info(websocket: Server.ServerConnection, group_id: in
""" """
logger.debug("获取群详细信息中") logger.debug("获取群详细信息中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}) payload = orjson.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
socket_response: dict = await get_response(request_uuid) socket_response: dict = await get_response(request_uuid)
@@ -78,13 +78,13 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use
""" """
logger.debug("获取群成员信息中") logger.debug("获取群成员信息中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps( payload = orjson.dumps(
{ {
"action": "get_group_member_info", "action": "get_group_member_info",
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True}, "params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
"echo": request_uuid, "echo": request_uuid,
} }
) ).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
socket_response: dict = await get_response(request_uuid) socket_response: dict = await get_response(request_uuid)
@@ -146,7 +146,7 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
""" """
logger.debug("获取自身信息中") logger.debug("获取自身信息中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}) payload = orjson.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
response: dict = await get_response(request_uuid) response: dict = await get_response(request_uuid)
@@ -183,7 +183,7 @@ async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) ->
""" """
logger.debug("获取陌生人信息中") logger.debug("获取陌生人信息中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid}) payload = orjson.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid}).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
response: dict = await get_response(request_uuid) response: dict = await get_response(request_uuid)
@@ -208,7 +208,7 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: Uni
""" """
logger.debug("获取消息详情中") logger.debug("获取消息详情中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}) payload = orjson.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
@@ -236,13 +236,13 @@ async def get_record_detail(
""" """
logger.debug("获取语音消息详情中") logger.debug("获取语音消息详情中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps( payload = orjson.dumps(
{ {
"action": "get_record", "action": "get_record",
"params": {"file": file, "file_id": file_id, "out_format": "wav"}, "params": {"file": file, "file_id": file_id, "out_format": "wav"},
"echo": request_uuid, "echo": request_uuid,
} }
) ).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒

View File

@@ -39,15 +39,23 @@ class ExaSearchEngine(BaseSearchEngine):
return self.api_manager.is_available() return self.api_manager.is_available()
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]: async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
"""执行Exa搜索""" """执行优化的Exa搜索使用answer模式"""
if not self.is_available(): if not self.is_available():
return [] return []
query = args["query"] query = args["query"]
num_results = args.get("num_results", 3) num_results = min(args.get("num_results", 5), 5) # 默认5个结果但限制最多5个
time_range = args.get("time_range", "any") time_range = args.get("time_range", "any")
exa_args = {"num_results": num_results, "text": True, "highlights": True} # 优化的搜索参数 - 更注重答案质量
exa_args = {
"num_results": num_results,
"text": True,
"highlights": True,
"summary": True, # 启用自动摘要
}
# 时间范围过滤
if time_range != "any": if time_range != "any":
today = datetime.now() today = datetime.now()
start_date = today - timedelta(days=7 if time_range == "week" else 30) start_date = today - timedelta(days=7 if time_range == "week" else 30)
@@ -61,18 +69,89 @@ class ExaSearchEngine(BaseSearchEngine):
return [] return []
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# 使用search_and_contents获取完整内容优化为answer模式
func = functools.partial(exa_client.search_and_contents, query, **exa_args) func = functools.partial(exa_client.search_and_contents, query, **exa_args)
search_response = await loop.run_in_executor(None, func) search_response = await loop.run_in_executor(None, func)
return [ # 优化结果处理 - 更注重答案质量
{ results = []
for res in search_response.results:
# 获取最佳内容片段
highlights = getattr(res, "highlights", [])
summary = getattr(res, "summary", "")
text = getattr(res, "text", "")
# 智能内容选择:摘要 > 高亮 > 文本开头
if summary and len(summary) > 50:
snippet = summary.strip()
elif highlights:
snippet = " ".join(highlights).strip()
elif text:
snippet = text[:300] + "..." if len(text) > 300 else text
else:
snippet = "内容获取失败"
# 只保留有意义的摘要
if len(snippet) < 30:
snippet = text[:200] + "..." if text and len(text) > 200 else snippet
results.append({
"title": res.title, "title": res.title,
"url": res.url, "url": res.url,
"snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."), "snippet": snippet,
"provider": "Exa", "provider": "Exa",
} "answer_focused": True, # 标记为答案导向的搜索
for res in search_response.results })
]
return results
except Exception as e: except Exception as e:
logger.error(f"Exa 搜索失败: {e}") logger.error(f"Exa answer模式搜索失败: {e}")
return []
async def answer_search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
"""执行Exa快速答案搜索 - 最精简的搜索模式"""
if not self.is_available():
return []
query = args["query"]
num_results = min(args.get("num_results", 3), 3) # answer模式默认3个结果专注质量
# 精简的搜索参数 - 专注快速答案
exa_args = {
"num_results": num_results,
"text": False, # 不需要全文
"highlights": True, # 只要关键高亮
"summary": True, # 优先摘要
}
try:
exa_client = self.api_manager.get_next_client()
if not exa_client:
return []
loop = asyncio.get_running_loop()
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
search_response = await loop.run_in_executor(None, func)
# 极简结果处理 - 只保留最核心信息
results = []
for res in search_response.results:
summary = getattr(res, "summary", "")
highlights = getattr(res, "highlights", [])
# 优先使用摘要,否则使用高亮
answer_text = summary.strip() if summary and len(summary) > 30 else " ".join(highlights).strip()
if answer_text and len(answer_text) > 20:
results.append({
"title": res.title,
"url": res.url,
"snippet": answer_text[:400] + "..." if len(answer_text) > 400 else answer_text,
"provider": "Exa-Answer",
"answer_mode": True # 标记为纯答案模式
})
return results
except Exception as e:
logger.error(f"Exa快速答案搜索失败: {e}")
return [] return []

View File

@@ -1,7 +1,7 @@
""" """
Metaso Search Engine (Chat Completions Mode) Metaso Search Engine (Chat Completions Mode)
""" """
import json import orjson
from typing import Any from typing import Any
import httpx import httpx
@@ -43,12 +43,12 @@ class MetasoClient:
if data_str == "[DONE]": if data_str == "[DONE]":
break break
try: try:
data = json.loads(data_str) data = orjson.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {}) delta = data.get("choices", [{}])[0].get("delta", {})
content_chunk = delta.get("content") content_chunk = delta.get("content")
if content_chunk: if content_chunk:
full_response_content += content_chunk full_response_content += content_chunk
except json.JSONDecodeError: except orjson.JSONDecodeError:
logger.warning(f"Metaso stream: could not decode JSON line: {data_str}") logger.warning(f"Metaso stream: could not decode JSON line: {data_str}")
continue continue

View File

@@ -41,6 +41,13 @@ class WebSurfingTool(BaseTool):
False, False,
["any", "week", "month"], ["any", "week", "month"],
), ),
(
"answer_mode",
ToolParamType.BOOLEAN,
"是否启用答案模式仅适用于Exa搜索引擎。启用后将返回更精简、直接的答案减少冗余信息。默认为False。",
False,
None,
),
] # type: ignore ] # type: ignore
def __init__(self, plugin_config=None, chat_stream=None): def __init__(self, plugin_config=None, chat_stream=None):
@@ -97,12 +104,18 @@ class WebSurfingTool(BaseTool):
) -> dict[str, Any]: ) -> dict[str, Any]:
"""并行搜索策略:同时使用所有启用的搜索引擎""" """并行搜索策略:同时使用所有启用的搜索引擎"""
search_tasks = [] search_tasks = []
answer_mode = function_args.get("answer_mode", False)
for engine_name in enabled_engines: for engine_name in enabled_engines:
engine = self.engines.get(engine_name) engine = self.engines.get(engine_name)
if engine and engine.is_available(): if engine and engine.is_available():
custom_args = function_args.copy() custom_args = function_args.copy()
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
search_tasks.append(engine.answer_search(custom_args))
else:
search_tasks.append(engine.search(custom_args)) search_tasks.append(engine.search(custom_args))
if not search_tasks: if not search_tasks:
@@ -137,16 +150,22 @@ class WebSurfingTool(BaseTool):
self, function_args: dict[str, Any], enabled_engines: list[str] self, function_args: dict[str, Any], enabled_engines: list[str]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个""" """回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
answer_mode = function_args.get("answer_mode", False)
for engine_name in enabled_engines: for engine_name in enabled_engines:
engine = self.engines.get(engine_name) engine = self.engines.get(engine_name)
if not engine or not engine.is_available(): if not engine or not engine.is_available():
continue continue
try: try:
custom_args = function_args.copy() custom_args = function_args.copy()
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
logger.info("使用Exa答案模式进行搜索fallback策略")
results = await engine.answer_search(custom_args)
else:
results = await engine.search(custom_args) results = await engine.search(custom_args)
if results: # 如果有结果,直接返回 if results: # 如果有结果,直接返回
@@ -164,17 +183,25 @@ class WebSurfingTool(BaseTool):
async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]: async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]:
"""单一搜索策略:只使用第一个可用的搜索引擎""" """单一搜索策略:只使用第一个可用的搜索引擎"""
answer_mode = function_args.get("answer_mode", False)
for engine_name in enabled_engines: for engine_name in enabled_engines:
engine = self.engines.get(engine_name) engine = self.engines.get(engine_name)
if not engine or not engine.is_available(): if not engine or not engine.is_available():
continue continue
try: try:
custom_args = function_args.copy() custom_args = function_args.copy()
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
logger.info("使用Exa答案模式进行搜索")
results = await engine.answer_search(custom_args)
else:
results = await engine.search(custom_args) results = await engine.search(custom_args)
if results:
formatted_content = format_search_results(results) formatted_content = format_search_results(results)
return { return {
"type": "web_search_result", "type": "web_search_result",

View File

@@ -5,7 +5,7 @@
""" """
import asyncio import asyncio
import json import orjson
import logging import logging
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path

View File

@@ -4,7 +4,7 @@
直接从存储的数据文件生成可视化,无需启动完整的记忆管理器 直接从存储的数据文件生成可视化,无需启动完整的记忆管理器
""" """
import json import orjson
import sys import sys
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
@@ -122,7 +122,7 @@ def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]:
print(f"📂 加载图数据: {graph_file}") print(f"📂 加载图数据: {graph_file}")
with open(graph_file, 'r', encoding='utf-8') as f: with open(graph_file, 'r', encoding='utf-8') as f:
data = json.load(f) data = orjson.loads(f.read())
# 解析数据 # 解析数据
nodes_dict = {} nodes_dict = {}