feat: 将 JSON 处理库从 json 更改为 orjson,以提高性能和兼容性

This commit is contained in:
Windpicker-owo
2025-11-06 12:47:56 +08:00
parent e29266582d
commit 17c1d4b4f9
18 changed files with 83 additions and 78 deletions

View File

@@ -137,6 +137,7 @@ class MemoryManager:
graph_store=self.graph_store,
persistence_manager=self.persistence,
embedding_generator=self.embedding_generator,
max_expand_depth=getattr(self.config, 'max_expand_depth', 1), # 从配置读取默认深度
)
self._initialized = True

View File

@@ -102,8 +102,8 @@ class VectorStore:
# 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items():
if isinstance(value, (list, dict)):
import json
metadata[key] = json.dumps(value, ensure_ascii=False)
import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value
else:
@@ -141,7 +141,7 @@ class VectorStore:
try:
# 准备元数据
import json
import orjson
metadatas = []
for n in valid_nodes:
metadata = {
@@ -151,7 +151,7 @@ class VectorStore:
}
for key, value in n.metadata.items():
if isinstance(value, (list, dict)):
metadata[key] = json.dumps(value, ensure_ascii=False)
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value # type: ignore
else:
@@ -207,7 +207,7 @@ class VectorStore:
)
# 解析结果
import json
import orjson
similar_nodes = []
if results["ids"] and results["ids"][0]:
for i, node_id in enumerate(results["ids"][0]):
@@ -223,7 +223,7 @@ class VectorStore:
for key, value in list(metadata.items()):
if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
try:
metadata[key] = json.loads(value)
metadata[key] = orjson.loads(value)
except:
pass # 保持原值

View File

@@ -34,6 +34,7 @@ class MemoryTools:
graph_store: GraphStore,
persistence_manager: PersistenceManager,
embedding_generator: Optional[EmbeddingGenerator] = None,
max_expand_depth: int = 1,
):
"""
初始化工具集
@@ -43,11 +44,13 @@ class MemoryTools:
graph_store: 图存储
persistence_manager: 持久化管理器
embedding_generator: 嵌入生成器(可选)
max_expand_depth: 图扩展深度的默认值(从配置读取)
"""
self.vector_store = vector_store
self.graph_store = graph_store
self.persistence_manager = persistence_manager
self._initialized = False
self.max_expand_depth = max_expand_depth # 保存配置的默认值
# 初始化组件
self.extractor = MemoryExtractor()
@@ -448,11 +451,12 @@ class MemoryTools:
try:
query = params.get("query", "")
top_k = params.get("top_k", 10)
expand_depth = params.get("expand_depth", 1)
# 使用配置中的默认值而不是硬编码的 1
expand_depth = params.get("expand_depth", self.max_expand_depth)
use_multi_query = params.get("use_multi_query", True)
context = params.get("context", None)
logger.info(f"搜索记忆: {query} (top_k={top_k}, multi_query={use_multi_query})")
logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth}, multi_query={use_multi_query})")
# 0. 确保初始化
await self._ensure_initialized()
@@ -474,9 +478,9 @@ class MemoryTools:
ids = metadata["memory_ids"]
# 确保是列表
if isinstance(ids, str):
import json
import orjson
try:
ids = json.loads(ids)
ids = orjson.loads(ids)
except:
ids = [ids]
if isinstance(ids, list):
@@ -649,11 +653,11 @@ class MemoryTools:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
import json, re
import orjson, re
response = re.sub(r'```json\s*', '', response)
response = re.sub(r'```\s*$', '', response).strip()
data = json.loads(response)
data = orjson.loads(response)
queries = data.get("queries", [])
result = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
@@ -799,9 +803,9 @@ class MemoryTools:
# 确保是列表
if isinstance(ids, str):
import json
import orjson
try:
ids = json.loads(ids)
ids = orjson.loads(ids)
except Exception as e:
logger.warning(f"JSON 解析失败: {e}")
ids = [ids]
@@ -910,9 +914,9 @@ class MemoryTools:
# 提取记忆ID
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
if isinstance(neighbor_memory_ids, str):
import json
import orjson
try:
neighbor_memory_ids = json.loads(neighbor_memory_ids)
neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
except:
neighbor_memory_ids = [neighbor_memory_ids]