Merge branch 'feature/memory-graph-system' of https://github.com/MoFox-Studio/MoFox_Bot into feature/memory-graph-system
This commit is contained in:
@@ -323,8 +323,8 @@ class GlobalNoticeManager:
|
||||
return message.additional_config.get("is_notice", False)
|
||||
elif isinstance(message.additional_config, str):
|
||||
# 兼容JSON字符串格式
|
||||
import json
|
||||
config = json.loads(message.additional_config)
|
||||
import orjson
|
||||
config = orjson.loads(message.additional_config)
|
||||
return config.get("is_notice", False)
|
||||
|
||||
# 检查消息类型或其他标识
|
||||
@@ -349,8 +349,8 @@ class GlobalNoticeManager:
|
||||
if isinstance(message.additional_config, dict):
|
||||
return message.additional_config.get("notice_type")
|
||||
elif isinstance(message.additional_config, str):
|
||||
import json
|
||||
config = json.loads(message.additional_config)
|
||||
import orjson
|
||||
config = orjson.loads(message.additional_config)
|
||||
return config.get("notice_type")
|
||||
return None
|
||||
except Exception:
|
||||
|
||||
@@ -553,18 +553,56 @@ class DefaultReplyer:
|
||||
if user_info_obj:
|
||||
sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "")
|
||||
|
||||
# 获取参与者信息
|
||||
participants = []
|
||||
try:
|
||||
# 尝试从聊天流中获取参与者信息
|
||||
if hasattr(stream, 'chat_history_manager'):
|
||||
history_manager = stream.chat_history_manager
|
||||
# 获取最近的参与者列表
|
||||
recent_records = history_manager.get_memory_chat_history(
|
||||
user_id=getattr(stream, "user_id", ""),
|
||||
count=10,
|
||||
memory_types=["chat_message", "system_message"]
|
||||
)
|
||||
# 提取唯一的参与者名称
|
||||
for record in recent_records[:5]: # 最近5条记录
|
||||
content = record.get("content", {})
|
||||
participant = content.get("participant_name")
|
||||
if participant and participant not in participants:
|
||||
participants.append(participant)
|
||||
|
||||
# 如果消息包含发送者信息,也添加到参与者列表
|
||||
if content.get("sender_name") and content.get("sender_name") not in participants:
|
||||
participants.append(content.get("sender_name"))
|
||||
except Exception as e:
|
||||
logger.debug(f"获取参与者信息失败: {e}")
|
||||
|
||||
# 如果发送者不在参与者列表中,添加进去
|
||||
if sender_name and sender_name not in participants:
|
||||
participants.insert(0, sender_name)
|
||||
|
||||
# 格式化聊天历史为更友好的格式
|
||||
formatted_history = ""
|
||||
if chat_history:
|
||||
# 移除过长的历史记录,只保留最近部分
|
||||
lines = chat_history.strip().split('\n')
|
||||
recent_lines = lines[-10:] if len(lines) > 10 else lines
|
||||
formatted_history = '\n'.join(recent_lines)
|
||||
|
||||
query_context = {
|
||||
"chat_history": chat_history if chat_history else "",
|
||||
"chat_history": formatted_history,
|
||||
"sender": sender_name,
|
||||
"participants": participants,
|
||||
}
|
||||
|
||||
# 使用记忆管理器的智能检索(自动优化查询)
|
||||
# 使用记忆管理器的智能检索(多查询策略)
|
||||
memories = await manager.search_memories(
|
||||
query=target,
|
||||
top_k=10,
|
||||
min_importance=0.3,
|
||||
include_forgotten=False,
|
||||
optimize_query=True,
|
||||
use_multi_query=True,
|
||||
context=query_context,
|
||||
)
|
||||
|
||||
@@ -667,32 +705,46 @@ class DefaultReplyer:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# 使用工具执行器获取信息
|
||||
# 首先获取当前的历史记录(在执行新工具调用之前)
|
||||
tool_history_str = self.tool_executor.history_manager.format_for_prompt(max_records=3, include_results=True)
|
||||
|
||||
# 然后执行工具调用
|
||||
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
|
||||
sender=sender, target_message=target, chat_history=chat_history, return_details=False
|
||||
)
|
||||
|
||||
info_parts = []
|
||||
|
||||
# 显示之前的工具调用历史(不包括当前这次调用)
|
||||
if tool_history_str:
|
||||
info_parts.append(tool_history_str)
|
||||
|
||||
# 显示当前工具调用的结果(简要信息)
|
||||
if tool_results:
|
||||
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
|
||||
current_results_parts = ["## 🔧 刚获取的工具信息"]
|
||||
for tool_result in tool_results:
|
||||
tool_name = tool_result.get("tool_name", "unknown")
|
||||
content = tool_result.get("content", "")
|
||||
result_type = tool_result.get("type", "tool_result")
|
||||
|
||||
tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n"
|
||||
# 不进行截断,让工具自己处理结果长度
|
||||
current_results_parts.append(f"- **{tool_name}**: {content}")
|
||||
|
||||
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
|
||||
info_parts.append("\n".join(current_results_parts))
|
||||
logger.info(f"获取到 {len(tool_results)} 个工具结果")
|
||||
|
||||
return tool_info_str
|
||||
else:
|
||||
logger.debug("未获取到任何工具结果")
|
||||
# 如果没有任何信息,返回空字符串
|
||||
if not info_parts:
|
||||
logger.debug("未获取到任何工具结果或历史记录")
|
||||
return ""
|
||||
|
||||
return "\n\n".join(info_parts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具信息获取失败: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
|
||||
"""解析回复目标消息 - 使用共享工具"""
|
||||
from src.chat.utils.prompt import Prompt
|
||||
|
||||
@@ -57,8 +57,16 @@ class CacheManager:
|
||||
# 嵌入模型
|
||||
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
|
||||
|
||||
# 工具调用统计
|
||||
self.tool_stats = {
|
||||
"total_tool_calls": 0,
|
||||
"cache_hits_by_tool": {}, # 按工具名称统计缓存命中
|
||||
"execution_times_by_tool": {}, # 按工具名称统计执行时间
|
||||
"most_used_tools": {}, # 最常用的工具
|
||||
}
|
||||
|
||||
self._initialized = True
|
||||
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)")
|
||||
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB) + 工具统计")
|
||||
|
||||
@staticmethod
|
||||
def _validate_embedding(embedding_result: Any) -> np.ndarray | None:
|
||||
@@ -363,58 +371,205 @@ class CacheManager:
|
||||
|
||||
def get_health_stats(self) -> dict[str, Any]:
|
||||
"""获取缓存健康统计信息"""
|
||||
from src.common.memory_utils import format_size
|
||||
|
||||
# 简化的健康统计,不包含内存监控(因为相关属性未定义)
|
||||
return {
|
||||
"l1_count": len(self.l1_kv_cache),
|
||||
"l1_memory": self.l1_current_memory,
|
||||
"l1_memory_formatted": format_size(self.l1_current_memory),
|
||||
"l1_max_memory": self.l1_max_memory,
|
||||
"l1_memory_usage_percent": round((self.l1_current_memory / self.l1_max_memory) * 100, 2),
|
||||
"l1_max_size": self.l1_max_size,
|
||||
"l1_size_usage_percent": round((len(self.l1_kv_cache) / self.l1_max_size) * 100, 2),
|
||||
"average_item_size": self.l1_current_memory // len(self.l1_kv_cache) if self.l1_kv_cache else 0,
|
||||
"average_item_size_formatted": format_size(self.l1_current_memory // len(self.l1_kv_cache)) if self.l1_kv_cache else "0 B",
|
||||
"largest_item_size": max(self.l1_size_map.values()) if self.l1_size_map else 0,
|
||||
"largest_item_size_formatted": format_size(max(self.l1_size_map.values())) if self.l1_size_map else "0 B",
|
||||
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0,
|
||||
"tool_stats": {
|
||||
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
|
||||
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
|
||||
"cache_hits": sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
|
||||
"cache_misses": sum(data.get("misses", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def check_health(self) -> tuple[bool, list[str]]:
|
||||
"""检查缓存健康状态
|
||||
|
||||
|
||||
Returns:
|
||||
(is_healthy, warnings) - 是否健康,警告列表
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# 检查内存使用
|
||||
memory_usage = (self.l1_current_memory / self.l1_max_memory) * 100
|
||||
if memory_usage > 90:
|
||||
warnings.append(f"⚠️ L1缓存内存使用率过高: {memory_usage:.1f}%")
|
||||
elif memory_usage > 75:
|
||||
warnings.append(f"⚡ L1缓存内存使用率较高: {memory_usage:.1f}%")
|
||||
|
||||
# 检查条目数
|
||||
size_usage = (len(self.l1_kv_cache) / self.l1_max_size) * 100
|
||||
if size_usage > 90:
|
||||
warnings.append(f"⚠️ L1缓存条目数过多: {size_usage:.1f}%")
|
||||
|
||||
# 检查平均条目大小
|
||||
if self.l1_kv_cache:
|
||||
avg_size = self.l1_current_memory // len(self.l1_kv_cache)
|
||||
if avg_size > 100 * 1024: # >100KB
|
||||
from src.common.memory_utils import format_size
|
||||
warnings.append(f"⚡ 平均缓存条目过大: {format_size(avg_size)}")
|
||||
|
||||
# 检查最大单条目
|
||||
if self.l1_size_map:
|
||||
max_size = max(self.l1_size_map.values())
|
||||
if max_size > 500 * 1024: # >500KB
|
||||
from src.common.memory_utils import format_size
|
||||
warnings.append(f"⚠️ 发现超大缓存条目: {format_size(max_size)}")
|
||||
|
||||
|
||||
# 检查L1缓存大小
|
||||
l1_size = len(self.l1_kv_cache)
|
||||
if l1_size > 1000: # 如果超过1000个条目
|
||||
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
|
||||
|
||||
# 检查向量索引大小
|
||||
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0
|
||||
if isinstance(vector_count, int) and vector_count > 500:
|
||||
warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}")
|
||||
|
||||
# 检查工具统计健康
|
||||
total_calls = self.tool_stats.get("total_tool_calls", 0)
|
||||
if total_calls > 0:
|
||||
total_hits = sum(data.get("hits", 0) for data in self.tool_stats.get("cache_hits_by_tool", {}).values())
|
||||
cache_hit_rate = (total_hits / total_calls) * 100
|
||||
if cache_hit_rate < 50: # 缓存命中率低于50%
|
||||
warnings.append(f"⚡ 整体缓存命中率较低: {cache_hit_rate:.1f}%")
|
||||
|
||||
return len(warnings) == 0, warnings
|
||||
|
||||
async def get_tool_result_with_stats(self,
|
||||
tool_name: str,
|
||||
function_args: dict[str, Any],
|
||||
tool_file_path: str | Path,
|
||||
semantic_query: str | None = None) -> tuple[Any | None, bool]:
|
||||
"""获取工具结果并更新统计信息
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
function_args: 函数参数
|
||||
tool_file_path: 工具文件路径
|
||||
semantic_query: 语义查询字符串
|
||||
|
||||
Returns:
|
||||
Tuple[结果, 是否命中缓存]
|
||||
"""
|
||||
# 更新总调用次数
|
||||
self.tool_stats["total_tool_calls"] += 1
|
||||
|
||||
# 更新工具使用统计
|
||||
if tool_name not in self.tool_stats["most_used_tools"]:
|
||||
self.tool_stats["most_used_tools"][tool_name] = 0
|
||||
self.tool_stats["most_used_tools"][tool_name] += 1
|
||||
|
||||
# 尝试获取缓存
|
||||
result = await self.get(tool_name, function_args, tool_file_path, semantic_query)
|
||||
|
||||
# 更新缓存命中统计
|
||||
if tool_name not in self.tool_stats["cache_hits_by_tool"]:
|
||||
self.tool_stats["cache_hits_by_tool"][tool_name] = {"hits": 0, "misses": 0}
|
||||
|
||||
if result is not None:
|
||||
self.tool_stats["cache_hits_by_tool"][tool_name]["hits"] += 1
|
||||
logger.info(f"工具缓存命中: {tool_name}")
|
||||
return result, True
|
||||
else:
|
||||
self.tool_stats["cache_hits_by_tool"][tool_name]["misses"] += 1
|
||||
return None, False
|
||||
|
||||
async def set_tool_result_with_stats(self,
|
||||
tool_name: str,
|
||||
function_args: dict[str, Any],
|
||||
tool_file_path: str | Path,
|
||||
data: Any,
|
||||
execution_time: float | None = None,
|
||||
ttl: int | None = None,
|
||||
semantic_query: str | None = None):
|
||||
"""存储工具结果并更新统计信息
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
function_args: 函数参数
|
||||
tool_file_path: 工具文件路径
|
||||
data: 结果数据
|
||||
execution_time: 执行时间
|
||||
ttl: 缓存TTL
|
||||
semantic_query: 语义查询字符串
|
||||
"""
|
||||
# 更新执行时间统计
|
||||
if execution_time is not None:
|
||||
if tool_name not in self.tool_stats["execution_times_by_tool"]:
|
||||
self.tool_stats["execution_times_by_tool"][tool_name] = []
|
||||
self.tool_stats["execution_times_by_tool"][tool_name].append(execution_time)
|
||||
|
||||
# 只保留最近100次的执行时间记录
|
||||
if len(self.tool_stats["execution_times_by_tool"][tool_name]) > 100:
|
||||
self.tool_stats["execution_times_by_tool"][tool_name] = \
|
||||
self.tool_stats["execution_times_by_tool"][tool_name][-100:]
|
||||
|
||||
# 存储到缓存
|
||||
await self.set(tool_name, function_args, tool_file_path, data, ttl, semantic_query)
|
||||
|
||||
def get_tool_performance_stats(self) -> dict[str, Any]:
|
||||
"""获取工具性能统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
stats = self.tool_stats.copy()
|
||||
|
||||
# 计算平均执行时间
|
||||
avg_times = {}
|
||||
for tool_name, times in stats["execution_times_by_tool"].items():
|
||||
if times:
|
||||
avg_times[tool_name] = {
|
||||
"average": sum(times) / len(times),
|
||||
"min": min(times),
|
||||
"max": max(times),
|
||||
"count": len(times),
|
||||
}
|
||||
|
||||
# 计算缓存命中率
|
||||
cache_hit_rates = {}
|
||||
for tool_name, hit_data in stats["cache_hits_by_tool"].items():
|
||||
total = hit_data["hits"] + hit_data["misses"]
|
||||
if total > 0:
|
||||
cache_hit_rates[tool_name] = {
|
||||
"hit_rate": (hit_data["hits"] / total) * 100,
|
||||
"hits": hit_data["hits"],
|
||||
"misses": hit_data["misses"],
|
||||
"total": total,
|
||||
}
|
||||
|
||||
# 按使用频率排序工具
|
||||
most_used = sorted(stats["most_used_tools"].items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
return {
|
||||
"total_tool_calls": stats["total_tool_calls"],
|
||||
"average_execution_times": avg_times,
|
||||
"cache_hit_rates": cache_hit_rates,
|
||||
"most_used_tools": most_used[:10], # 前10个最常用工具
|
||||
"cache_health": self.get_health_stats(),
|
||||
}
|
||||
|
||||
def get_tool_recommendations(self) -> dict[str, Any]:
|
||||
"""获取工具优化建议
|
||||
|
||||
Returns:
|
||||
优化建议字典
|
||||
"""
|
||||
recommendations = []
|
||||
|
||||
# 分析缓存命中率低的工具
|
||||
cache_hit_rates = {}
|
||||
for tool_name, hit_data in self.tool_stats["cache_hits_by_tool"].items():
|
||||
total = hit_data["hits"] + hit_data["misses"]
|
||||
if total >= 5: # 至少调用5次才分析
|
||||
hit_rate = (hit_data["hits"] / total) * 100
|
||||
cache_hit_rates[tool_name] = hit_rate
|
||||
|
||||
if hit_rate < 30: # 缓存命中率低于30%
|
||||
recommendations.append({
|
||||
"tool": tool_name,
|
||||
"type": "low_cache_hit_rate",
|
||||
"message": f"工具 {tool_name} 的缓存命中率仅为 {hit_rate:.1f}%,建议检查缓存配置或参数变化频率",
|
||||
"severity": "medium" if hit_rate > 10 else "high",
|
||||
})
|
||||
|
||||
# 分析执行时间长的工具
|
||||
for tool_name, times in self.tool_stats["execution_times_by_tool"].items():
|
||||
if len(times) >= 3: # 至少3次执行才分析
|
||||
avg_time = sum(times) / len(times)
|
||||
if avg_time > 5.0: # 平均执行时间超过5秒
|
||||
recommendations.append({
|
||||
"tool": tool_name,
|
||||
"type": "slow_execution",
|
||||
"message": f"工具 {tool_name} 平均执行时间较长 ({avg_time:.2f}s),建议优化算法或增加缓存",
|
||||
"severity": "medium" if avg_time < 10.0 else "high",
|
||||
})
|
||||
|
||||
return {
|
||||
"recommendations": recommendations,
|
||||
"summary": {
|
||||
"total_issues": len(recommendations),
|
||||
"high_priority": len([r for r in recommendations if r["severity"] == "high"]),
|
||||
"medium_priority": len([r for r in recommendations if r["severity"] == "medium"]),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 全局实例
|
||||
tool_cache = CacheManager()
|
||||
|
||||
@@ -137,6 +137,7 @@ class MemoryManager:
|
||||
graph_store=self.graph_store,
|
||||
persistence_manager=self.persistence,
|
||||
embedding_generator=self.embedding_generator,
|
||||
max_expand_depth=getattr(self.config, 'max_expand_depth', 1), # 从配置读取默认深度
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
@@ -362,18 +363,15 @@ class MemoryManager:
|
||||
|
||||
# 构建上下文信息
|
||||
chat_history = context.get("chat_history", "") if context else ""
|
||||
sender = context.get("sender", "") if context else ""
|
||||
participants = context.get("participants", []) if context else []
|
||||
participants_str = "、".join(participants) if participants else "无"
|
||||
|
||||
|
||||
prompt = f"""你是记忆检索助手。为提高检索准确率,请为查询生成3-5个不同角度的搜索语句。
|
||||
|
||||
**核心原则(重要!):**
|
||||
对于包含多个概念的复杂查询(如"杰瑞喵如何评价新的记忆系统"),应该生成:
|
||||
对于包含多个概念的复杂查询(如"小明如何评价小王"),应该生成:
|
||||
1. 完整查询(包含所有要素)- 权重1.0
|
||||
2. 每个关键概念的独立查询(如"新的记忆系统")- 权重0.8,避免被主体淹没!
|
||||
3. 主体+动作组合(如"杰瑞喵 评价")- 权重0.6
|
||||
4. 泛化查询(如"记忆系统")- 权重0.7
|
||||
2. 每个关键概念的独立查询(如"小明"、"小王")- 权重0.8,避免被主体淹没!
|
||||
3. 主体+动作组合(如"小明 评价")- 权重0.6
|
||||
4. 泛化查询(如"评价")- 权重0.7
|
||||
|
||||
**要求:**
|
||||
- 第一个必须是原始查询或同义改写
|
||||
@@ -381,9 +379,7 @@ class MemoryManager:
|
||||
- 查询简洁(5-20字)
|
||||
- 直接输出JSON,不要添加说明
|
||||
|
||||
**已知参与者:** {participants_str}
|
||||
**对话上下文:** {chat_history[-300:] if chat_history else "无"}
|
||||
**当前查询:** {sender}: {query}
|
||||
|
||||
**输出JSON格式:**
|
||||
```json
|
||||
@@ -436,7 +432,6 @@ class MemoryManager:
|
||||
time_range: Optional[Tuple[datetime, datetime]] = None,
|
||||
min_importance: float = 0.0,
|
||||
include_forgotten: bool = False,
|
||||
optimize_query: bool = True,
|
||||
use_multi_query: bool = True,
|
||||
expand_depth: int = 1,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
@@ -457,7 +452,6 @@ class MemoryManager:
|
||||
time_range: 时间范围过滤 (start, end)
|
||||
min_importance: 最小重要性
|
||||
include_forgotten: 是否包含已遗忘的记忆
|
||||
optimize_query: 是否使用小模型优化查询(已弃用,被 use_multi_query 替代)
|
||||
use_multi_query: 是否使用多查询策略(推荐,默认True)
|
||||
expand_depth: 图扩展深度(0=禁用, 1=推荐, 2-3=深度探索)
|
||||
context: 查询上下文(用于优化)
|
||||
|
||||
@@ -102,8 +102,8 @@ class VectorStore:
|
||||
# 处理额外的元数据,将 list 转换为 JSON 字符串
|
||||
for key, value in node.metadata.items():
|
||||
if isinstance(value, (list, dict)):
|
||||
import json
|
||||
metadata[key] = json.dumps(value, ensure_ascii=False)
|
||||
import orjson
|
||||
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
elif isinstance(value, (str, int, float, bool)) or value is None:
|
||||
metadata[key] = value
|
||||
else:
|
||||
@@ -141,7 +141,7 @@ class VectorStore:
|
||||
|
||||
try:
|
||||
# 准备元数据
|
||||
import json
|
||||
import orjson
|
||||
metadatas = []
|
||||
for n in valid_nodes:
|
||||
metadata = {
|
||||
@@ -151,7 +151,7 @@ class VectorStore:
|
||||
}
|
||||
for key, value in n.metadata.items():
|
||||
if isinstance(value, (list, dict)):
|
||||
metadata[key] = json.dumps(value, ensure_ascii=False)
|
||||
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
elif isinstance(value, (str, int, float, bool)) or value is None:
|
||||
metadata[key] = value # type: ignore
|
||||
else:
|
||||
@@ -207,7 +207,7 @@ class VectorStore:
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
import json
|
||||
import orjson
|
||||
similar_nodes = []
|
||||
if results["ids"] and results["ids"][0]:
|
||||
for i, node_id in enumerate(results["ids"][0]):
|
||||
@@ -223,7 +223,7 @@ class VectorStore:
|
||||
for key, value in list(metadata.items()):
|
||||
if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
|
||||
try:
|
||||
metadata[key] = json.loads(value)
|
||||
metadata[key] = orjson.loads(value)
|
||||
except:
|
||||
pass # 保持原值
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ class MemoryTools:
|
||||
graph_store: GraphStore,
|
||||
persistence_manager: PersistenceManager,
|
||||
embedding_generator: Optional[EmbeddingGenerator] = None,
|
||||
max_expand_depth: int = 1,
|
||||
):
|
||||
"""
|
||||
初始化工具集
|
||||
@@ -43,11 +44,13 @@ class MemoryTools:
|
||||
graph_store: 图存储
|
||||
persistence_manager: 持久化管理器
|
||||
embedding_generator: 嵌入生成器(可选)
|
||||
max_expand_depth: 图扩展深度的默认值(从配置读取)
|
||||
"""
|
||||
self.vector_store = vector_store
|
||||
self.graph_store = graph_store
|
||||
self.persistence_manager = persistence_manager
|
||||
self._initialized = False
|
||||
self.max_expand_depth = max_expand_depth # 保存配置的默认值
|
||||
|
||||
# 初始化组件
|
||||
self.extractor = MemoryExtractor()
|
||||
@@ -448,11 +451,12 @@ class MemoryTools:
|
||||
try:
|
||||
query = params.get("query", "")
|
||||
top_k = params.get("top_k", 10)
|
||||
expand_depth = params.get("expand_depth", 1)
|
||||
# 使用配置中的默认值而不是硬编码的 1
|
||||
expand_depth = params.get("expand_depth", self.max_expand_depth)
|
||||
use_multi_query = params.get("use_multi_query", True)
|
||||
context = params.get("context", None)
|
||||
|
||||
logger.info(f"搜索记忆: {query} (top_k={top_k}, multi_query={use_multi_query})")
|
||||
logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth}, multi_query={use_multi_query})")
|
||||
|
||||
# 0. 确保初始化
|
||||
await self._ensure_initialized()
|
||||
@@ -474,9 +478,9 @@ class MemoryTools:
|
||||
ids = metadata["memory_ids"]
|
||||
# 确保是列表
|
||||
if isinstance(ids, str):
|
||||
import json
|
||||
import orjson
|
||||
try:
|
||||
ids = json.loads(ids)
|
||||
ids = orjson.loads(ids)
|
||||
except:
|
||||
ids = [ids]
|
||||
if isinstance(ids, list):
|
||||
@@ -625,35 +629,63 @@ class MemoryTools:
|
||||
try:
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.multi_query"
|
||||
)
|
||||
|
||||
participants = context.get("participants", []) if context else []
|
||||
prompt = f"""为查询生成3-5个不同角度的搜索语句(JSON格式)。
|
||||
|
||||
**查询:** {query}
|
||||
# 获取上下文信息
|
||||
participants = context.get("participants", []) if context else []
|
||||
chat_history = context.get("chat_history", "") if context else ""
|
||||
sender = context.get("sender", "") if context else ""
|
||||
|
||||
# 处理聊天历史,提取最近5条左右的对话
|
||||
recent_chat = ""
|
||||
if chat_history:
|
||||
lines = chat_history.strip().split('\n')
|
||||
# 取最近5条消息
|
||||
recent_lines = lines[-5:] if len(lines) > 5 else lines
|
||||
recent_chat = '\n'.join(recent_lines)
|
||||
|
||||
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句(JSON格式)。
|
||||
|
||||
**当前查询:** {query}
|
||||
**发送者:** {sender if sender else '未知'}
|
||||
**参与者:** {', '.join(participants) if participants else '无'}
|
||||
|
||||
**原则:** 对复杂查询(如"杰瑞喵如何评价新的记忆系统"),应生成:
|
||||
1. 完整查询(权重1.0)
|
||||
2. 每个关键概念独立查询(权重0.8)- 重要!
|
||||
3. 主体+动作(权重0.6)
|
||||
**最近聊天记录(最近5条):**
|
||||
{recent_chat if recent_chat else '无聊天历史'}
|
||||
|
||||
**输出JSON:**
|
||||
**分析原则:**
|
||||
1. **上下文理解**:根据聊天历史理解查询的真实意图
|
||||
2. **指代消解**:识别并代换"他"、"她"、"它"、"那个"等指代词
|
||||
3. **话题关联**:结合最近讨论的话题生成更精准的查询
|
||||
4. **查询分解**:对复杂查询分解为多个子查询
|
||||
|
||||
**生成策略:**
|
||||
1. **完整查询**(权重1.0):结合上下文的完整查询,包含指代消解
|
||||
2. **关键概念查询**(权重0.8):查询中的核心概念,特别是聊天中提到的实体
|
||||
3. **话题扩展查询**(权重0.7):基于最近聊天话题的相关查询
|
||||
4. **动作/情感查询**(权重0.6):如果涉及情感或动作,生成相关查询
|
||||
|
||||
**输出JSON格式:**
|
||||
```json
|
||||
{{"queries": [{{"text": "查询1", "weight": 1.0}}, {{"text": "查询2", "weight": 0.8}}]}}
|
||||
```"""
|
||||
{{"queries": [{{"text": "查询语句", "weight": 1.0}}, {{"text": "查询语句", "weight": 0.8}}]}}
|
||||
```
|
||||
|
||||
**示例:**
|
||||
- 查询:"他怎么样了?" + 聊天中提到"小明生病了" → "小明身体恢复情况"
|
||||
- 查询:"那个项目" + 聊天中讨论"记忆系统开发" → "记忆系统项目进展"
|
||||
"""
|
||||
|
||||
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
|
||||
|
||||
import json, re
|
||||
import orjson, re
|
||||
response = re.sub(r'```json\s*', '', response)
|
||||
response = re.sub(r'```\s*$', '', response).strip()
|
||||
|
||||
data = json.loads(response)
|
||||
data = orjson.loads(response)
|
||||
queries = data.get("queries", [])
|
||||
|
||||
result = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
|
||||
@@ -799,9 +831,9 @@ class MemoryTools:
|
||||
|
||||
# 确保是列表
|
||||
if isinstance(ids, str):
|
||||
import json
|
||||
import orjson
|
||||
try:
|
||||
ids = json.loads(ids)
|
||||
ids = orjson.loads(ids)
|
||||
except Exception as e:
|
||||
logger.warning(f"JSON 解析失败: {e}")
|
||||
ids = [ids]
|
||||
@@ -910,9 +942,9 @@ class MemoryTools:
|
||||
# 提取记忆ID
|
||||
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
|
||||
if isinstance(neighbor_memory_ids, str):
|
||||
import json
|
||||
import orjson
|
||||
try:
|
||||
neighbor_memory_ids = json.loads(neighbor_memory_ids)
|
||||
neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
|
||||
except:
|
||||
neighbor_memory_ids = [neighbor_memory_ids]
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import json
|
||||
import orjson
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, ClassVar
|
||||
@@ -100,10 +100,10 @@ class PluginStorage:
|
||||
if os.path.exists(self.file_path):
|
||||
with open(self.file_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
self._data = json.loads(content) if content else {}
|
||||
self._data = orjson.loads(content) if content else {}
|
||||
else:
|
||||
self._data = {}
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
except (orjson.JSONDecodeError, Exception) as e:
|
||||
logger.warning(f"从 '{self.file_path}' 加载数据失败: {e},将初始化为空数据。")
|
||||
self._data = {}
|
||||
|
||||
@@ -125,7 +125,7 @@ class PluginStorage:
|
||||
|
||||
try:
|
||||
with open(self.file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8'))
|
||||
self._dirty = False # 保存后重置标志
|
||||
logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。")
|
||||
except Exception as e:
|
||||
|
||||
@@ -5,7 +5,7 @@ MCP Client Manager
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import orjson
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -89,7 +89,7 @@ class MCPClientManager:
|
||||
|
||||
try:
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
config_data = orjson.loads(f.read())
|
||||
|
||||
servers = {}
|
||||
mcp_servers = config_data.get("mcpServers", {})
|
||||
@@ -106,7 +106,7 @@ class MCPClientManager:
|
||||
logger.info(f"成功加载 {len(servers)} 个 MCP 服务器配置")
|
||||
return servers
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
logger.error(f"解析 MCP 配置文件失败: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
|
||||
414
src/plugin_system/core/stream_tool_history.py
Normal file
414
src/plugin_system/core/stream_tool_history.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
流式工具历史记录管理器
|
||||
用于在聊天流级别管理工具调用历史,支持智能缓存和上下文感知
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from dataclasses import dataclass, asdict, field
|
||||
import orjson
|
||||
from src.common.logger import get_logger
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
logger = get_logger("stream_tool_history")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRecord:
|
||||
"""工具调用记录"""
|
||||
tool_name: str
|
||||
args: dict[str, Any]
|
||||
result: Optional[dict[str, Any]] = None
|
||||
status: str = "success" # success, error, pending
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
execution_time: Optional[float] = None # 执行耗时(秒)
|
||||
cache_hit: bool = False # 是否命中缓存
|
||||
result_preview: str = "" # 结果预览
|
||||
error_message: str = "" # 错误信息
|
||||
|
||||
def __post_init__(self):
|
||||
"""后处理:生成结果预览"""
|
||||
if self.result and not self.result_preview:
|
||||
content = self.result.get("content", "")
|
||||
if isinstance(content, str):
|
||||
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
|
||||
elif isinstance(content, (list, dict)):
|
||||
try:
|
||||
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..."
|
||||
except Exception:
|
||||
self.result_preview = str(content)[:500] + "..."
|
||||
else:
|
||||
self.result_preview = str(content)[:500] + "..."
|
||||
|
||||
|
||||
class StreamToolHistoryManager:
|
||||
"""流式工具历史记录管理器
|
||||
|
||||
提供以下功能:
|
||||
1. 工具调用历史的持久化管理
|
||||
2. 智能缓存集成和结果去重
|
||||
3. 上下文感知的历史记录检索
|
||||
4. 性能监控和统计
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, max_history: int = 20, enable_memory_cache: bool = True):
|
||||
"""初始化历史记录管理器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID,用于隔离不同聊天流的历史记录
|
||||
max_history: 最大历史记录数量
|
||||
enable_memory_cache: 是否启用内存缓存
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.max_history = max_history
|
||||
self.enable_memory_cache = enable_memory_cache
|
||||
|
||||
# 内存中的历史记录,按时间顺序排列
|
||||
self._history: list[ToolCallRecord] = []
|
||||
|
||||
# 性能统计
|
||||
self._stats = {
|
||||
"total_calls": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"total_execution_time": 0.0,
|
||||
"average_execution_time": 0.0,
|
||||
}
|
||||
|
||||
logger.info(f"[{chat_id}] 工具历史记录管理器初始化完成,最大历史: {max_history}")
|
||||
|
||||
async def add_tool_call(self, record: ToolCallRecord) -> None:
|
||||
"""添加工具调用记录
|
||||
|
||||
Args:
|
||||
record: 工具调用记录
|
||||
"""
|
||||
# 维护历史记录大小
|
||||
if len(self._history) >= self.max_history:
|
||||
# 移除最旧的记录
|
||||
removed_record = self._history.pop(0)
|
||||
logger.debug(f"[{self.chat_id}] 移除旧记录: {removed_record.tool_name}")
|
||||
|
||||
# 添加新记录
|
||||
self._history.append(record)
|
||||
|
||||
# 更新统计
|
||||
self._stats["total_calls"] += 1
|
||||
if record.cache_hit:
|
||||
self._stats["cache_hits"] += 1
|
||||
else:
|
||||
self._stats["cache_misses"] += 1
|
||||
|
||||
if record.execution_time is not None:
|
||||
self._stats["total_execution_time"] += record.execution_time
|
||||
self._stats["average_execution_time"] = self._stats["total_execution_time"] / self._stats["total_calls"]
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
|
||||
|
||||
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
"""从缓存或历史记录中获取结果
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
args: 工具参数
|
||||
|
||||
Returns:
|
||||
缓存的结果,如果不存在则返回None
|
||||
"""
|
||||
# 首先检查内存中的历史记录
|
||||
if self.enable_memory_cache:
|
||||
memory_result = self._search_memory_cache(tool_name, args)
|
||||
if memory_result:
|
||||
logger.info(f"[{self.chat_id}] 内存缓存命中: {tool_name}")
|
||||
return memory_result
|
||||
|
||||
# 然后检查全局缓存系统
|
||||
try:
|
||||
# 这里需要工具实例来获取文件路径,但为了解耦,我们先尝试从历史记录中推断
|
||||
tool_file_path = self._infer_tool_path(tool_name)
|
||||
|
||||
# 尝试语义缓存(如果可以推断出语义查询参数)
|
||||
semantic_query = self._extract_semantic_query(tool_name, args)
|
||||
|
||||
cached_result = await tool_cache.get(
|
||||
tool_name=tool_name,
|
||||
function_args=args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query,
|
||||
)
|
||||
|
||||
if cached_result:
|
||||
logger.info(f"[{self.chat_id}] 全局缓存命中: {tool_name}")
|
||||
|
||||
# 将结果同步到内存缓存
|
||||
if self.enable_memory_cache:
|
||||
record = ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
result=cached_result,
|
||||
status="success",
|
||||
cache_hit=True,
|
||||
timestamp=time.time(),
|
||||
)
|
||||
await self.add_tool_call(record)
|
||||
|
||||
return cached_result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.chat_id}] 缓存查询失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
|
||||
execution_time: Optional[float] = None,
|
||||
tool_file_path: Optional[str] = None,
|
||||
ttl: Optional[int] = None) -> None:
|
||||
"""缓存工具调用结果
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
args: 工具参数
|
||||
result: 执行结果
|
||||
execution_time: 执行耗时
|
||||
tool_file_path: 工具文件路径
|
||||
ttl: 缓存TTL
|
||||
"""
|
||||
# 添加到内存历史记录
|
||||
record = ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
result=result,
|
||||
status="success",
|
||||
execution_time=execution_time,
|
||||
cache_hit=False,
|
||||
timestamp=time.time(),
|
||||
)
|
||||
await self.add_tool_call(record)
|
||||
|
||||
# 同步到全局缓存系统
|
||||
try:
|
||||
if tool_file_path is None:
|
||||
tool_file_path = self._infer_tool_path(tool_name)
|
||||
|
||||
# 尝试语义缓存
|
||||
semantic_query = self._extract_semantic_query(tool_name, args)
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_name,
|
||||
function_args=args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=ttl,
|
||||
semantic_query=semantic_query,
|
||||
)
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 结果已缓存: {tool_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
|
||||
|
||||
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]:
|
||||
"""获取最近的历史记录
|
||||
|
||||
Args:
|
||||
count: 返回的记录数量
|
||||
status_filter: 状态过滤器,可选值:success, error, pending
|
||||
|
||||
Returns:
|
||||
历史记录列表
|
||||
"""
|
||||
history = self._history.copy()
|
||||
|
||||
# 应用状态过滤
|
||||
if status_filter:
|
||||
history = [record for record in history if record.status == status_filter]
|
||||
|
||||
# 返回最近的记录
|
||||
return history[-count:] if history else []
|
||||
|
||||
def format_for_prompt(self, max_records: int = 5, include_results: bool = True) -> str:
|
||||
"""格式化历史记录为提示词
|
||||
|
||||
Args:
|
||||
max_records: 最大记录数量
|
||||
include_results: 是否包含结果预览
|
||||
|
||||
Returns:
|
||||
格式化的提示词字符串
|
||||
"""
|
||||
if not self._history:
|
||||
return ""
|
||||
|
||||
recent_records = self._history[-max_records:]
|
||||
|
||||
lines = ["## 🔧 最近工具调用记录"]
|
||||
for i, record in enumerate(recent_records, 1):
|
||||
status_icon = "✅" if record.status == "success" else "❌" if record.status == "error" else "⏳"
|
||||
|
||||
# 格式化参数
|
||||
args_preview = self._format_args_preview(record.args)
|
||||
|
||||
# 基础信息
|
||||
lines.append(f"{i}. {status_icon} **{record.tool_name}**({args_preview})")
|
||||
|
||||
# 添加执行时间和缓存信息
|
||||
if record.execution_time is not None:
|
||||
time_info = f"{record.execution_time:.2f}s"
|
||||
cache_info = "🎯缓存" if record.cache_hit else "🔍执行"
|
||||
lines.append(f" ⏱️ {time_info} | {cache_info}")
|
||||
|
||||
# 添加结果预览
|
||||
if include_results and record.result_preview:
|
||||
lines.append(f" 📝 结果: {record.result_preview}")
|
||||
|
||||
# 添加错误信息
|
||||
if record.status == "error" and record.error_message:
|
||||
lines.append(f" ❌ 错误: {record.error_message}")
|
||||
|
||||
# 添加统计信息
|
||||
if self._stats["total_calls"] > 0:
|
||||
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
|
||||
avg_time = self._stats["average_execution_time"]
|
||||
lines.append(f"\n📊 工具统计: 总计{self._stats['total_calls']}次 | 缓存命中率{cache_hit_rate:.1f}% | 平均耗时{avg_time:.2f}s")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""获取性能统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
cache_hit_rate = 0.0
|
||||
if self._stats["total_calls"] > 0:
|
||||
cache_hit_rate = (self._stats["cache_hits"] / self._stats["total_calls"]) * 100
|
||||
|
||||
return {
|
||||
**self._stats,
|
||||
"cache_hit_rate": cache_hit_rate,
|
||||
"history_size": len(self._history),
|
||||
"chat_id": self.chat_id,
|
||||
}
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""清除历史记录"""
|
||||
self._history.clear()
|
||||
logger.info(f"[{self.chat_id}] 工具历史记录已清除")
|
||||
|
||||
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
"""在内存历史记录中搜索缓存
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
args: 工具参数
|
||||
|
||||
Returns:
|
||||
匹配的结果,如果不存在则返回None
|
||||
"""
|
||||
for record in reversed(self._history): # 从最新的开始搜索
|
||||
if (record.tool_name == tool_name and
|
||||
record.status == "success" and
|
||||
record.args == args):
|
||||
return record.result
|
||||
return None
|
||||
|
||||
def _infer_tool_path(self, tool_name: str) -> str:
|
||||
"""推断工具文件路径
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
推断的文件路径
|
||||
"""
|
||||
# 基于工具名称推断路径,这是一个简化的实现
|
||||
# 在实际使用中,可能需要更复杂的映射逻辑
|
||||
tool_path_mapping = {
|
||||
"web_search": "src/plugins/built_in/web_search_tool/tools/web_search.py",
|
||||
"memory_create": "src/memory_graph/tools/memory_tools.py",
|
||||
"memory_search": "src/memory_graph/tools/memory_tools.py",
|
||||
"user_profile_update": "src/plugins/built_in/affinity_flow_chatter/tools/user_profile_tool.py",
|
||||
"chat_stream_impression_update": "src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py",
|
||||
}
|
||||
|
||||
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
|
||||
|
||||
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]:
|
||||
"""提取语义查询参数
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
args: 工具参数
|
||||
|
||||
Returns:
|
||||
语义查询字符串,如果不存在则返回None
|
||||
"""
|
||||
# 为不同工具定义语义查询参数映射
|
||||
semantic_query_mapping = {
|
||||
"web_search": "query",
|
||||
"memory_search": "query",
|
||||
"knowledge_search": "query",
|
||||
}
|
||||
|
||||
query_key = semantic_query_mapping.get(tool_name)
|
||||
if query_key and query_key in args:
|
||||
return str(args[query_key])
|
||||
|
||||
return None
|
||||
|
||||
def _format_args_preview(self, args: dict[str, Any], max_length: int = 100) -> str:
|
||||
"""格式化参数预览
|
||||
|
||||
Args:
|
||||
args: 参数字典
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
格式化的参数预览字符串
|
||||
"""
|
||||
if not args:
|
||||
return ""
|
||||
|
||||
try:
|
||||
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
|
||||
if len(args_str) > max_length:
|
||||
args_str = args_str[:max_length] + "..."
|
||||
return args_str
|
||||
except Exception:
|
||||
# 如果序列化失败,使用简单格式
|
||||
parts = []
|
||||
for k, v in list(args.items())[:3]: # 最多显示3个参数
|
||||
parts.append(f"{k}={str(v)[:20]}")
|
||||
result = ", ".join(parts)
|
||||
if len(parts) >= 3 or len(result) > max_length:
|
||||
result += "..."
|
||||
return result
|
||||
|
||||
|
||||
# 全局管理器字典,按chat_id索引
|
||||
_stream_managers: dict[str, StreamToolHistoryManager] = {}
|
||||
|
||||
|
||||
def get_stream_tool_history_manager(chat_id: str) -> StreamToolHistoryManager:
|
||||
"""获取指定聊天的工具历史记录管理器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
工具历史记录管理器实例
|
||||
"""
|
||||
if chat_id not in _stream_managers:
|
||||
_stream_managers[chat_id] = StreamToolHistoryManager(chat_id)
|
||||
return _stream_managers[chat_id]
|
||||
|
||||
|
||||
def cleanup_stream_manager(chat_id: str) -> None:
|
||||
"""清理指定聊天的管理器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
"""
|
||||
if chat_id in _stream_managers:
|
||||
del _stream_managers[chat_id]
|
||||
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.cache_manager import tool_cache
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
@@ -11,6 +10,8 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord
|
||||
from dataclasses import asdict
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
@@ -36,15 +37,29 @@ def init_tool_executor_prompt():
|
||||
|
||||
{tool_history}
|
||||
|
||||
## 🔧 工具使用
|
||||
## 🔧 工具决策指南
|
||||
|
||||
根据上下文判断是否需要使用工具。每个工具都有详细的description说明其用途和参数,请根据工具定义决定是否调用。
|
||||
**核心原则:**
|
||||
- 根据上下文智能判断是否需要使用工具
|
||||
- 每个工具都有详细的description说明其用途和参数
|
||||
- 避免重复调用历史记录中已执行的工具(除非参数不同)
|
||||
- 优先考虑使用已有的缓存结果,避免重复调用
|
||||
|
||||
**历史记录说明:**
|
||||
- 上方显示的是**之前**的工具调用记录
|
||||
- 请参考历史记录避免重复调用相同参数的工具
|
||||
- 如果历史记录中已有相关结果,可以考虑直接回答而不调用工具
|
||||
|
||||
**⚠️ 记忆创建特别提醒:**
|
||||
创建记忆时,subject(主体)必须使用对话历史中显示的**真实发送人名字**!
|
||||
- ✅ 正确:从"Prou(12345678): ..."中提取"Prou"作为subject
|
||||
- ❌ 错误:使用"用户"、"对方"等泛指词
|
||||
|
||||
**工具调用策略:**
|
||||
1. **避免重复调用**:查看历史记录,如果最近已调用过相同工具且参数一致,无需重复调用
|
||||
2. **智能选择工具**:根据消息内容选择最合适的工具,避免过度使用
|
||||
3. **参数优化**:确保工具参数简洁有效,避免冗余信息
|
||||
|
||||
**执行指令:**
|
||||
- 需要使用工具 → 直接调用相应的工具函数
|
||||
- 不需要工具 → 输出 "No tool needed"
|
||||
@@ -81,9 +96,8 @@ class ToolExecutor:
|
||||
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
|
||||
self._log_prefix_initialized = False
|
||||
|
||||
# 工具调用历史
|
||||
self.tool_call_history: list[dict[str, Any]] = []
|
||||
"""工具调用历史,包含工具名称、参数和结果"""
|
||||
# 流式工具历史记录管理器
|
||||
self.history_manager = get_stream_tool_history_manager(chat_id)
|
||||
|
||||
# logger.info(f"{self.log_prefix}工具执行器初始化完成") # 移到异步初始化中
|
||||
|
||||
@@ -125,7 +139,7 @@ class ToolExecutor:
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
# 构建工具调用历史文本
|
||||
tool_history = self._format_tool_history()
|
||||
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
|
||||
|
||||
# 获取人设信息
|
||||
personality_core = global_config.personality.personality_core
|
||||
@@ -183,83 +197,7 @@ class ToolExecutor:
|
||||
|
||||
return tool_definitions
|
||||
|
||||
def _format_tool_history(self, max_history: int = 5) -> str:
|
||||
"""格式化工具调用历史为文本
|
||||
|
||||
Args:
|
||||
max_history: 最多显示的历史记录数量
|
||||
|
||||
Returns:
|
||||
格式化的工具历史文本
|
||||
"""
|
||||
if not self.tool_call_history:
|
||||
return ""
|
||||
|
||||
# 只取最近的几条历史
|
||||
recent_history = self.tool_call_history[-max_history:]
|
||||
|
||||
history_lines = ["历史工具调用记录:"]
|
||||
for i, record in enumerate(recent_history, 1):
|
||||
tool_name = record.get("tool_name", "unknown")
|
||||
args = record.get("args", {})
|
||||
result_preview = record.get("result_preview", "")
|
||||
status = record.get("status", "success")
|
||||
|
||||
# 格式化参数
|
||||
args_str = ", ".join([f"{k}={v}" for k, v in args.items()])
|
||||
|
||||
# 格式化记录
|
||||
status_emoji = "✓" if status == "success" else "✗"
|
||||
history_lines.append(f"{i}. {status_emoji} {tool_name}({args_str})")
|
||||
|
||||
if result_preview:
|
||||
# 限制结果预览长度
|
||||
if len(result_preview) > 200:
|
||||
result_preview = result_preview[:200] + "..."
|
||||
history_lines.append(f" 结果: {result_preview}")
|
||||
|
||||
return "\n".join(history_lines)
|
||||
|
||||
def _add_tool_to_history(self, tool_name: str, args: dict, result: dict | None, status: str = "success"):
|
||||
"""添加工具调用到历史记录
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
args: 工具参数
|
||||
result: 工具结果
|
||||
status: 执行状态 (success/error)
|
||||
"""
|
||||
# 生成结果预览
|
||||
result_preview = ""
|
||||
if result:
|
||||
content = result.get("content", "")
|
||||
if isinstance(content, str):
|
||||
result_preview = content
|
||||
elif isinstance(content, list | dict):
|
||||
import json
|
||||
|
||||
try:
|
||||
result_preview = json.dumps(content, ensure_ascii=False)
|
||||
except Exception:
|
||||
result_preview = str(content)
|
||||
else:
|
||||
result_preview = str(content)
|
||||
|
||||
record = {
|
||||
"tool_name": tool_name,
|
||||
"args": args,
|
||||
"result_preview": result_preview,
|
||||
"status": status,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
self.tool_call_history.append(record)
|
||||
|
||||
# 限制历史记录数量,避免内存溢出
|
||||
max_history_size = 5
|
||||
if len(self.tool_call_history) > max_history_size:
|
||||
self.tool_call_history = self.tool_call_history[-max_history_size:]
|
||||
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
@@ -320,10 +258,20 @@ class ToolExecutor:
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
|
||||
# 记录到历史
|
||||
self._add_tool_to_history(tool_name, tool_args, result, status="success")
|
||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=tool_args,
|
||||
result=result,
|
||||
status="success"
|
||||
))
|
||||
else:
|
||||
# 工具返回空结果也记录到历史
|
||||
self._add_tool_to_history(tool_name, tool_args, None, status="success")
|
||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=tool_args,
|
||||
result=None,
|
||||
status="success"
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||
@@ -338,62 +286,72 @@ class ToolExecutor:
|
||||
tool_results.append(error_info)
|
||||
|
||||
# 记录失败到历史
|
||||
self._add_tool_to_history(tool_name, tool_args, None, status="error")
|
||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=tool_args,
|
||||
result=None,
|
||||
status="error",
|
||||
error_message=str(e)
|
||||
))
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""执行单个工具调用,并处理缓存"""
|
||||
"""执行单个工具调用,集成流式历史记录管理器"""
|
||||
|
||||
start_time = time.time()
|
||||
function_args = tool_call.args or {}
|
||||
tool_instance = tool_instance or get_tool_instance(tool_call.func_name, self.chat_stream)
|
||||
|
||||
# 如果工具不存在或未启用缓存,则直接执行
|
||||
if not tool_instance or not tool_instance.enable_cache:
|
||||
return await self._original_execute_tool_call(tool_call, tool_instance)
|
||||
# 尝试从历史记录管理器获取缓存结果
|
||||
if tool_instance and tool_instance.enable_cache:
|
||||
try:
|
||||
cached_result = await self.history_manager.get_cached_result(
|
||||
tool_name=tool_call.func_name,
|
||||
args=function_args
|
||||
)
|
||||
if cached_result:
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
|
||||
# --- 缓存逻辑开始 ---
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||
# 记录缓存命中到历史
|
||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||
tool_name=tool_call.func_name,
|
||||
args=function_args,
|
||||
result=cached_result,
|
||||
status="success",
|
||||
execution_time=execution_time,
|
||||
cache_hit=True
|
||||
))
|
||||
|
||||
cached_result = await tool_cache.get(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=function_args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query,
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}检查历史缓存时出错: {e}")
|
||||
|
||||
# 缓存未命中,执行原始工具调用
|
||||
# 缓存未命中,执行工具调用
|
||||
result = await self._original_execute_tool_call(tool_call, tool_instance)
|
||||
|
||||
# 将结果存入缓存
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||
# 记录执行结果到历史管理器
|
||||
execution_time = time.time() - start_time
|
||||
if tool_instance and result and tool_instance.enable_cache:
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=function_args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
|
||||
# --- 缓存逻辑结束 ---
|
||||
await self.history_manager.cache_result(
|
||||
tool_name=tool_call.func_name,
|
||||
args=function_args,
|
||||
result=result,
|
||||
execution_time=execution_time,
|
||||
tool_file_path=tool_file_path,
|
||||
ttl=tool_instance.cache_ttl
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}缓存结果到历史管理器时出错: {e}")
|
||||
|
||||
return result
|
||||
|
||||
@@ -528,21 +486,31 @@ class ToolExecutor:
|
||||
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
|
||||
|
||||
# 记录到历史
|
||||
self._add_tool_to_history(tool_name, tool_args, result, status="success")
|
||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=tool_args,
|
||||
result=result,
|
||||
status="success"
|
||||
))
|
||||
|
||||
return tool_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
|
||||
# 记录失败到历史
|
||||
self._add_tool_to_history(tool_name, tool_args, None, status="error")
|
||||
await self.history_manager.add_tool_call(ToolCallRecord(
|
||||
tool_name=tool_name,
|
||||
args=tool_args,
|
||||
result=None,
|
||||
status="error",
|
||||
error_message=str(e)
|
||||
))
|
||||
|
||||
return None
|
||||
|
||||
def clear_tool_history(self):
|
||||
"""清除工具调用历史"""
|
||||
self.tool_call_history.clear()
|
||||
logger.debug(f"{self.log_prefix}已清除工具调用历史")
|
||||
self.history_manager.clear_history()
|
||||
|
||||
def get_tool_history(self) -> list[dict[str, Any]]:
|
||||
"""获取工具调用历史
|
||||
@@ -550,7 +518,17 @@ class ToolExecutor:
|
||||
Returns:
|
||||
工具调用历史列表
|
||||
"""
|
||||
return self.tool_call_history.copy()
|
||||
# 返回最近的历史记录
|
||||
records = self.history_manager.get_recent_history(count=10)
|
||||
return [asdict(record) for record in records]
|
||||
|
||||
def get_tool_stats(self) -> dict[str, Any]:
|
||||
"""获取工具统计信息
|
||||
|
||||
Returns:
|
||||
工具统计信息字典
|
||||
"""
|
||||
return self.history_manager.get_stats()
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -652,7 +652,7 @@ class ChatterPlanFilter:
|
||||
enhanced_memories = await memory_manager.search_memories(
|
||||
query=query,
|
||||
top_k=5,
|
||||
optimize_query=False, # 直接使用关键词查询
|
||||
use_multi_query=False, # 直接使用关键词查询
|
||||
)
|
||||
|
||||
if not enhanced_memories:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
当定时任务触发时,负责搜集信息、调用LLM决策、并根据决策生成回复
|
||||
"""
|
||||
|
||||
import json
|
||||
import orjson
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
负责记录和管理已回复过的评论ID,避免重复回复
|
||||
"""
|
||||
|
||||
import json
|
||||
import orjson
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -71,7 +71,7 @@ class ReplyTrackerService:
|
||||
self.replied_comments = {}
|
||||
return
|
||||
|
||||
data = json.loads(file_content)
|
||||
data = orjson.loads(file_content)
|
||||
if self._validate_data(data):
|
||||
self.replied_comments = data
|
||||
logger.info(
|
||||
@@ -81,7 +81,7 @@ class ReplyTrackerService:
|
||||
else:
|
||||
logger.error("加载的数据格式无效,将创建新的记录")
|
||||
self.replied_comments = {}
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
logger.error(f"解析回复记录文件失败: {e}")
|
||||
self._backup_corrupted_file()
|
||||
self.replied_comments = {}
|
||||
@@ -118,7 +118,7 @@ class ReplyTrackerService:
|
||||
|
||||
# 先写入临时文件
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2)
|
||||
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
|
||||
# 如果写入成功,重命名为正式文件
|
||||
if temp_file.stat().st_size > 0: # 确保写入成功
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import orjson
|
||||
from typing import ClassVar, List
|
||||
|
||||
import websockets as Server
|
||||
@@ -44,10 +44,10 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||
# 只在debug模式下记录原始消息
|
||||
if logger.level <= 10: # DEBUG level
|
||||
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
decoded_raw_message: dict = orjson.loads(raw_message)
|
||||
try:
|
||||
# 首先尝试解析原始消息
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
decoded_raw_message: dict = orjson.loads(raw_message)
|
||||
|
||||
# 检查是否是切片消息 (来自 MMC)
|
||||
if chunker.is_chunk_message(decoded_raw_message):
|
||||
@@ -71,7 +71,7 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||
elif post_type is None:
|
||||
await put_response(decoded_raw_message)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
logger.error(f"消息解析失败: {e}")
|
||||
logger.debug(f"原始消息: {raw_message[:500]}...")
|
||||
except Exception as e:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import orjson
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
@@ -34,7 +34,7 @@ class MessageChunker:
|
||||
"""判断消息是否需要切片"""
|
||||
try:
|
||||
if isinstance(message, dict):
|
||||
message_str = json.dumps(message, ensure_ascii=False)
|
||||
message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
else:
|
||||
message_str = message
|
||||
return len(message_str.encode("utf-8")) > self.max_chunk_size
|
||||
@@ -58,7 +58,7 @@ class MessageChunker:
|
||||
try:
|
||||
# 统一转换为字符串
|
||||
if isinstance(message, dict):
|
||||
message_str = json.dumps(message, ensure_ascii=False)
|
||||
message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
else:
|
||||
message_str = message
|
||||
|
||||
@@ -116,7 +116,7 @@ class MessageChunker:
|
||||
"""判断是否是切片消息"""
|
||||
try:
|
||||
if isinstance(message, str):
|
||||
data = json.loads(message)
|
||||
data = orjson.loads(message)
|
||||
else:
|
||||
data = message
|
||||
|
||||
@@ -126,7 +126,7 @@ class MessageChunker:
|
||||
and "__mmc_chunk_data__" in data
|
||||
and "__mmc_is_chunked__" in data
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class MessageReassembler:
|
||||
try:
|
||||
# 统一转换为字典
|
||||
if isinstance(message, str):
|
||||
chunk_data = json.loads(message)
|
||||
chunk_data = orjson.loads(message)
|
||||
else:
|
||||
chunk_data = message
|
||||
|
||||
@@ -197,8 +197,8 @@ class MessageReassembler:
|
||||
if "_original_message" in chunk_data:
|
||||
# 这是一个被包装的非切片消息,解包返回
|
||||
try:
|
||||
return json.loads(chunk_data["_original_message"])
|
||||
except json.JSONDecodeError:
|
||||
return orjson.loads(chunk_data["_original_message"])
|
||||
except orjson.JSONDecodeError:
|
||||
return {"text_message": chunk_data["_original_message"]}
|
||||
else:
|
||||
return chunk_data
|
||||
@@ -251,14 +251,14 @@ class MessageReassembler:
|
||||
|
||||
# 尝试反序列化重组后的消息
|
||||
try:
|
||||
return json.loads(reassembled_message)
|
||||
except json.JSONDecodeError:
|
||||
return orjson.loads(reassembled_message)
|
||||
except orjson.JSONDecodeError:
|
||||
# 如果不能反序列化为JSON,则作为文本消息返回
|
||||
return {"text_message": reassembled_message}
|
||||
|
||||
return None
|
||||
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(f"处理切片消息时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import orjson
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@@ -783,7 +783,7 @@ class MessageHandler:
|
||||
# 检查JSON消息格式
|
||||
if not message_data or "data" not in message_data:
|
||||
logger.warning("JSON消息格式不正确")
|
||||
return Seg(type="json", data=json.dumps(message_data))
|
||||
return Seg(type="json", data=orjson.dumps(message_data).decode('utf-8'))
|
||||
|
||||
try:
|
||||
# 尝试将json_data解析为Python对象
|
||||
@@ -1146,13 +1146,13 @@ class MessageHandler:
|
||||
return None
|
||||
forward_message_id = forward_message_data.get("id")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
payload = orjson.dumps(
|
||||
{
|
||||
"action": "get_forward_msg",
|
||||
"params": {"message_id": forward_message_id},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
).decode('utf-8')
|
||||
try:
|
||||
connection = self.get_server_connection()
|
||||
if not connection:
|
||||
@@ -1167,9 +1167,9 @@ class MessageHandler:
|
||||
logger.error(f"获取转发消息失败: {str(e)}")
|
||||
return None
|
||||
logger.debug(
|
||||
f"转发消息原始格式:{json.dumps(response)[:80]}..."
|
||||
if len(json.dumps(response)) > 80
|
||||
else json.dumps(response)
|
||||
f"转发消息原始格式:{orjson.dumps(response).decode('utf-8')[:80]}..."
|
||||
if len(orjson.dumps(response).decode('utf-8')) > 80
|
||||
else orjson.dumps(response).decode('utf-8')
|
||||
)
|
||||
response_data: Dict = response.get("data")
|
||||
if not response_data:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import orjson
|
||||
import time
|
||||
from typing import ClassVar, Optional, Tuple
|
||||
|
||||
@@ -241,7 +241,7 @@ class NoticeHandler:
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=handled_message,
|
||||
raw_message=json.dumps(raw_message),
|
||||
raw_message=orjson.dumps(raw_message).decode('utf-8'),
|
||||
)
|
||||
|
||||
if system_notice:
|
||||
@@ -602,7 +602,7 @@ class NoticeHandler:
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=seg_message,
|
||||
raw_message=json.dumps(
|
||||
raw_message=orjson.dumps(
|
||||
{
|
||||
"post_type": "notice",
|
||||
"notice_type": "group_ban",
|
||||
@@ -611,7 +611,7 @@ class NoticeHandler:
|
||||
"user_id": user_id,
|
||||
"operator_id": None, # 自然解除禁言没有操作者
|
||||
}
|
||||
),
|
||||
).decode('utf-8'),
|
||||
)
|
||||
|
||||
await self.put_notice(message_base)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import json
|
||||
import orjson
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
@@ -605,7 +605,7 @@ class SendHandler:
|
||||
|
||||
async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict:
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": action, "params": params, "echo": request_uuid}).decode('utf-8')
|
||||
|
||||
# 获取当前连接
|
||||
connection = self.get_server_connection()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import orjson
|
||||
import ssl
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@@ -34,7 +34,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
|
||||
"""
|
||||
logger.debug("获取群聊信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
@@ -56,7 +56,7 @@ async def get_group_detail_info(websocket: Server.ServerConnection, group_id: in
|
||||
"""
|
||||
logger.debug("获取群详细信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
@@ -78,13 +78,13 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use
|
||||
"""
|
||||
logger.debug("获取群成员信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
payload = orjson.dumps(
|
||||
{
|
||||
"action": "get_group_member_info",
|
||||
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
@@ -146,7 +146,7 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
||||
"""
|
||||
logger.debug("获取自身信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
@@ -183,7 +183,7 @@ async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) ->
|
||||
"""
|
||||
logger.debug("获取陌生人信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid}).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
@@ -208,7 +208,7 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: Uni
|
||||
"""
|
||||
logger.debug("获取消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
|
||||
@@ -236,13 +236,13 @@ async def get_record_detail(
|
||||
"""
|
||||
logger.debug("获取语音消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
payload = orjson.dumps(
|
||||
{
|
||||
"action": "get_record",
|
||||
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
|
||||
|
||||
@@ -39,15 +39,23 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
return self.api_manager.is_available()
|
||||
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Exa搜索"""
|
||||
"""执行优化的Exa搜索(使用answer模式)"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
num_results = min(args.get("num_results", 5), 5) # 默认5个结果,但限制最多5个
|
||||
time_range = args.get("time_range", "any")
|
||||
|
||||
exa_args = {"num_results": num_results, "text": True, "highlights": True}
|
||||
# 优化的搜索参数 - 更注重答案质量
|
||||
exa_args = {
|
||||
"num_results": num_results,
|
||||
"text": True,
|
||||
"highlights": True,
|
||||
"summary": True, # 启用自动摘要
|
||||
}
|
||||
|
||||
# 时间范围过滤
|
||||
if time_range != "any":
|
||||
today = datetime.now()
|
||||
start_date = today - timedelta(days=7 if time_range == "week" else 30)
|
||||
@@ -61,18 +69,89 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
return []
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
# 使用search_and_contents获取完整内容,优化为answer模式
|
||||
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
return [
|
||||
{
|
||||
# 优化结果处理 - 更注重答案质量
|
||||
results = []
|
||||
for res in search_response.results:
|
||||
# 获取最佳内容片段
|
||||
highlights = getattr(res, "highlights", [])
|
||||
summary = getattr(res, "summary", "")
|
||||
text = getattr(res, "text", "")
|
||||
|
||||
# 智能内容选择:摘要 > 高亮 > 文本开头
|
||||
if summary and len(summary) > 50:
|
||||
snippet = summary.strip()
|
||||
elif highlights:
|
||||
snippet = " ".join(highlights).strip()
|
||||
elif text:
|
||||
snippet = text[:300] + "..." if len(text) > 300 else text
|
||||
else:
|
||||
snippet = "内容获取失败"
|
||||
|
||||
# 只保留有意义的摘要
|
||||
if len(snippet) < 30:
|
||||
snippet = text[:200] + "..." if text and len(text) > 200 else snippet
|
||||
|
||||
results.append({
|
||||
"title": res.title,
|
||||
"url": res.url,
|
||||
"snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."),
|
||||
"snippet": snippet,
|
||||
"provider": "Exa",
|
||||
}
|
||||
for res in search_response.results
|
||||
]
|
||||
"answer_focused": True, # 标记为答案导向的搜索
|
||||
})
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Exa 搜索失败: {e}")
|
||||
logger.error(f"Exa answer模式搜索失败: {e}")
|
||||
return []
|
||||
|
||||
async def answer_search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Exa快速答案搜索 - 最精简的搜索模式"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
query = args["query"]
|
||||
num_results = min(args.get("num_results", 3), 3) # answer模式默认3个结果,专注质量
|
||||
|
||||
# 精简的搜索参数 - 专注快速答案
|
||||
exa_args = {
|
||||
"num_results": num_results,
|
||||
"text": False, # 不需要全文
|
||||
"highlights": True, # 只要关键高亮
|
||||
"summary": True, # 优先摘要
|
||||
}
|
||||
|
||||
try:
|
||||
exa_client = self.api_manager.get_next_client()
|
||||
if not exa_client:
|
||||
return []
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
# 极简结果处理 - 只保留最核心信息
|
||||
results = []
|
||||
for res in search_response.results:
|
||||
summary = getattr(res, "summary", "")
|
||||
highlights = getattr(res, "highlights", [])
|
||||
|
||||
# 优先使用摘要,否则使用高亮
|
||||
answer_text = summary.strip() if summary and len(summary) > 30 else " ".join(highlights).strip()
|
||||
|
||||
if answer_text and len(answer_text) > 20:
|
||||
results.append({
|
||||
"title": res.title,
|
||||
"url": res.url,
|
||||
"snippet": answer_text[:400] + "..." if len(answer_text) > 400 else answer_text,
|
||||
"provider": "Exa-Answer",
|
||||
"answer_mode": True # 标记为纯答案模式
|
||||
})
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Exa快速答案搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Metaso Search Engine (Chat Completions Mode)
|
||||
"""
|
||||
import json
|
||||
import orjson
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -43,12 +43,12 @@ class MetasoClient:
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
data = orjson.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content_chunk = delta.get("content")
|
||||
if content_chunk:
|
||||
full_response_content += content_chunk
|
||||
except json.JSONDecodeError:
|
||||
except orjson.JSONDecodeError:
|
||||
logger.warning(f"Metaso stream: could not decode JSON line: {data_str}")
|
||||
continue
|
||||
|
||||
|
||||
@@ -41,6 +41,13 @@ class WebSurfingTool(BaseTool):
|
||||
False,
|
||||
["any", "week", "month"],
|
||||
),
|
||||
(
|
||||
"answer_mode",
|
||||
ToolParamType.BOOLEAN,
|
||||
"是否启用答案模式(仅适用于Exa搜索引擎)。启用后将返回更精简、直接的答案,减少冗余信息。默认为False。",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
] # type: ignore
|
||||
|
||||
def __init__(self, plugin_config=None, chat_stream=None):
|
||||
@@ -97,13 +104,19 @@ class WebSurfingTool(BaseTool):
|
||||
) -> dict[str, Any]:
|
||||
"""并行搜索策略:同时使用所有启用的搜索引擎"""
|
||||
search_tasks = []
|
||||
answer_mode = function_args.get("answer_mode", False)
|
||||
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if engine and engine.is_available():
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
search_tasks.append(engine.search(custom_args))
|
||||
|
||||
# 如果启用了answer模式且是Exa引擎,使用answer_search方法
|
||||
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
|
||||
search_tasks.append(engine.answer_search(custom_args))
|
||||
else:
|
||||
search_tasks.append(engine.search(custom_args))
|
||||
|
||||
if not search_tasks:
|
||||
|
||||
@@ -137,17 +150,23 @@ class WebSurfingTool(BaseTool):
|
||||
self, function_args: dict[str, Any], enabled_engines: list[str]
|
||||
) -> dict[str, Any]:
|
||||
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
|
||||
answer_mode = function_args.get("answer_mode", False)
|
||||
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
|
||||
continue
|
||||
|
||||
try:
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
|
||||
results = await engine.search(custom_args)
|
||||
# 如果启用了answer模式且是Exa引擎,使用answer_search方法
|
||||
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
|
||||
logger.info("使用Exa答案模式进行搜索(fallback策略)")
|
||||
results = await engine.answer_search(custom_args)
|
||||
else:
|
||||
results = await engine.search(custom_args)
|
||||
|
||||
if results: # 如果有结果,直接返回
|
||||
formatted_content = format_search_results(results)
|
||||
@@ -164,22 +183,30 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]:
|
||||
"""单一搜索策略:只使用第一个可用的搜索引擎"""
|
||||
answer_mode = function_args.get("answer_mode", False)
|
||||
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
|
||||
continue
|
||||
|
||||
try:
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
|
||||
results = await engine.search(custom_args)
|
||||
formatted_content = format_search_results(results)
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
}
|
||||
# 如果启用了answer模式且是Exa引擎,使用answer_search方法
|
||||
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
|
||||
logger.info("使用Exa答案模式进行搜索")
|
||||
results = await engine.answer_search(custom_args)
|
||||
else:
|
||||
results = await engine.search(custom_args)
|
||||
|
||||
if results:
|
||||
formatted_content = format_search_results(results)
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{engine_name} 搜索失败: {e}")
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import orjson
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
直接从存储的数据文件生成可视化,无需启动完整的记忆管理器
|
||||
"""
|
||||
|
||||
import json
|
||||
import orjson
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
@@ -122,7 +122,7 @@ def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]:
|
||||
|
||||
print(f"📂 加载图数据: {graph_file}")
|
||||
with open(graph_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
data = orjson.loads(f.read())
|
||||
|
||||
# 解析数据
|
||||
nodes_dict = {}
|
||||
|
||||
Reference in New Issue
Block a user