feat: 将 JSON 处理库从 json 更改为 orjson,以提高性能和兼容性

This commit is contained in:
Windpicker-owo
2025-11-06 12:47:56 +08:00
parent e29266582d
commit 17c1d4b4f9
18 changed files with 83 additions and 78 deletions

View File

@@ -323,8 +323,8 @@ class GlobalNoticeManager:
return message.additional_config.get("is_notice", False) return message.additional_config.get("is_notice", False)
elif isinstance(message.additional_config, str): elif isinstance(message.additional_config, str):
# 兼容JSON字符串格式 # 兼容JSON字符串格式
import json import orjson
config = json.loads(message.additional_config) config = orjson.loads(message.additional_config)
return config.get("is_notice", False) return config.get("is_notice", False)
# 检查消息类型或其他标识 # 检查消息类型或其他标识
@@ -349,8 +349,8 @@ class GlobalNoticeManager:
if isinstance(message.additional_config, dict): if isinstance(message.additional_config, dict):
return message.additional_config.get("notice_type") return message.additional_config.get("notice_type")
elif isinstance(message.additional_config, str): elif isinstance(message.additional_config, str):
import json import orjson
config = json.loads(message.additional_config) config = orjson.loads(message.additional_config)
return config.get("notice_type") return config.get("notice_type")
return None return None
except Exception: except Exception:

View File

@@ -137,6 +137,7 @@ class MemoryManager:
graph_store=self.graph_store, graph_store=self.graph_store,
persistence_manager=self.persistence, persistence_manager=self.persistence,
embedding_generator=self.embedding_generator, embedding_generator=self.embedding_generator,
max_expand_depth=getattr(self.config, 'max_expand_depth', 1), # 从配置读取默认深度
) )
self._initialized = True self._initialized = True

View File

@@ -102,8 +102,8 @@ class VectorStore:
# 处理额外的元数据,将 list 转换为 JSON 字符串 # 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items(): for key, value in node.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, (list, dict)):
import json import orjson
metadata[key] = json.dumps(value, ensure_ascii=False) metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value metadata[key] = value
else: else:
@@ -141,7 +141,7 @@ class VectorStore:
try: try:
# 准备元数据 # 准备元数据
import json import orjson
metadatas = [] metadatas = []
for n in valid_nodes: for n in valid_nodes:
metadata = { metadata = {
@@ -151,7 +151,7 @@ class VectorStore:
} }
for key, value in n.metadata.items(): for key, value in n.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, (list, dict)):
metadata[key] = json.dumps(value, ensure_ascii=False) metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value # type: ignore metadata[key] = value # type: ignore
else: else:
@@ -207,7 +207,7 @@ class VectorStore:
) )
# 解析结果 # 解析结果
import json import orjson
similar_nodes = [] similar_nodes = []
if results["ids"] and results["ids"][0]: if results["ids"] and results["ids"][0]:
for i, node_id in enumerate(results["ids"][0]): for i, node_id in enumerate(results["ids"][0]):
@@ -223,7 +223,7 @@ class VectorStore:
for key, value in list(metadata.items()): for key, value in list(metadata.items()):
if isinstance(value, str) and (value.startswith('[') or value.startswith('{')): if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
try: try:
metadata[key] = json.loads(value) metadata[key] = orjson.loads(value)
except: except:
pass # 保持原值 pass # 保持原值

View File

@@ -34,6 +34,7 @@ class MemoryTools:
graph_store: GraphStore, graph_store: GraphStore,
persistence_manager: PersistenceManager, persistence_manager: PersistenceManager,
embedding_generator: Optional[EmbeddingGenerator] = None, embedding_generator: Optional[EmbeddingGenerator] = None,
max_expand_depth: int = 1,
): ):
""" """
初始化工具集 初始化工具集
@@ -43,11 +44,13 @@ class MemoryTools:
graph_store: 图存储 graph_store: 图存储
persistence_manager: 持久化管理器 persistence_manager: 持久化管理器
embedding_generator: 嵌入生成器(可选) embedding_generator: 嵌入生成器(可选)
max_expand_depth: 图扩展深度的默认值(从配置读取)
""" """
self.vector_store = vector_store self.vector_store = vector_store
self.graph_store = graph_store self.graph_store = graph_store
self.persistence_manager = persistence_manager self.persistence_manager = persistence_manager
self._initialized = False self._initialized = False
self.max_expand_depth = max_expand_depth # 保存配置的默认值
# 初始化组件 # 初始化组件
self.extractor = MemoryExtractor() self.extractor = MemoryExtractor()
@@ -448,11 +451,12 @@ class MemoryTools:
try: try:
query = params.get("query", "") query = params.get("query", "")
top_k = params.get("top_k", 10) top_k = params.get("top_k", 10)
expand_depth = params.get("expand_depth", 1) # 使用配置中的默认值而不是硬编码的 1
expand_depth = params.get("expand_depth", self.max_expand_depth)
use_multi_query = params.get("use_multi_query", True) use_multi_query = params.get("use_multi_query", True)
context = params.get("context", None) context = params.get("context", None)
logger.info(f"搜索记忆: {query} (top_k={top_k}, multi_query={use_multi_query})") logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth}, multi_query={use_multi_query})")
# 0. 确保初始化 # 0. 确保初始化
await self._ensure_initialized() await self._ensure_initialized()
@@ -474,9 +478,9 @@ class MemoryTools:
ids = metadata["memory_ids"] ids = metadata["memory_ids"]
# 确保是列表 # 确保是列表
if isinstance(ids, str): if isinstance(ids, str):
import json import orjson
try: try:
ids = json.loads(ids) ids = orjson.loads(ids)
except: except:
ids = [ids] ids = [ids]
if isinstance(ids, list): if isinstance(ids, list):
@@ -649,11 +653,11 @@ class MemoryTools:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250) response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
import json, re import orjson, re
response = re.sub(r'```json\s*', '', response) response = re.sub(r'```json\s*', '', response)
response = re.sub(r'```\s*$', '', response).strip() response = re.sub(r'```\s*$', '', response).strip()
data = json.loads(response) data = orjson.loads(response)
queries = data.get("queries", []) queries = data.get("queries", [])
result = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) result = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
@@ -799,9 +803,9 @@ class MemoryTools:
# 确保是列表 # 确保是列表
if isinstance(ids, str): if isinstance(ids, str):
import json import orjson
try: try:
ids = json.loads(ids) ids = orjson.loads(ids)
except Exception as e: except Exception as e:
logger.warning(f"JSON 解析失败: {e}") logger.warning(f"JSON 解析失败: {e}")
ids = [ids] ids = [ids]
@@ -910,9 +914,9 @@ class MemoryTools:
# 提取记忆ID # 提取记忆ID
neighbor_memory_ids = neighbor_node_data.get("memory_ids", []) neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
if isinstance(neighbor_memory_ids, str): if isinstance(neighbor_memory_ids, str):
import json import orjson
try: try:
neighbor_memory_ids = json.loads(neighbor_memory_ids) neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
except: except:
neighbor_memory_ids = [neighbor_memory_ids] neighbor_memory_ids = [neighbor_memory_ids]

View File

@@ -7,7 +7,7 @@
""" """
import atexit import atexit
import json import orjson
import os import os
import threading import threading
from typing import Any, ClassVar from typing import Any, ClassVar
@@ -100,10 +100,10 @@ class PluginStorage:
if os.path.exists(self.file_path): if os.path.exists(self.file_path):
with open(self.file_path, encoding="utf-8") as f: with open(self.file_path, encoding="utf-8") as f:
content = f.read() content = f.read()
self._data = json.loads(content) if content else {} self._data = orjson.loads(content) if content else {}
else: else:
self._data = {} self._data = {}
except (json.JSONDecodeError, Exception) as e: except (orjson.JSONDecodeError, Exception) as e:
logger.warning(f"'{self.file_path}' 加载数据失败: {e},将初始化为空数据。") logger.warning(f"'{self.file_path}' 加载数据失败: {e},将初始化为空数据。")
self._data = {} self._data = {}
@@ -125,7 +125,7 @@ class PluginStorage:
try: try:
with open(self.file_path, "w", encoding="utf-8") as f: with open(self.file_path, "w", encoding="utf-8") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False) f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8'))
self._dirty = False # 保存后重置标志 self._dirty = False # 保存后重置标志
logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。") logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。")
except Exception as e: except Exception as e:

View File

@@ -5,7 +5,7 @@ MCP Client Manager
""" """
import asyncio import asyncio
import json import orjson
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -89,7 +89,7 @@ class MCPClientManager:
try: try:
with open(self.config_path, encoding="utf-8") as f: with open(self.config_path, encoding="utf-8") as f:
config_data = json.load(f) config_data = orjson.loads(f.read())
servers = {} servers = {}
mcp_servers = config_data.get("mcpServers", {}) mcp_servers = config_data.get("mcpServers", {})
@@ -106,7 +106,7 @@ class MCPClientManager:
logger.info(f"成功加载 {len(servers)} 个 MCP 服务器配置") logger.info(f"成功加载 {len(servers)} 个 MCP 服务器配置")
return servers return servers
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
logger.error(f"解析 MCP 配置文件失败: {e}") logger.error(f"解析 MCP 配置文件失败: {e}")
return {} return {}
except Exception as e: except Exception as e:

View File

@@ -236,10 +236,10 @@ class ToolExecutor:
if isinstance(content, str): if isinstance(content, str):
result_preview = content result_preview = content
elif isinstance(content, list | dict): elif isinstance(content, list | dict):
import json import orjson
try: try:
result_preview = json.dumps(content, ensure_ascii=False) result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
except Exception: except Exception:
result_preview = str(content) result_preview = str(content)
else: else:

View File

@@ -3,7 +3,7 @@
当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复 当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复
""" """
import json import orjson
from datetime import datetime from datetime import datetime
from typing import Any, Literal from typing import Any, Literal

View File

@@ -3,7 +3,7 @@
负责记录和管理已回复过的评论ID避免重复回复 负责记录和管理已回复过的评论ID避免重复回复
""" """
import json import orjson
import time import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -71,7 +71,7 @@ class ReplyTrackerService:
self.replied_comments = {} self.replied_comments = {}
return return
data = json.loads(file_content) data = orjson.loads(file_content)
if self._validate_data(data): if self._validate_data(data):
self.replied_comments = data self.replied_comments = data
logger.info( logger.info(
@@ -81,7 +81,7 @@ class ReplyTrackerService:
else: else:
logger.error("加载的数据格式无效,将创建新的记录") logger.error("加载的数据格式无效,将创建新的记录")
self.replied_comments = {} self.replied_comments = {}
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
logger.error(f"解析回复记录文件失败: {e}") logger.error(f"解析回复记录文件失败: {e}")
self._backup_corrupted_file() self._backup_corrupted_file()
self.replied_comments = {} self.replied_comments = {}
@@ -118,7 +118,7 @@ class ReplyTrackerService:
# 先写入临时文件 # 先写入临时文件
with open(temp_file, "w", encoding="utf-8") as f: with open(temp_file, "w", encoding="utf-8") as f:
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2) orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')
# 如果写入成功,重命名为正式文件 # 如果写入成功,重命名为正式文件
if temp_file.stat().st_size > 0: # 确保写入成功 if temp_file.stat().st_size > 0: # 确保写入成功

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
import inspect import inspect
import json import orjson
from typing import ClassVar, List from typing import ClassVar, List
import websockets as Server import websockets as Server
@@ -44,10 +44,10 @@ async def message_recv(server_connection: Server.ServerConnection):
# 只在debug模式下记录原始消息 # 只在debug模式下记录原始消息
if logger.level <= 10: # DEBUG level if logger.level <= 10: # DEBUG level
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message) logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
decoded_raw_message: dict = json.loads(raw_message) decoded_raw_message: dict = orjson.loads(raw_message)
try: try:
# 首先尝试解析原始消息 # 首先尝试解析原始消息
decoded_raw_message: dict = json.loads(raw_message) decoded_raw_message: dict = orjson.loads(raw_message)
# 检查是否是切片消息 (来自 MMC) # 检查是否是切片消息 (来自 MMC)
if chunker.is_chunk_message(decoded_raw_message): if chunker.is_chunk_message(decoded_raw_message):
@@ -71,7 +71,7 @@ async def message_recv(server_connection: Server.ServerConnection):
elif post_type is None: elif post_type is None:
await put_response(decoded_raw_message) await put_response(decoded_raw_message)
except json.JSONDecodeError as e: except orjson.JSONDecodeError as e:
logger.error(f"消息解析失败: {e}") logger.error(f"消息解析失败: {e}")
logger.debug(f"原始消息: {raw_message[:500]}...") logger.debug(f"原始消息: {raw_message[:500]}...")
except Exception as e: except Exception as e:

View File

@@ -5,7 +5,7 @@
""" """
import asyncio import asyncio
import json import orjson
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@@ -34,7 +34,7 @@ class MessageChunker:
"""判断消息是否需要切片""" """判断消息是否需要切片"""
try: try:
if isinstance(message, dict): if isinstance(message, dict):
message_str = json.dumps(message, ensure_ascii=False) message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
else: else:
message_str = message message_str = message
return len(message_str.encode("utf-8")) > self.max_chunk_size return len(message_str.encode("utf-8")) > self.max_chunk_size
@@ -58,7 +58,7 @@ class MessageChunker:
try: try:
# 统一转换为字符串 # 统一转换为字符串
if isinstance(message, dict): if isinstance(message, dict):
message_str = json.dumps(message, ensure_ascii=False) message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
else: else:
message_str = message message_str = message
@@ -116,7 +116,7 @@ class MessageChunker:
"""判断是否是切片消息""" """判断是否是切片消息"""
try: try:
if isinstance(message, str): if isinstance(message, str):
data = json.loads(message) data = orjson.loads(message)
else: else:
data = message data = message
@@ -126,7 +126,7 @@ class MessageChunker:
and "__mmc_chunk_data__" in data and "__mmc_chunk_data__" in data
and "__mmc_is_chunked__" in data and "__mmc_is_chunked__" in data
) )
except (json.JSONDecodeError, TypeError): except (orjson.JSONDecodeError, TypeError):
return False return False
@@ -187,7 +187,7 @@ class MessageReassembler:
try: try:
# 统一转换为字典 # 统一转换为字典
if isinstance(message, str): if isinstance(message, str):
chunk_data = json.loads(message) chunk_data = orjson.loads(message)
else: else:
chunk_data = message chunk_data = message
@@ -197,8 +197,8 @@ class MessageReassembler:
if "_original_message" in chunk_data: if "_original_message" in chunk_data:
# 这是一个被包装的非切片消息,解包返回 # 这是一个被包装的非切片消息,解包返回
try: try:
return json.loads(chunk_data["_original_message"]) return orjson.loads(chunk_data["_original_message"])
except json.JSONDecodeError: except orjson.JSONDecodeError:
return {"text_message": chunk_data["_original_message"]} return {"text_message": chunk_data["_original_message"]}
else: else:
return chunk_data return chunk_data
@@ -251,14 +251,14 @@ class MessageReassembler:
# 尝试反序列化重组后的消息 # 尝试反序列化重组后的消息
try: try:
return json.loads(reassembled_message) return orjson.loads(reassembled_message)
except json.JSONDecodeError: except orjson.JSONDecodeError:
# 如果不能反序列化为JSON则作为文本消息返回 # 如果不能反序列化为JSON则作为文本消息返回
return {"text_message": reassembled_message} return {"text_message": reassembled_message}
return None return None
except (json.JSONDecodeError, KeyError, TypeError) as e: except (orjson.JSONDecodeError, KeyError, TypeError) as e:
logger.error(f"处理切片消息时出错: {e}") logger.error(f"处理切片消息时出错: {e}")
return None return None

View File

@@ -1,5 +1,5 @@
import base64 import base64
import json import orjson
import time import time
import uuid import uuid
from pathlib import Path from pathlib import Path
@@ -783,7 +783,7 @@ class MessageHandler:
# 检查JSON消息格式 # 检查JSON消息格式
if not message_data or "data" not in message_data: if not message_data or "data" not in message_data:
logger.warning("JSON消息格式不正确") logger.warning("JSON消息格式不正确")
return Seg(type="json", data=json.dumps(message_data)) return Seg(type="json", data=orjson.dumps(message_data).decode('utf-8'))
try: try:
# 尝试将json_data解析为Python对象 # 尝试将json_data解析为Python对象
@@ -1146,13 +1146,13 @@ class MessageHandler:
return None return None
forward_message_id = forward_message_data.get("id") forward_message_id = forward_message_data.get("id")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps( payload = orjson.dumps(
{ {
"action": "get_forward_msg", "action": "get_forward_msg",
"params": {"message_id": forward_message_id}, "params": {"message_id": forward_message_id},
"echo": request_uuid, "echo": request_uuid,
} }
) ).decode('utf-8')
try: try:
connection = self.get_server_connection() connection = self.get_server_connection()
if not connection: if not connection:
@@ -1167,9 +1167,9 @@ class MessageHandler:
logger.error(f"获取转发消息失败: {str(e)}") logger.error(f"获取转发消息失败: {str(e)}")
return None return None
logger.debug( logger.debug(
f"转发消息原始格式:{json.dumps(response)[:80]}..." f"转发消息原始格式:{orjson.dumps(response).decode('utf-8')[:80]}..."
if len(json.dumps(response)) > 80 if len(orjson.dumps(response).decode('utf-8')) > 80
else json.dumps(response) else orjson.dumps(response).decode('utf-8')
) )
response_data: Dict = response.get("data") response_data: Dict = response.get("data")
if not response_data: if not response_data:

View File

@@ -1,5 +1,5 @@
import asyncio import asyncio
import json import orjson
import time import time
from typing import ClassVar, Optional, Tuple from typing import ClassVar, Optional, Tuple
@@ -241,7 +241,7 @@ class NoticeHandler:
message_base: MessageBase = MessageBase( message_base: MessageBase = MessageBase(
message_info=message_info, message_info=message_info,
message_segment=handled_message, message_segment=handled_message,
raw_message=json.dumps(raw_message), raw_message=orjson.dumps(raw_message).decode('utf-8'),
) )
if system_notice: if system_notice:
@@ -602,7 +602,7 @@ class NoticeHandler:
message_base: MessageBase = MessageBase( message_base: MessageBase = MessageBase(
message_info=message_info, message_info=message_info,
message_segment=seg_message, message_segment=seg_message,
raw_message=json.dumps( raw_message=orjson.dumps(
{ {
"post_type": "notice", "post_type": "notice",
"notice_type": "group_ban", "notice_type": "group_ban",
@@ -611,7 +611,7 @@ class NoticeHandler:
"user_id": user_id, "user_id": user_id,
"operator_id": None, # 自然解除禁言没有操作者 "operator_id": None, # 自然解除禁言没有操作者
} }
), ).decode('utf-8'),
) )
await self.put_notice(message_base) await self.put_notice(message_base)

View File

@@ -1,4 +1,4 @@
import json import orjson
import random import random
import time import time
import uuid import uuid
@@ -605,7 +605,7 @@ class SendHandler:
async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict: async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict:
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": action, "params": params, "echo": request_uuid}) payload = orjson.dumps({"action": action, "params": params, "echo": request_uuid}).decode('utf-8')
# 获取当前连接 # 获取当前连接
connection = self.get_server_connection() connection = self.get_server_connection()

View File

@@ -1,6 +1,6 @@
import base64 import base64
import io import io
import json import orjson
import ssl import ssl
import uuid import uuid
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -34,7 +34,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
""" """
logger.debug("获取群聊信息中") logger.debug("获取群聊信息中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}) payload = orjson.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
socket_response: dict = await get_response(request_uuid) socket_response: dict = await get_response(request_uuid)
@@ -56,7 +56,7 @@ async def get_group_detail_info(websocket: Server.ServerConnection, group_id: in
""" """
logger.debug("获取群详细信息中") logger.debug("获取群详细信息中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}) payload = orjson.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
socket_response: dict = await get_response(request_uuid) socket_response: dict = await get_response(request_uuid)
@@ -78,13 +78,13 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use
""" """
logger.debug("获取群成员信息中") logger.debug("获取群成员信息中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps( payload = orjson.dumps(
{ {
"action": "get_group_member_info", "action": "get_group_member_info",
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True}, "params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
"echo": request_uuid, "echo": request_uuid,
} }
) ).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
socket_response: dict = await get_response(request_uuid) socket_response: dict = await get_response(request_uuid)
@@ -146,7 +146,7 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
""" """
logger.debug("获取自身信息中") logger.debug("获取自身信息中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}) payload = orjson.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
response: dict = await get_response(request_uuid) response: dict = await get_response(request_uuid)
@@ -183,7 +183,7 @@ async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) ->
""" """
logger.debug("获取陌生人信息中") logger.debug("获取陌生人信息中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid}) payload = orjson.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid}).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
response: dict = await get_response(request_uuid) response: dict = await get_response(request_uuid)
@@ -208,7 +208,7 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: Uni
""" """
logger.debug("获取消息详情中") logger.debug("获取消息详情中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}) payload = orjson.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
@@ -236,13 +236,13 @@ async def get_record_detail(
""" """
logger.debug("获取语音消息详情中") logger.debug("获取语音消息详情中")
request_uuid = str(uuid.uuid4()) request_uuid = str(uuid.uuid4())
payload = json.dumps( payload = orjson.dumps(
{ {
"action": "get_record", "action": "get_record",
"params": {"file": file, "file_id": file_id, "out_format": "wav"}, "params": {"file": file, "file_id": file_id, "out_format": "wav"},
"echo": request_uuid, "echo": request_uuid,
} }
) ).decode('utf-8')
try: try:
await websocket.send(payload) await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒 response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒

View File

@@ -1,7 +1,7 @@
""" """
Metaso Search Engine (Chat Completions Mode) Metaso Search Engine (Chat Completions Mode)
""" """
import json import orjson
from typing import Any from typing import Any
import httpx import httpx
@@ -43,12 +43,12 @@ class MetasoClient:
if data_str == "[DONE]": if data_str == "[DONE]":
break break
try: try:
data = json.loads(data_str) data = orjson.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {}) delta = data.get("choices", [{}])[0].get("delta", {})
content_chunk = delta.get("content") content_chunk = delta.get("content")
if content_chunk: if content_chunk:
full_response_content += content_chunk full_response_content += content_chunk
except json.JSONDecodeError: except orjson.JSONDecodeError:
logger.warning(f"Metaso stream: could not decode JSON line: {data_str}") logger.warning(f"Metaso stream: could not decode JSON line: {data_str}")
continue continue

View File

@@ -5,7 +5,7 @@
""" """
import asyncio import asyncio
import json import orjson
import logging import logging
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path

View File

@@ -4,7 +4,7 @@
直接从存储的数据文件生成可视化,无需启动完整的记忆管理器 直接从存储的数据文件生成可视化,无需启动完整的记忆管理器
""" """
import json import orjson
import sys import sys
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
@@ -122,7 +122,7 @@ def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]:
print(f"📂 加载图数据: {graph_file}") print(f"📂 加载图数据: {graph_file}")
with open(graph_file, 'r', encoding='utf-8') as f: with open(graph_file, 'r', encoding='utf-8') as f:
data = json.load(f) data = orjson.loads(f.read())
# 解析数据 # 解析数据
nodes_dict = {} nodes_dict = {}