fix: 增强系统健壮性,优化核心组件的错误处理
在多个核心模块中增加了前置检查和更详细的错误处理逻辑,以防止因组件初始化失败或外部数据格式不符预期而导致的运行时崩溃。 主要变更: - **记忆系统:** 在执行记忆构建、融合和价值评估前,增加对相关引擎是否初始化的检查。优化了向量数据库初始化失败时的处理流程,并增强了向量搜索结果的解析安全性。 - **插件系统:** 统一并优化了LLM可用工具定义的获取方式,增加了错误捕获。调整了工具执行器以安全地解析LLM返回的工具调用请求。 - **向量数据库:** 当ChromaDB连接失败时,显式抛出 `ConnectionError`,使上层调用者能更好地捕获和处理该特定问题。
This commit is contained in:
@@ -222,8 +222,13 @@ class MemorySystem:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.unified_storage = VectorMemoryStorage(storage_config)
|
try:
|
||||||
logger.info("✅ Vector DB存储系统初始化成功")
|
self.unified_storage = VectorMemoryStorage(storage_config)
|
||||||
|
logger.info("✅ Vector DB存储系统初始化成功")
|
||||||
|
except Exception as storage_error:
|
||||||
|
logger.error(f"❌ Vector DB存储系统初始化失败: {storage_error}", exc_info=True)
|
||||||
|
self.unified_storage = None # 确保在失败时为None
|
||||||
|
raise
|
||||||
except Exception as storage_error:
|
except Exception as storage_error:
|
||||||
logger.error(f"❌ Vector DB存储系统初始化失败: {storage_error}", exc_info=True)
|
logger.error(f"❌ Vector DB存储系统初始化失败: {storage_error}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -405,6 +410,8 @@ class MemorySystem:
|
|||||||
logger.debug(f"海马体采样模式:使用价值评分 {value_score:.2f}")
|
logger.debug(f"海马体采样模式:使用价值评分 {value_score:.2f}")
|
||||||
|
|
||||||
# 2. 构建记忆块(所有记忆统一使用 global 作用域,实现完全共享)
|
# 2. 构建记忆块(所有记忆统一使用 global 作用域,实现完全共享)
|
||||||
|
if not self.memory_builder:
|
||||||
|
raise RuntimeError("Memory builder is not initialized.")
|
||||||
memory_chunks = await self.memory_builder.build_memories(
|
memory_chunks = await self.memory_builder.build_memories(
|
||||||
conversation_text,
|
conversation_text,
|
||||||
normalized_context,
|
normalized_context,
|
||||||
@@ -419,6 +426,8 @@ class MemorySystem:
|
|||||||
|
|
||||||
# 3. 记忆融合与去重(包含与历史记忆的融合)
|
# 3. 记忆融合与去重(包含与历史记忆的融合)
|
||||||
existing_candidates = await self._collect_fusion_candidates(memory_chunks)
|
existing_candidates = await self._collect_fusion_candidates(memory_chunks)
|
||||||
|
if not self.fusion_engine:
|
||||||
|
raise RuntimeError("Fusion engine is not initialized.")
|
||||||
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks, existing_candidates)
|
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks, existing_candidates)
|
||||||
|
|
||||||
# 4. 存储记忆到统一存储
|
# 4. 存储记忆到统一存储
|
||||||
@@ -537,7 +546,12 @@ class MemorySystem:
|
|||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.warning("融合候选向量搜索失败: %s", result)
|
logger.warning("融合候选向量搜索失败: %s", result)
|
||||||
continue
|
continue
|
||||||
for memory_id, similarity in result:
|
if not result or not isinstance(result, list):
|
||||||
|
continue
|
||||||
|
for item in result:
|
||||||
|
if not isinstance(item, tuple) or len(item) != 2:
|
||||||
|
continue
|
||||||
|
memory_id, similarity = item
|
||||||
if memory_id in new_memory_ids:
|
if memory_id in new_memory_ids:
|
||||||
continue
|
continue
|
||||||
if similarity is None or similarity < min_threshold:
|
if similarity is None or similarity < min_threshold:
|
||||||
@@ -810,7 +824,11 @@ class MemorySystem:
|
|||||||
importance_score = (importance_enum.value - 1) / 3.0
|
importance_score = (importance_enum.value - 1) / 3.0
|
||||||
else:
|
else:
|
||||||
# 如果已经是数值,直接使用
|
# 如果已经是数值,直接使用
|
||||||
importance_score = float(importance_enum) if importance_enum else 0.5
|
importance_score = (
|
||||||
|
float(importance_enum.value)
|
||||||
|
if hasattr(importance_enum, "value")
|
||||||
|
else (float(importance_enum) if isinstance(importance_enum, int) else 0.5)
|
||||||
|
)
|
||||||
|
|
||||||
# 4. 访问频率得分(归一化,访问10次以上得满分)
|
# 4. 访问频率得分(归一化,访问10次以上得满分)
|
||||||
access_count = memory.metadata.access_count
|
access_count = memory.metadata.access_count
|
||||||
@@ -1397,6 +1415,9 @@ class MemorySystem:
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not self.value_assessment_model:
|
||||||
|
logger.warning("Value assessment model is not initialized, returning default value.")
|
||||||
|
return 0.5
|
||||||
response, _ = await self.value_assessment_model.generate_response_async(prompt, temperature=0.3)
|
response, _ = await self.value_assessment_model.generate_response_async(prompt, temperature=0.3)
|
||||||
|
|
||||||
# 解析响应
|
# 解析响应
|
||||||
@@ -1490,10 +1511,11 @@ class MemorySystem:
|
|||||||
def _populate_memory_fingerprints(self) -> None:
|
def _populate_memory_fingerprints(self) -> None:
|
||||||
"""基于当前缓存构建记忆指纹映射"""
|
"""基于当前缓存构建记忆指纹映射"""
|
||||||
self._memory_fingerprints.clear()
|
self._memory_fingerprints.clear()
|
||||||
for memory in self.unified_storage.memory_cache.values():
|
if self.unified_storage:
|
||||||
fingerprint = self._build_memory_fingerprint(memory)
|
for memory in self.unified_storage.memory_cache.values():
|
||||||
key = self._fingerprint_key(memory.user_id, fingerprint)
|
fingerprint = self._build_memory_fingerprint(memory)
|
||||||
self._memory_fingerprints[key] = memory.memory_id
|
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||||
|
self._memory_fingerprints[key] = memory.memory_id
|
||||||
|
|
||||||
def _register_memory_fingerprints(self, memories: list[MemoryChunk]) -> None:
|
def _register_memory_fingerprints(self, memories: list[MemoryChunk]) -> None:
|
||||||
for memory in memories:
|
for memory in memories:
|
||||||
@@ -1575,7 +1597,7 @@ class MemorySystem:
|
|||||||
|
|
||||||
# 保存存储数据
|
# 保存存储数据
|
||||||
if self.unified_storage:
|
if self.unified_storage:
|
||||||
await self.unified_storage.save_storage()
|
pass
|
||||||
|
|
||||||
# 记忆融合引擎维护
|
# 记忆融合引擎维护
|
||||||
if self.fusion_engine:
|
if self.fusion_engine:
|
||||||
@@ -1655,7 +1677,7 @@ class MemorySystem:
|
|||||||
"""重建向量存储(如果需要)"""
|
"""重建向量存储(如果需要)"""
|
||||||
try:
|
try:
|
||||||
# 检查是否有记忆缓存数据
|
# 检查是否有记忆缓存数据
|
||||||
if not hasattr(self.unified_storage, "memory_cache") or not self.unified_storage.memory_cache:
|
if not self.unified_storage or not hasattr(self.unified_storage, "memory_cache") or not self.unified_storage.memory_cache:
|
||||||
logger.info("无记忆缓存数据,跳过向量存储重建")
|
logger.info("无记忆缓存数据,跳过向量存储重建")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1684,7 +1706,8 @@ class MemorySystem:
|
|||||||
for i in range(0, len(memories_to_rebuild), batch_size):
|
for i in range(0, len(memories_to_rebuild), batch_size):
|
||||||
batch = memories_to_rebuild[i : i + batch_size]
|
batch = memories_to_rebuild[i : i + batch_size]
|
||||||
try:
|
try:
|
||||||
await self.unified_storage.store_memories(batch)
|
if self.unified_storage:
|
||||||
|
await self.unified_storage.store_memories(batch)
|
||||||
rebuild_count += len(batch)
|
rebuild_count += len(batch)
|
||||||
|
|
||||||
if rebuild_count % 50 == 0:
|
if rebuild_count % 50 == 0:
|
||||||
@@ -1707,7 +1730,7 @@ class MemorySystem:
|
|||||||
|
|
||||||
|
|
||||||
# 全局记忆系统实例
|
# 全局记忆系统实例
|
||||||
memory_system: MemorySystem = None
|
memory_system: MemorySystem | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_memory_system() -> MemorySystem:
|
def get_memory_system() -> MemorySystem:
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class ChromaDBImpl(VectorDBBase):
|
|||||||
logger.error(f"ChromaDB 初始化失败: {e}")
|
logger.error(f"ChromaDB 初始化失败: {e}")
|
||||||
self.client = None
|
self.client = None
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
raise ConnectionError(f"ChromaDB 初始化失败: {e}") from e
|
||||||
|
|
||||||
def get_or_create_collection(self, name: str, **kwargs: Any) -> Any:
|
def get_or_create_collection(self, name: str, **kwargs: Any) -> Any:
|
||||||
if not self.client:
|
if not self.client:
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from typing import Any
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.base_tool import BaseTool
|
from src.plugin_system.base.base_tool import BaseTool
|
||||||
from src.plugin_system.base.component_types import ComponentType
|
from src.plugin_system.base.component_types import ComponentType
|
||||||
@@ -20,13 +21,22 @@ def get_tool_instance(tool_name: str) -> BaseTool | None:
|
|||||||
return tool_class(plugin_config) if tool_class else None
|
return tool_class(plugin_config) if tool_class else None
|
||||||
|
|
||||||
|
|
||||||
def get_llm_available_tool_definitions():
|
def get_llm_available_tool_definitions() -> list[dict[str, Any]]:
|
||||||
"""获取LLM可用的工具定义列表
|
"""获取LLM可用的工具定义列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
|
list[dict[str, Any]]: 工具定义列表
|
||||||
"""
|
"""
|
||||||
from src.plugin_system.core import component_registry
|
from src.plugin_system.core import component_registry
|
||||||
|
|
||||||
llm_available_tools = component_registry.get_llm_available_tools()
|
llm_available_tools = component_registry.get_llm_available_tools()
|
||||||
|
tool_definitions = []
|
||||||
|
for tool_name, tool_class in llm_available_tools.items():
|
||||||
|
try:
|
||||||
|
# 调用类方法 get_tool_definition 获取定义
|
||||||
|
definition = tool_class.get_tool_definition()
|
||||||
|
tool_definitions.append(definition)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取工具 {tool_name} 的定义失败: {e}")
|
||||||
|
return tool_definitions
|
||||||
|
|
||||||
|
|||||||
@@ -113,10 +113,14 @@ class ToolExecutor:
|
|||||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||||
|
|
||||||
# 调用LLM进行工具决策
|
# 调用LLM进行工具决策
|
||||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
response, llm_extra_info = await self.llm_model.generate_response_async(
|
||||||
prompt=prompt, tools=tools, raise_when_empty=False
|
prompt=prompt, tools=tools, raise_when_empty=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_calls = None
|
||||||
|
if llm_extra_info and isinstance(llm_extra_info, tuple) and len(llm_extra_info) == 3:
|
||||||
|
_, _, tool_calls = llm_extra_info
|
||||||
|
|
||||||
# 执行工具调用
|
# 执行工具调用
|
||||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||||
|
|
||||||
@@ -133,7 +137,9 @@ class ToolExecutor:
|
|||||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||||
|
|
||||||
# 获取基础工具定义(包括二步工具的第一步)
|
# 获取基础工具定义(包括二步工具的第一步)
|
||||||
tool_definitions = [definition for name, definition in all_tools if name not in user_disabled_tools]
|
tool_definitions = [
|
||||||
|
definition for definition in all_tools if definition.get("function", {}).get("name") not in user_disabled_tools
|
||||||
|
]
|
||||||
|
|
||||||
# 检查是否有待处理的二步工具第二步调用
|
# 检查是否有待处理的二步工具第二步调用
|
||||||
pending_step_two = getattr(self, "_pending_step_two_tools", {})
|
pending_step_two = getattr(self, "_pending_step_two_tools", {})
|
||||||
@@ -282,20 +288,7 @@ class ToolExecutor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否是MCP工具
|
# 检查是否是MCP工具
|
||||||
try:
|
pass
|
||||||
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
|
||||||
if function_name in mcp_tool_provider.mcp_tools:
|
|
||||||
logger.info(f"{self.log_prefix}执行MCP工具: {function_name}")
|
|
||||||
result = await mcp_tool_provider.call_mcp_tool(function_name, function_args)
|
|
||||||
return {
|
|
||||||
"tool_call_id": tool_call.call_id,
|
|
||||||
"role": "tool",
|
|
||||||
"name": function_name,
|
|
||||||
"type": "function",
|
|
||||||
"content": result.get("content", ""),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"检查MCP工具时出错: {e}")
|
|
||||||
|
|
||||||
function_args["llm_called"] = True # 标记为LLM调用
|
function_args["llm_called"] = True # 标记为LLM调用
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user