修复代码格式和文件名大小写问题
This commit is contained in:
@@ -17,7 +17,7 @@ from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from sqlalchemy import select,insert,update,delete
|
||||
from sqlalchemy import select, insert, update, delete
|
||||
from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
@@ -41,6 +41,7 @@ def cosine_similarity(v1, v2):
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
def calculate_information_content(text):
|
||||
"""计算文本的信息量(熵)"""
|
||||
char_count = Counter(text)
|
||||
@@ -783,7 +784,9 @@ class Hippocampus:
|
||||
|
||||
return result
|
||||
|
||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
|
||||
async def get_activate_from_text(
|
||||
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||
) -> tuple[float, list[str]]:
|
||||
"""从文本中提取关键词并获取相关记忆。
|
||||
|
||||
Args:
|
||||
@@ -951,10 +954,10 @@ class EntorhinalCortex:
|
||||
current_memorized_times = message.get("memorized_times", 0)
|
||||
with get_db_session() as session:
|
||||
session.execute(
|
||||
update(Messages)
|
||||
.where(Messages.message_id == message["message_id"])
|
||||
.values(memorized_times=current_memorized_times + 1)
|
||||
)
|
||||
update(Messages)
|
||||
.where(Messages.message_id == message["message_id"])
|
||||
.values(memorized_times=current_memorized_times + 1)
|
||||
)
|
||||
session.commit()
|
||||
return messages # 直接返回原始的消息列表
|
||||
|
||||
@@ -1040,7 +1043,6 @@ class EntorhinalCortex:
|
||||
for i in range(0, len(nodes_to_create), batch_size):
|
||||
batch = nodes_to_create[i : i + batch_size]
|
||||
session.execute(insert(GraphNodes), batch)
|
||||
|
||||
|
||||
if nodes_to_update:
|
||||
batch_size = 100
|
||||
@@ -1052,11 +1054,9 @@ class EntorhinalCortex:
|
||||
.where(GraphNodes.concept == node_data["concept"])
|
||||
.values(**{k: v for k, v in node_data.items() if k != "concept"})
|
||||
)
|
||||
|
||||
|
||||
if nodes_to_delete:
|
||||
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
|
||||
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(session.execute(select(GraphEdges)).scalars())
|
||||
@@ -1112,7 +1112,6 @@ class EntorhinalCortex:
|
||||
for i in range(0, len(edges_to_create), batch_size):
|
||||
batch = edges_to_create[i : i + batch_size]
|
||||
session.execute(insert(GraphEdges), batch)
|
||||
|
||||
|
||||
if edges_to_update:
|
||||
batch_size = 100
|
||||
@@ -1126,7 +1125,6 @@ class EntorhinalCortex:
|
||||
)
|
||||
.values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]})
|
||||
)
|
||||
|
||||
|
||||
if edges_to_delete:
|
||||
for source, target in edges_to_delete:
|
||||
@@ -1137,12 +1135,10 @@ class EntorhinalCortex:
|
||||
# 提交事务
|
||||
session.commit()
|
||||
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}秒")
|
||||
logger.info(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
|
||||
|
||||
|
||||
async def resync_memory_to_db(self):
|
||||
"""清空数据库并重新同步所有记忆数据"""
|
||||
start_time = time.time()
|
||||
@@ -1153,7 +1149,7 @@ class EntorhinalCortex:
|
||||
clear_start = time.time()
|
||||
session.execute(delete(GraphNodes))
|
||||
session.execute(delete(GraphEdges))
|
||||
|
||||
|
||||
clear_end = time.time()
|
||||
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
||||
|
||||
@@ -1211,7 +1207,7 @@ class EntorhinalCortex:
|
||||
for i in range(0, len(nodes_data), batch_size):
|
||||
batch = nodes_data[i : i + batch_size]
|
||||
session.execute(insert(GraphNodes), batch)
|
||||
|
||||
|
||||
node_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||
|
||||
@@ -1223,7 +1219,7 @@ class EntorhinalCortex:
|
||||
batch = edges_data[i : i + batch_size]
|
||||
session.execute(insert(GraphEdges), batch)
|
||||
session.commit()
|
||||
|
||||
|
||||
edge_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||
|
||||
@@ -1264,10 +1260,7 @@ class EntorhinalCortex:
|
||||
if not node.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
session.execute(
|
||||
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
|
||||
)
|
||||
|
||||
session.execute(update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data))
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = node.created_time or current_time
|
||||
@@ -1303,7 +1296,6 @@ class EntorhinalCortex:
|
||||
.where((GraphEdges.source == source) & (GraphEdges.target == target))
|
||||
.values(**update_data)
|
||||
)
|
||||
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = edge.created_time or current_time
|
||||
@@ -1325,8 +1317,10 @@ class ParahippocampalGyrus:
|
||||
def __init__(self, hippocampus: Hippocampus):
|
||||
self.hippocampus = hippocampus
|
||||
self.memory_graph = hippocampus.memory_graph
|
||||
|
||||
self.memory_modify_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory.modify")
|
||||
|
||||
self.memory_modify_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="memory.modify"
|
||||
)
|
||||
|
||||
async def memory_compress(self, messages: list, compress_rate=0.1):
|
||||
"""压缩和总结消息内容,生成记忆主题和摘要。
|
||||
@@ -1623,14 +1617,20 @@ class ParahippocampalGyrus:
|
||||
similarity = self._calculate_item_similarity(memory_items[i], memory_items[j])
|
||||
if similarity > 0.8: # 相似度阈值
|
||||
# 合并相似记忆项
|
||||
longer_item = memory_items[i] if len(memory_items[i]) > len(memory_items[j]) else memory_items[j]
|
||||
shorter_item = memory_items[j] if len(memory_items[i]) > len(memory_items[j]) else memory_items[i]
|
||||
longer_item = (
|
||||
memory_items[i] if len(memory_items[i]) > len(memory_items[j]) else memory_items[j]
|
||||
)
|
||||
shorter_item = (
|
||||
memory_items[j] if len(memory_items[i]) > len(memory_items[j]) else memory_items[i]
|
||||
)
|
||||
|
||||
# 保留更长的记忆项,标记短的用于删除
|
||||
if shorter_item not in items_to_remove:
|
||||
items_to_remove.append(shorter_item)
|
||||
merged_count += 1
|
||||
logger.debug(f"[整合] 在节点 {node} 中合并相似记忆: {shorter_item[:30]}... -> {longer_item[:30]}...")
|
||||
logger.debug(
|
||||
f"[整合] 在节点 {node} 中合并相似记忆: {shorter_item[:30]}... -> {longer_item[:30]}..."
|
||||
)
|
||||
|
||||
# 移除被合并的记忆项
|
||||
if items_to_remove:
|
||||
@@ -1657,11 +1657,11 @@ class ParahippocampalGyrus:
|
||||
|
||||
# 检查是否有变化需要同步到数据库
|
||||
has_changes = (
|
||||
edge_changes["weakened"] or
|
||||
edge_changes["removed"] or
|
||||
node_changes["reduced"] or
|
||||
node_changes["removed"] or
|
||||
merged_count > 0
|
||||
edge_changes["weakened"]
|
||||
or edge_changes["removed"]
|
||||
or node_changes["reduced"]
|
||||
or node_changes["removed"]
|
||||
or merged_count > 0
|
||||
)
|
||||
|
||||
if has_changes:
|
||||
@@ -1773,7 +1773,9 @@ class HippocampusManager:
|
||||
response = []
|
||||
return response
|
||||
|
||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
|
||||
async def get_activate_from_text(
|
||||
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||
) -> tuple[float, list[str]]:
|
||||
"""从文本中获取激活值的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
@@ -1797,6 +1799,6 @@ class HippocampusManager:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus.get_all_node_names()
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
hippocampus_manager = HippocampusManager()
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import os
|
||||
from typing import Dict, Any
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
@@ -19,68 +19,64 @@ from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
logger = get_logger("action_diagnostics")
|
||||
|
||||
|
||||
class ActionDiagnostics:
|
||||
"""Action组件诊断器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.required_actions = ["no_reply", "reply", "emoji", "at_user"]
|
||||
|
||||
|
||||
def check_plugin_loading(self) -> Dict[str, Any]:
|
||||
"""检查插件加载状态"""
|
||||
logger.info("开始检查插件加载状态...")
|
||||
|
||||
|
||||
result = {
|
||||
"plugins_loaded": False,
|
||||
"total_plugins": 0,
|
||||
"loaded_plugins": [],
|
||||
"failed_plugins": [],
|
||||
"core_actions_plugin": None
|
||||
"core_actions_plugin": None,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# 加载所有插件
|
||||
plugin_manager.load_all_plugins()
|
||||
|
||||
|
||||
# 获取插件统计信息
|
||||
stats = plugin_manager.get_stats()
|
||||
result["plugins_loaded"] = True
|
||||
result["total_plugins"] = stats.get("total_plugins", 0)
|
||||
|
||||
|
||||
# 检查是否有core_actions插件
|
||||
for plugin_name in plugin_manager.loaded_plugins:
|
||||
result["loaded_plugins"].append(plugin_name)
|
||||
if "core_actions" in plugin_name.lower():
|
||||
result["core_actions_plugin"] = plugin_name
|
||||
|
||||
|
||||
logger.info(f"插件加载成功,总数: {result['total_plugins']}")
|
||||
logger.info(f"已加载插件: {result['loaded_plugins']}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"插件加载失败: {e}")
|
||||
result["error"] = str(e)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def check_action_registry(self) -> Dict[str, Any]:
|
||||
"""检查Action注册状态"""
|
||||
logger.info("开始检查Action组件注册状态...")
|
||||
|
||||
result = {
|
||||
"registered_actions": [],
|
||||
"missing_actions": [],
|
||||
"default_actions": {},
|
||||
"total_actions": 0
|
||||
}
|
||||
|
||||
|
||||
result = {"registered_actions": [], "missing_actions": [], "default_actions": {}, "total_actions": 0}
|
||||
|
||||
try:
|
||||
# 获取所有注册的Action
|
||||
all_components = component_registry.get_all_components(ComponentType.ACTION)
|
||||
result["total_actions"] = len(all_components)
|
||||
|
||||
|
||||
for name, info in all_components.items():
|
||||
result["registered_actions"].append(name)
|
||||
logger.debug(f"已注册Action: {name} (插件: {info.plugin_name})")
|
||||
|
||||
|
||||
# 检查必需的Action是否存在
|
||||
for required_action in self.required_actions:
|
||||
if required_action not in all_components:
|
||||
@@ -88,32 +84,32 @@ class ActionDiagnostics:
|
||||
logger.warning(f"缺失必需Action: {required_action}")
|
||||
else:
|
||||
logger.info(f"找到必需Action: {required_action}")
|
||||
|
||||
|
||||
# 获取默认Action
|
||||
default_actions = component_registry.get_default_actions()
|
||||
result["default_actions"] = {name: info.plugin_name for name, info in default_actions.items()}
|
||||
|
||||
|
||||
logger.info(f"总注册Action数量: {result['total_actions']}")
|
||||
logger.info(f"缺失Action: {result['missing_actions']}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Action注册检查失败: {e}")
|
||||
result["error"] = str(e)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def check_specific_action(self, action_name: str) -> Dict[str, Any]:
|
||||
"""检查特定Action的详细信息"""
|
||||
logger.info(f"检查Action详细信息: {action_name}")
|
||||
|
||||
|
||||
result = {
|
||||
"exists": False,
|
||||
"component_info": None,
|
||||
"component_class": None,
|
||||
"is_default": False,
|
||||
"plugin_name": None
|
||||
"plugin_name": None,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# 检查组件信息
|
||||
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION)
|
||||
@@ -123,14 +119,14 @@ class ActionDiagnostics:
|
||||
"name": component_info.name,
|
||||
"description": component_info.description,
|
||||
"plugin_name": component_info.plugin_name,
|
||||
"version": component_info.version
|
||||
"version": component_info.version,
|
||||
}
|
||||
result["plugin_name"] = component_info.plugin_name
|
||||
logger.info(f"找到Action组件信息: {action_name}")
|
||||
else:
|
||||
logger.warning(f"未找到Action组件信息: {action_name}")
|
||||
return result
|
||||
|
||||
|
||||
# 检查组件类
|
||||
component_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
||||
if component_class:
|
||||
@@ -138,36 +134,32 @@ class ActionDiagnostics:
|
||||
logger.info(f"找到Action组件类: {component_class.__name__}")
|
||||
else:
|
||||
logger.warning(f"未找到Action组件类: {action_name}")
|
||||
|
||||
|
||||
# 检查是否为默认Action
|
||||
default_actions = component_registry.get_default_actions()
|
||||
result["is_default"] = action_name in default_actions
|
||||
|
||||
|
||||
logger.info(f"Action {action_name} 检查完成: 存在={result['exists']}, 默认={result['is_default']}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查Action {action_name} 失败: {e}")
|
||||
result["error"] = str(e)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def attempt_fix_missing_actions(self) -> Dict[str, Any]:
|
||||
"""尝试修复缺失的Action"""
|
||||
logger.info("尝试修复缺失的Action组件...")
|
||||
|
||||
result = {
|
||||
"fixed_actions": [],
|
||||
"still_missing": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
|
||||
result = {"fixed_actions": [], "still_missing": [], "errors": []}
|
||||
|
||||
try:
|
||||
# 重新加载插件
|
||||
plugin_manager.load_all_plugins()
|
||||
|
||||
|
||||
# 再次检查Action注册状态
|
||||
registry_check = self.check_action_registry()
|
||||
|
||||
|
||||
for required_action in self.required_actions:
|
||||
if required_action in registry_check["missing_actions"]:
|
||||
try:
|
||||
@@ -182,107 +174,100 @@ class ActionDiagnostics:
|
||||
logger.error(error_msg)
|
||||
result["errors"].append(error_msg)
|
||||
result["still_missing"].append(required_action)
|
||||
|
||||
|
||||
logger.info(f"Action修复完成: 已修复={result['fixed_actions']}, 仍缺失={result['still_missing']}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Action修复过程失败: {e}"
|
||||
logger.error(error_msg)
|
||||
result["errors"].append(error_msg)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _register_no_reply_action(self):
|
||||
"""手动注册no_reply Action"""
|
||||
try:
|
||||
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
|
||||
# 创建Action信息
|
||||
action_info = ActionInfo(
|
||||
name="no_reply",
|
||||
description="暂时不回复消息",
|
||||
plugin_name="built_in.core_actions",
|
||||
version="1.0.0"
|
||||
name="no_reply", description="暂时不回复消息", plugin_name="built_in.core_actions", version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
# 注册Action
|
||||
success = component_registry._register_action_component(action_info, NoReplyAction)
|
||||
if success:
|
||||
logger.info("手动注册no_reply Action成功")
|
||||
else:
|
||||
raise Exception("注册失败")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"手动注册no_reply Action失败: {e}") from e
|
||||
|
||||
|
||||
def run_full_diagnosis(self) -> Dict[str, Any]:
|
||||
"""运行完整诊断"""
|
||||
logger.info("🔧 开始Action组件完整诊断")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
diagnosis_result = {
|
||||
"plugin_status": {},
|
||||
"registry_status": {},
|
||||
"action_details": {},
|
||||
"fix_attempts": {},
|
||||
"summary": {}
|
||||
"summary": {},
|
||||
}
|
||||
|
||||
|
||||
# 1. 检查插件加载
|
||||
logger.info("\n📦 步骤1: 检查插件加载状态")
|
||||
diagnosis_result["plugin_status"] = self.check_plugin_loading()
|
||||
|
||||
|
||||
# 2. 检查Action注册
|
||||
logger.info("\n📋 步骤2: 检查Action注册状态")
|
||||
diagnosis_result["registry_status"] = self.check_action_registry()
|
||||
|
||||
|
||||
# 3. 检查特定Action详细信息
|
||||
logger.info("\n🔍 步骤3: 检查特定Action详细信息")
|
||||
diagnosis_result["action_details"] = {}
|
||||
for action in self.required_actions:
|
||||
diagnosis_result["action_details"][action] = self.check_specific_action(action)
|
||||
|
||||
|
||||
# 4. 尝试修复缺失的Action
|
||||
if diagnosis_result["registry_status"].get("missing_actions"):
|
||||
logger.info("\n🔧 步骤4: 尝试修复缺失的Action")
|
||||
diagnosis_result["fix_attempts"] = self.attempt_fix_missing_actions()
|
||||
|
||||
|
||||
# 5. 生成诊断摘要
|
||||
logger.info("\n📊 步骤5: 生成诊断摘要")
|
||||
diagnosis_result["summary"] = self._generate_summary(diagnosis_result)
|
||||
|
||||
|
||||
self._print_diagnosis_results(diagnosis_result)
|
||||
|
||||
|
||||
return diagnosis_result
|
||||
|
||||
|
||||
def _generate_summary(self, diagnosis_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""生成诊断摘要"""
|
||||
summary = {
|
||||
"overall_status": "unknown",
|
||||
"critical_issues": [],
|
||||
"recommendations": []
|
||||
}
|
||||
|
||||
summary = {"overall_status": "unknown", "critical_issues": [], "recommendations": []}
|
||||
|
||||
try:
|
||||
# 检查插件加载状态
|
||||
if not diagnosis_result["plugin_status"].get("plugins_loaded"):
|
||||
summary["critical_issues"].append("插件加载失败")
|
||||
summary["recommendations"].append("检查插件系统配置")
|
||||
|
||||
|
||||
# 检查必需Action
|
||||
missing_actions = diagnosis_result["registry_status"].get("missing_actions", [])
|
||||
if "no_reply" in missing_actions:
|
||||
summary["critical_issues"].append("缺失no_reply Action")
|
||||
summary["recommendations"].append("检查core_actions插件是否正确加载")
|
||||
|
||||
|
||||
# 检查修复结果
|
||||
if diagnosis_result.get("fix_attempts"):
|
||||
still_missing = diagnosis_result["fix_attempts"].get("still_missing", [])
|
||||
if still_missing:
|
||||
summary["critical_issues"].append(f"修复后仍缺失Action: {still_missing}")
|
||||
summary["recommendations"].append("需要手动修复插件注册问题")
|
||||
|
||||
|
||||
# 确定整体状态
|
||||
if not summary["critical_issues"]:
|
||||
summary["overall_status"] = "healthy"
|
||||
@@ -290,103 +275,106 @@ class ActionDiagnostics:
|
||||
summary["overall_status"] = "warning"
|
||||
else:
|
||||
summary["overall_status"] = "critical"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
summary["critical_issues"].append(f"摘要生成失败: {e}")
|
||||
summary["overall_status"] = "error"
|
||||
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def _print_diagnosis_results(self, diagnosis_result: Dict[str, Any]):
|
||||
"""打印诊断结果"""
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("📈 诊断结果摘要")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
summary = diagnosis_result.get("summary", {})
|
||||
overall_status = summary.get("overall_status", "unknown")
|
||||
|
||||
|
||||
# 状态指示器
|
||||
status_indicators = {
|
||||
"healthy": "✅ 系统健康",
|
||||
"warning": "⚠️ 存在警告",
|
||||
"critical": "❌ 存在严重问题",
|
||||
"error": "💥 诊断出错",
|
||||
"unknown": "❓ 状态未知"
|
||||
"unknown": "❓ 状态未知",
|
||||
}
|
||||
|
||||
|
||||
logger.info(f"🎯 整体状态: {status_indicators.get(overall_status, overall_status)}")
|
||||
|
||||
|
||||
# 关键问题
|
||||
critical_issues = summary.get("critical_issues", [])
|
||||
if critical_issues:
|
||||
logger.info("\n🚨 关键问题:")
|
||||
for issue in critical_issues:
|
||||
logger.info(f" • {issue}")
|
||||
|
||||
|
||||
# 建议
|
||||
recommendations = summary.get("recommendations", [])
|
||||
if recommendations:
|
||||
logger.info("\n💡 建议:")
|
||||
for rec in recommendations:
|
||||
logger.info(f" • {rec}")
|
||||
|
||||
|
||||
# 详细状态
|
||||
plugin_status = diagnosis_result.get("plugin_status", {})
|
||||
if plugin_status.get("plugins_loaded"):
|
||||
logger.info(f"\n📦 插件状态: 已加载 {plugin_status.get('total_plugins', 0)} 个插件")
|
||||
else:
|
||||
logger.info("\n📦 插件状态: ❌ 插件加载失败")
|
||||
|
||||
|
||||
registry_status = diagnosis_result.get("registry_status", {})
|
||||
total_actions = registry_status.get("total_actions", 0)
|
||||
missing_actions = registry_status.get("missing_actions", [])
|
||||
logger.info(f"📋 Action状态: 已注册 {total_actions} 个,缺失 {len(missing_actions)} 个")
|
||||
|
||||
|
||||
if missing_actions:
|
||||
logger.info(f" 缺失的Action: {missing_actions}")
|
||||
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
diagnostics = ActionDiagnostics()
|
||||
|
||||
|
||||
try:
|
||||
result = diagnostics.run_full_diagnosis()
|
||||
|
||||
|
||||
# 保存诊断结果
|
||||
import orjson
|
||||
|
||||
with open("action_diagnosis_results.json", "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(
|
||||
result, option=orjson.OPT_INDENT_2).decode('utf-8')
|
||||
)
|
||||
f.write(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
logger.info("📄 诊断结果已保存到: action_diagnosis_results.json")
|
||||
|
||||
|
||||
# 根据诊断结果返回适当的退出代码
|
||||
summary = result.get("summary", {})
|
||||
overall_status = summary.get("overall_status", "unknown")
|
||||
|
||||
|
||||
if overall_status == "healthy":
|
||||
return 0
|
||||
elif overall_status == "warning":
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("❌ 诊断被用户中断")
|
||||
return 3
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 诊断执行失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return 4
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
||||
|
||||
@@ -12,9 +12,10 @@ from src.config.config import global_config
|
||||
|
||||
logger = get_logger("async_instant_memory_wrapper")
|
||||
|
||||
|
||||
class AsyncInstantMemoryWrapper:
|
||||
"""异步瞬时记忆包装器"""
|
||||
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.llm_memory = None
|
||||
@@ -32,6 +33,7 @@ class AsyncInstantMemoryWrapper:
|
||||
if self.llm_memory is None and self.llm_memory_enabled:
|
||||
try:
|
||||
from src.chat.memory_system.instant_memory import InstantMemory
|
||||
|
||||
self.llm_memory = InstantMemory(self.chat_id)
|
||||
logger.info(f"LLM瞬时记忆系统已初始化: {self.chat_id}")
|
||||
except Exception as e:
|
||||
@@ -43,80 +45,76 @@ class AsyncInstantMemoryWrapper:
|
||||
if self.vector_memory is None and self.vector_memory_enabled:
|
||||
try:
|
||||
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
|
||||
|
||||
self.vector_memory = VectorInstantMemoryV2(self.chat_id)
|
||||
logger.info(f"向量瞬时记忆系统已初始化: {self.chat_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"向量瞬时记忆系统初始化失败: {e}")
|
||||
self.vector_memory_enabled = False # 初始化失败则禁用
|
||||
|
||||
self.vector_memory_enabled = False # 初始化失败则禁用
|
||||
|
||||
def _get_cache_key(self, operation: str, content: str) -> str:
|
||||
"""生成缓存键"""
|
||||
return f"{operation}_{self.chat_id}_{hash(content)}"
|
||||
|
||||
|
||||
def _is_cache_valid(self, cache_key: str) -> bool:
|
||||
"""检查缓存是否有效"""
|
||||
if cache_key not in self.cache:
|
||||
return False
|
||||
|
||||
|
||||
_, timestamp = self.cache[cache_key]
|
||||
return time.time() - timestamp < self.cache_ttl
|
||||
|
||||
|
||||
def _get_cached_result(self, cache_key: str) -> Optional[Any]:
|
||||
"""获取缓存结果"""
|
||||
if self._is_cache_valid(cache_key):
|
||||
result, _ = self.cache[cache_key]
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _cache_result(self, cache_key: str, result: Any):
|
||||
"""缓存结果"""
|
||||
self.cache[cache_key] = (result, time.time())
|
||||
|
||||
|
||||
async def store_memory_async(self, content: str, timeout: Optional[float] = None) -> bool:
|
||||
"""异步存储记忆(带超时控制)"""
|
||||
if timeout is None:
|
||||
timeout = self.default_timeout
|
||||
|
||||
|
||||
success_count = 0
|
||||
|
||||
|
||||
# 异步存储到LLM记忆系统
|
||||
await self._ensure_llm_memory()
|
||||
if self.llm_memory:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.llm_memory.create_and_store_memory(content),
|
||||
timeout=timeout
|
||||
)
|
||||
await asyncio.wait_for(self.llm_memory.create_and_store_memory(content), timeout=timeout)
|
||||
success_count += 1
|
||||
logger.debug(f"LLM记忆存储成功: {content[:50]}...")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"LLM记忆存储超时: {content[:50]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM记忆存储失败: {e}")
|
||||
|
||||
|
||||
# 异步存储到向量记忆系统
|
||||
await self._ensure_vector_memory()
|
||||
if self.vector_memory:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.vector_memory.store_message(content),
|
||||
timeout=timeout
|
||||
)
|
||||
await asyncio.wait_for(self.vector_memory.store_message(content), timeout=timeout)
|
||||
success_count += 1
|
||||
logger.debug(f"向量记忆存储成功: {content[:50]}...")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"向量记忆存储超时: {content[:50]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"向量记忆存储失败: {e}")
|
||||
|
||||
|
||||
return success_count > 0
|
||||
|
||||
async def retrieve_memory_async(self, query: str, timeout: Optional[float] = None,
|
||||
use_cache: bool = True) -> Optional[Any]:
|
||||
|
||||
async def retrieve_memory_async(
|
||||
self, query: str, timeout: Optional[float] = None, use_cache: bool = True
|
||||
) -> Optional[Any]:
|
||||
"""异步检索记忆(带缓存和超时控制)"""
|
||||
if timeout is None:
|
||||
timeout = self.default_timeout
|
||||
|
||||
|
||||
# 检查缓存
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key("retrieve", query)
|
||||
@@ -124,17 +122,17 @@ class AsyncInstantMemoryWrapper:
|
||||
if cached_result is not None:
|
||||
logger.debug(f"记忆检索命中缓存: {query[:30]}...")
|
||||
return cached_result
|
||||
|
||||
|
||||
# 尝试多种记忆系统
|
||||
results = []
|
||||
|
||||
|
||||
# 从向量记忆系统检索(优先,速度快)
|
||||
await self._ensure_vector_memory()
|
||||
if self.vector_memory:
|
||||
try:
|
||||
vector_result = await asyncio.wait_for(
|
||||
self.vector_memory.get_memory_for_context(query),
|
||||
timeout=timeout * 0.6 # 给向量系统60%的时间
|
||||
timeout=timeout * 0.6, # 给向量系统60%的时间
|
||||
)
|
||||
if vector_result:
|
||||
results.append(vector_result)
|
||||
@@ -143,14 +141,14 @@ class AsyncInstantMemoryWrapper:
|
||||
logger.warning(f"向量记忆检索超时: {query[:30]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"向量记忆检索失败: {e}")
|
||||
|
||||
|
||||
# 从LLM记忆系统检索(备用,更准确但较慢)
|
||||
await self._ensure_llm_memory()
|
||||
if self.llm_memory and len(results) == 0: # 只有向量检索失败时才使用LLM
|
||||
try:
|
||||
llm_result = await asyncio.wait_for(
|
||||
self.llm_memory.get_memory(query),
|
||||
timeout=timeout * 0.4 # 给LLM系统40%的时间
|
||||
timeout=timeout * 0.4, # 给LLM系统40%的时间
|
||||
)
|
||||
if llm_result:
|
||||
results.extend(llm_result)
|
||||
@@ -159,7 +157,7 @@ class AsyncInstantMemoryWrapper:
|
||||
logger.warning(f"LLM记忆检索超时: {query[:30]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM记忆检索失败: {e}")
|
||||
|
||||
|
||||
# 合并结果
|
||||
final_result = None
|
||||
if results:
|
||||
@@ -178,42 +176,43 @@ class AsyncInstantMemoryWrapper:
|
||||
final_result.append(r)
|
||||
else:
|
||||
final_result = results[0] # 使用第一个结果
|
||||
|
||||
|
||||
# 缓存结果
|
||||
if use_cache and final_result is not None:
|
||||
cache_key = self._get_cache_key("retrieve", query)
|
||||
self._cache_result(cache_key, final_result)
|
||||
|
||||
|
||||
return final_result
|
||||
|
||||
|
||||
async def get_memory_with_fallback(self, query: str, max_timeout: float = 2.0) -> str:
|
||||
"""获取记忆的回退方法,保证不会长时间阻塞"""
|
||||
try:
|
||||
# 首先尝试快速检索
|
||||
result = await self.retrieve_memory_async(query, timeout=max_timeout)
|
||||
|
||||
|
||||
if result:
|
||||
if isinstance(result, list):
|
||||
return "\n".join(str(item) for item in result)
|
||||
return str(result)
|
||||
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记忆检索完全失败: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def store_memory_background(self, content: str):
|
||||
"""在后台存储记忆(发后即忘模式)"""
|
||||
|
||||
async def background_store():
|
||||
try:
|
||||
await self.store_memory_async(content, timeout=10.0) # 后台任务可以用更长超时
|
||||
except Exception as e:
|
||||
logger.error(f"后台记忆存储失败: {e}")
|
||||
|
||||
|
||||
# 创建后台任务
|
||||
asyncio.create_task(background_store())
|
||||
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取包装器状态"""
|
||||
return {
|
||||
@@ -222,23 +221,26 @@ class AsyncInstantMemoryWrapper:
|
||||
"vector_memory_available": self.vector_memory is not None,
|
||||
"cache_entries": len(self.cache),
|
||||
"cache_ttl": self.cache_ttl,
|
||||
"default_timeout": self.default_timeout
|
||||
"default_timeout": self.default_timeout,
|
||||
}
|
||||
|
||||
|
||||
def clear_cache(self):
|
||||
"""清理缓存"""
|
||||
self.cache.clear()
|
||||
logger.info(f"记忆缓存已清理: {self.chat_id}")
|
||||
|
||||
|
||||
# 缓存包装器实例,避免重复创建
|
||||
_wrapper_cache: Dict[str, AsyncInstantMemoryWrapper] = {}
|
||||
|
||||
|
||||
def get_async_instant_memory(chat_id: str) -> AsyncInstantMemoryWrapper:
|
||||
"""获取异步瞬时记忆包装器实例"""
|
||||
if chat_id not in _wrapper_cache:
|
||||
_wrapper_cache[chat_id] = AsyncInstantMemoryWrapper(chat_id)
|
||||
return _wrapper_cache[chat_id]
|
||||
|
||||
|
||||
def clear_wrapper_cache():
|
||||
"""清理包装器缓存"""
|
||||
global _wrapper_cache
|
||||
|
||||
@@ -15,9 +15,11 @@ from src.chat.memory_system.async_instant_memory_wrapper import get_async_instan
|
||||
|
||||
logger = get_logger("async_memory_optimizer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryTask:
|
||||
"""记忆任务数据结构"""
|
||||
|
||||
task_id: str
|
||||
task_type: str # "store", "retrieve", "build"
|
||||
chat_id: str
|
||||
@@ -25,14 +27,15 @@ class MemoryTask:
|
||||
priority: int = 1 # 1=低优先级, 2=中优先级, 3=高优先级
|
||||
callback: Optional[Callable] = None
|
||||
created_at: float = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = time.time()
|
||||
|
||||
|
||||
class AsyncMemoryQueue:
|
||||
"""异步记忆任务队列管理器"""
|
||||
|
||||
|
||||
def __init__(self, max_workers: int = 3):
|
||||
self.max_workers = max_workers
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
@@ -42,56 +45,56 @@ class AsyncMemoryQueue:
|
||||
self.failed_tasks: Dict[str, str] = {}
|
||||
self.is_running = False
|
||||
self.worker_tasks: List[asyncio.Task] = []
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""启动异步队列处理器"""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
|
||||
self.is_running = True
|
||||
# 启动多个工作协程
|
||||
for i in range(self.max_workers):
|
||||
worker = asyncio.create_task(self._worker(f"worker-{i}"))
|
||||
self.worker_tasks.append(worker)
|
||||
|
||||
|
||||
logger.info(f"异步记忆队列已启动,工作线程数: {self.max_workers}")
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""停止队列处理器"""
|
||||
self.is_running = False
|
||||
|
||||
|
||||
# 等待所有工作任务完成
|
||||
for task in self.worker_tasks:
|
||||
task.cancel()
|
||||
|
||||
|
||||
await asyncio.gather(*self.worker_tasks, return_exceptions=True)
|
||||
self.executor.shutdown(wait=True)
|
||||
logger.info("异步记忆队列已停止")
|
||||
|
||||
|
||||
async def _worker(self, worker_name: str):
|
||||
"""工作协程,处理队列中的任务"""
|
||||
logger.info(f"记忆处理工作线程 {worker_name} 启动")
|
||||
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# 等待任务,超时1秒避免永久阻塞
|
||||
task = await asyncio.wait_for(self.task_queue.get(), timeout=1.0)
|
||||
|
||||
|
||||
# 执行任务
|
||||
await self._execute_task(task, worker_name)
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时正常,继续下一次循环
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"工作线程 {worker_name} 处理任务时出错: {e}")
|
||||
|
||||
|
||||
async def _execute_task(self, task: MemoryTask, worker_name: str):
|
||||
"""执行具体的记忆任务"""
|
||||
try:
|
||||
logger.debug(f"[{worker_name}] 开始处理任务: {task.task_type} - {task.task_id}")
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 根据任务类型执行不同的处理逻辑
|
||||
result = None
|
||||
if task.task_type == "store":
|
||||
@@ -102,13 +105,13 @@ class AsyncMemoryQueue:
|
||||
result = await self._handle_build_task(task)
|
||||
else:
|
||||
raise ValueError(f"未知的任务类型: {task.task_type}")
|
||||
|
||||
|
||||
# 记录完成的任务
|
||||
self.completed_tasks[task.task_id] = result
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
|
||||
logger.debug(f"[{worker_name}] 任务完成: {task.task_id} (耗时: {execution_time:.2f}s)")
|
||||
|
||||
|
||||
# 执行回调函数
|
||||
if task.callback:
|
||||
try:
|
||||
@@ -118,12 +121,12 @@ class AsyncMemoryQueue:
|
||||
task.callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"任务回调执行失败: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"任务执行失败: {e}"
|
||||
logger.error(f"[{worker_name}] {error_msg}")
|
||||
self.failed_tasks[task.task_id] = error_msg
|
||||
|
||||
|
||||
# 执行错误回调
|
||||
if task.callback:
|
||||
try:
|
||||
@@ -133,7 +136,7 @@ class AsyncMemoryQueue:
|
||||
task.callback(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _handle_store_task(self, task: MemoryTask) -> Any:
|
||||
"""处理记忆存储任务"""
|
||||
# 这里需要根据具体的记忆系统来实现
|
||||
@@ -141,7 +144,7 @@ class AsyncMemoryQueue:
|
||||
try:
|
||||
# 获取包装器实例
|
||||
memory_wrapper = get_async_instant_memory(task.chat_id)
|
||||
|
||||
|
||||
# 使用包装器中的llm_memory实例
|
||||
if memory_wrapper and memory_wrapper.llm_memory:
|
||||
await memory_wrapper.llm_memory.create_and_store_memory(task.content)
|
||||
@@ -152,13 +155,13 @@ class AsyncMemoryQueue:
|
||||
except Exception as e:
|
||||
logger.error(f"记忆存储失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _handle_retrieve_task(self, task: MemoryTask) -> Any:
|
||||
"""处理记忆检索任务"""
|
||||
try:
|
||||
# 获取包装器实例
|
||||
memory_wrapper = get_async_instant_memory(task.chat_id)
|
||||
|
||||
|
||||
# 使用包装器中的llm_memory实例
|
||||
if memory_wrapper and memory_wrapper.llm_memory:
|
||||
memories = await memory_wrapper.llm_memory.get_memory(task.content)
|
||||
@@ -169,14 +172,14 @@ class AsyncMemoryQueue:
|
||||
except Exception as e:
|
||||
logger.error(f"记忆检索失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def _handle_build_task(self, task: MemoryTask) -> Any:
|
||||
"""处理记忆构建任务(海马体系统)"""
|
||||
try:
|
||||
# 延迟导入避免循环依赖
|
||||
if global_config.memory.enable_memory:
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
|
||||
if hippocampus_manager._initialized:
|
||||
await hippocampus_manager.build_memory()
|
||||
return True
|
||||
@@ -184,22 +187,22 @@ class AsyncMemoryQueue:
|
||||
except Exception as e:
|
||||
logger.error(f"记忆构建失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def add_task(self, task: MemoryTask) -> str:
|
||||
"""添加任务到队列"""
|
||||
await self.task_queue.put(task)
|
||||
self.running_tasks[task.task_id] = task
|
||||
logger.debug(f"任务已加入队列: {task.task_type} - {task.task_id}")
|
||||
return task.task_id
|
||||
|
||||
|
||||
def get_task_result(self, task_id: str) -> Optional[Any]:
|
||||
"""获取任务结果(非阻塞)"""
|
||||
return self.completed_tasks.get(task_id)
|
||||
|
||||
|
||||
def is_task_completed(self, task_id: str) -> bool:
|
||||
"""检查任务是否完成"""
|
||||
return task_id in self.completed_tasks or task_id in self.failed_tasks
|
||||
|
||||
|
||||
def get_queue_status(self) -> Dict[str, Any]:
|
||||
"""获取队列状态"""
|
||||
return {
|
||||
@@ -208,30 +211,30 @@ class AsyncMemoryQueue:
|
||||
"running_tasks": len(self.running_tasks),
|
||||
"completed_tasks": len(self.completed_tasks),
|
||||
"failed_tasks": len(self.failed_tasks),
|
||||
"worker_count": len(self.worker_tasks)
|
||||
"worker_count": len(self.worker_tasks),
|
||||
}
|
||||
|
||||
|
||||
class NonBlockingMemoryManager:
|
||||
"""非阻塞记忆管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.queue = AsyncMemoryQueue(max_workers=3)
|
||||
self.cache: Dict[str, Any] = {}
|
||||
self.cache_ttl: Dict[str, float] = {}
|
||||
self.cache_timeout = 300 # 缓存5分钟
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.queue.start()
|
||||
logger.info("非阻塞记忆管理器已初始化")
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭管理器"""
|
||||
await self.queue.stop()
|
||||
logger.info("非阻塞记忆管理器已关闭")
|
||||
|
||||
async def store_memory_async(self, chat_id: str, content: str,
|
||||
callback: Optional[Callable] = None) -> str:
|
||||
|
||||
async def store_memory_async(self, chat_id: str, content: str, callback: Optional[Callable] = None) -> str:
|
||||
"""异步存储记忆(非阻塞)"""
|
||||
task = MemoryTask(
|
||||
task_id=f"store_{chat_id}_{int(time.time() * 1000)}",
|
||||
@@ -239,13 +242,12 @@ class NonBlockingMemoryManager:
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
priority=1, # 存储优先级较低
|
||||
callback=callback
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
return await self.queue.add_task(task)
|
||||
|
||||
async def retrieve_memory_async(self, chat_id: str, query: str,
|
||||
callback: Optional[Callable] = None) -> str:
|
||||
|
||||
async def retrieve_memory_async(self, chat_id: str, query: str, callback: Optional[Callable] = None) -> str:
|
||||
"""异步检索记忆(非阻塞)"""
|
||||
# 先检查缓存
|
||||
cache_key = f"retrieve_{chat_id}_{hash(query)}"
|
||||
@@ -257,18 +259,18 @@ class NonBlockingMemoryManager:
|
||||
else:
|
||||
callback(result)
|
||||
return "cache_hit"
|
||||
|
||||
|
||||
task = MemoryTask(
|
||||
task_id=f"retrieve_{chat_id}_{int(time.time() * 1000)}",
|
||||
task_type="retrieve",
|
||||
chat_id=chat_id,
|
||||
content=query,
|
||||
priority=2, # 检索优先级中等
|
||||
callback=self._create_cache_callback(cache_key, callback)
|
||||
callback=self._create_cache_callback(cache_key, callback),
|
||||
)
|
||||
|
||||
|
||||
return await self.queue.add_task(task)
|
||||
|
||||
|
||||
async def build_memory_async(self, callback: Optional[Callable] = None) -> str:
|
||||
"""异步构建记忆(非阻塞)"""
|
||||
task = MemoryTask(
|
||||
@@ -277,70 +279,72 @@ class NonBlockingMemoryManager:
|
||||
chat_id="system",
|
||||
content="",
|
||||
priority=1, # 构建优先级较低,避免影响用户体验
|
||||
callback=callback
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
return await self.queue.add_task(task)
|
||||
|
||||
|
||||
def _is_cache_valid(self, cache_key: str) -> bool:
|
||||
"""检查缓存是否有效"""
|
||||
if cache_key not in self.cache:
|
||||
return False
|
||||
|
||||
|
||||
return time.time() - self.cache_ttl.get(cache_key, 0) < self.cache_timeout
|
||||
|
||||
|
||||
def _create_cache_callback(self, cache_key: str, original_callback: Optional[Callable]):
|
||||
"""创建带缓存的回调函数"""
|
||||
|
||||
async def cache_callback(result):
|
||||
# 存储到缓存
|
||||
if result is not None:
|
||||
self.cache[cache_key] = result
|
||||
self.cache_ttl[cache_key] = time.time()
|
||||
|
||||
|
||||
# 执行原始回调
|
||||
if original_callback:
|
||||
if asyncio.iscoroutinefunction(original_callback):
|
||||
await original_callback(result)
|
||||
else:
|
||||
original_callback(result)
|
||||
|
||||
|
||||
return cache_callback
|
||||
|
||||
|
||||
def get_cached_memory(self, chat_id: str, query: str) -> Optional[Any]:
|
||||
"""获取缓存的记忆(同步,立即返回)"""
|
||||
cache_key = f"retrieve_{chat_id}_{hash(query)}"
|
||||
if self._is_cache_valid(cache_key):
|
||||
return self.cache[cache_key]
|
||||
return None
|
||||
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取管理器状态"""
|
||||
status = self.queue.get_queue_status()
|
||||
status.update({
|
||||
"cache_entries": len(self.cache),
|
||||
"cache_timeout": self.cache_timeout
|
||||
})
|
||||
status.update({"cache_entries": len(self.cache), "cache_timeout": self.cache_timeout})
|
||||
return status
|
||||
|
||||
|
||||
# 全局实例
|
||||
async_memory_manager = NonBlockingMemoryManager()
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def store_memory_nonblocking(chat_id: str, content: str) -> str:
|
||||
"""非阻塞存储记忆的便捷函数"""
|
||||
return await async_memory_manager.store_memory_async(chat_id, content)
|
||||
|
||||
|
||||
async def retrieve_memory_nonblocking(chat_id: str, query: str) -> Optional[Any]:
|
||||
"""非阻塞检索记忆的便捷函数,支持缓存"""
|
||||
# 先尝试从缓存获取
|
||||
cached_result = async_memory_manager.get_cached_memory(chat_id, query)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
|
||||
# 缓存未命中,启动异步检索
|
||||
await async_memory_manager.retrieve_memory_async(chat_id, query)
|
||||
return None # 返回None表示需要异步获取
|
||||
|
||||
|
||||
async def build_memory_nonblocking() -> str:
|
||||
"""非阻塞构建记忆的便捷函数"""
|
||||
return await async_memory_manager.build_memory_async()
|
||||
|
||||
@@ -14,8 +14,10 @@ from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.config.config import model_config
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryItem:
|
||||
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
|
||||
self.memory_id = memory_id
|
||||
@@ -24,6 +26,8 @@ class MemoryItem:
|
||||
self.keywords: list[str] = keywords
|
||||
self.create_time: float = time.time()
|
||||
self.last_view_time: float = time.time()
|
||||
|
||||
|
||||
class InstantMemory:
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
@@ -105,13 +109,13 @@ class InstantMemory:
|
||||
async def store_memory(self, memory_item: MemoryItem):
|
||||
with get_db_session() as session:
|
||||
memory = Memory(
|
||||
memory_id=memory_item.memory_id,
|
||||
chat_id=memory_item.chat_id,
|
||||
memory_text=memory_item.memory_text,
|
||||
keywords=orjson.dumps(memory_item.keywords).decode('utf-8'),
|
||||
create_time=memory_item.create_time,
|
||||
last_view_time=memory_item.last_view_time,
|
||||
)
|
||||
memory_id=memory_item.memory_id,
|
||||
chat_id=memory_item.chat_id,
|
||||
memory_text=memory_item.memory_text,
|
||||
keywords=orjson.dumps(memory_item.keywords).decode("utf-8"),
|
||||
create_time=memory_item.create_time,
|
||||
last_view_time=memory_item.last_view_time,
|
||||
)
|
||||
session.add(memory)
|
||||
session.commit()
|
||||
|
||||
@@ -160,12 +164,14 @@ class InstantMemory:
|
||||
if start_time and end_time:
|
||||
start_ts = start_time.timestamp()
|
||||
end_ts = end_time.timestamp()
|
||||
|
||||
query = session.execute(select(Memory).where(
|
||||
(Memory.chat_id == self.chat_id)
|
||||
& (Memory.create_time >= start_ts)
|
||||
& (Memory.create_time < end_ts)
|
||||
)).scalars()
|
||||
|
||||
query = session.execute(
|
||||
select(Memory).where(
|
||||
(Memory.chat_id == self.chat_id)
|
||||
& (Memory.create_time >= start_ts)
|
||||
& (Memory.create_time < end_ts)
|
||||
)
|
||||
).scalars()
|
||||
else:
|
||||
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars()
|
||||
for mem in query:
|
||||
@@ -209,12 +215,14 @@ class InstantMemory:
|
||||
try:
|
||||
dt = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
|
||||
return dt, dt + timedelta(hours=1)
|
||||
except Exception: ...
|
||||
except Exception:
|
||||
...
|
||||
# 具体日期
|
||||
try:
|
||||
dt = datetime.strptime(time_str, "%Y-%m-%d")
|
||||
return dt, dt + timedelta(days=1)
|
||||
except Exception: ...
|
||||
except Exception:
|
||||
...
|
||||
# 相对时间
|
||||
if time_str == "今天":
|
||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
@@ -15,6 +15,7 @@ logger = get_logger("vector_instant_memory_v2")
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息数据结构"""
|
||||
|
||||
message_id: str
|
||||
chat_id: str
|
||||
content: str
|
||||
@@ -25,51 +26,49 @@ class ChatMessage:
|
||||
|
||||
class VectorInstantMemoryV2:
|
||||
"""重构的向量瞬时记忆系统 V2
|
||||
|
||||
|
||||
新设计理念:
|
||||
1. 全量存储 - 所有聊天记录都存储为向量
|
||||
2. 定时清理 - 定期清理过期记录
|
||||
3. 实时匹配 - 新消息与历史记录做向量相似度匹配
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, chat_id: str, retention_hours: int = 24, cleanup_interval: int = 3600):
|
||||
"""
|
||||
初始化向量瞬时记忆系统
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
retention_hours: 记忆保留时长(小时)
|
||||
retention_hours: 记忆保留时长(小时)
|
||||
cleanup_interval: 清理间隔(秒)
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.retention_hours = retention_hours
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.collection_name = "instant_memory"
|
||||
|
||||
|
||||
# 清理任务相关
|
||||
self.cleanup_task = None
|
||||
self.is_running = True
|
||||
|
||||
|
||||
# 初始化系统
|
||||
self._init_chroma()
|
||||
self._start_cleanup_task()
|
||||
|
||||
|
||||
logger.info(f"向量瞬时记忆系统V2初始化完成: {chat_id} (保留{retention_hours}小时)")
|
||||
|
||||
|
||||
def _init_chroma(self):
|
||||
"""使用全局服务初始化向量数据库集合"""
|
||||
try:
|
||||
# 现在我们只获取集合,而不是创建新的客户端
|
||||
vector_db_service.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
vector_db_service.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
|
||||
logger.info(f"向量记忆集合 '{self.collection_name}' 已准备就绪")
|
||||
except Exception as e:
|
||||
logger.error(f"获取向量记忆集合失败: {e}")
|
||||
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动定时清理任务"""
|
||||
|
||||
def cleanup_worker():
|
||||
while self.is_running:
|
||||
try:
|
||||
@@ -78,11 +77,11 @@ class VectorInstantMemoryV2:
|
||||
except Exception as e:
|
||||
logger.error(f"清理任务异常: {e}")
|
||||
time.sleep(60) # 异常时等待1分钟再继续
|
||||
|
||||
|
||||
self.cleanup_task = threading.Thread(target=cleanup_worker, daemon=True)
|
||||
self.cleanup_task.start()
|
||||
logger.info(f"定时清理任务已启动,间隔{self.cleanup_interval}秒")
|
||||
|
||||
|
||||
def _cleanup_expired_messages(self):
|
||||
"""清理过期的聊天记录"""
|
||||
try:
|
||||
@@ -91,211 +90,208 @@ class VectorInstantMemoryV2:
|
||||
# 采用 get -> filter -> delete 模式,避免复杂的 where 查询
|
||||
# 1. 获取当前 chat_id 的所有文档
|
||||
results = vector_db_service.get(
|
||||
collection_name=self.collection_name,
|
||||
where={"chat_id": self.chat_id},
|
||||
include=["metadatas"]
|
||||
collection_name=self.collection_name, where={"chat_id": self.chat_id}, include=["metadatas"]
|
||||
)
|
||||
|
||||
if not results or not results.get('ids'):
|
||||
if not results or not results.get("ids"):
|
||||
logger.info(f"chat_id '{self.chat_id}' 没有找到任何记录,无需清理")
|
||||
return
|
||||
|
||||
# 2. 在内存中过滤出过期的文档
|
||||
expired_ids = []
|
||||
metadatas = results.get('metadatas', [])
|
||||
ids = results.get('ids', [])
|
||||
metadatas = results.get("metadatas", [])
|
||||
ids = results.get("ids", [])
|
||||
|
||||
for i, metadata in enumerate(metadatas):
|
||||
if metadata and metadata.get('timestamp', float('inf')) < expire_time:
|
||||
if metadata and metadata.get("timestamp", float("inf")) < expire_time:
|
||||
expired_ids.append(ids[i])
|
||||
|
||||
# 3. 如果有过期文档,根据 ID 进行删除
|
||||
if expired_ids:
|
||||
vector_db_service.delete(
|
||||
collection_name=self.collection_name,
|
||||
ids=expired_ids
|
||||
)
|
||||
vector_db_service.delete(collection_name=self.collection_name, ids=expired_ids)
|
||||
logger.info(f"为 chat_id '{self.chat_id}' 清理了 {len(expired_ids)} 条过期记录")
|
||||
else:
|
||||
logger.info(f"chat_id '{self.chat_id}' 没有需要清理的过期记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期记录失败: {e}")
|
||||
|
||||
|
||||
async def store_message(self, content: str, sender: str = "user") -> bool:
|
||||
"""
|
||||
存储聊天消息到向量库
|
||||
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
sender: 发送者
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否存储成功
|
||||
"""
|
||||
if not content.strip():
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# 生成消息向量
|
||||
message_vector = await get_embedding(content)
|
||||
if not message_vector:
|
||||
logger.warning(f"消息向量生成失败: {content[:50]}...")
|
||||
return False
|
||||
|
||||
|
||||
message_id = f"{self.chat_id}_{int(time.time() * 1000)}_{hash(content) % 10000}"
|
||||
|
||||
|
||||
message = ChatMessage(
|
||||
message_id=message_id,
|
||||
chat_id=self.chat_id,
|
||||
content=content,
|
||||
timestamp=time.time(),
|
||||
sender=sender
|
||||
message_id=message_id, chat_id=self.chat_id, content=content, timestamp=time.time(), sender=sender
|
||||
)
|
||||
|
||||
|
||||
# 使用新的服务存储
|
||||
vector_db_service.add(
|
||||
collection_name=self.collection_name,
|
||||
embeddings=[message_vector],
|
||||
documents=[content],
|
||||
metadatas=[{
|
||||
"message_id": message.message_id,
|
||||
"chat_id": message.chat_id,
|
||||
"timestamp": message.timestamp,
|
||||
"sender": message.sender,
|
||||
"message_type": message.message_type
|
||||
}],
|
||||
ids=[message_id]
|
||||
metadatas=[
|
||||
{
|
||||
"message_id": message.message_id,
|
||||
"chat_id": message.chat_id,
|
||||
"timestamp": message.timestamp,
|
||||
"sender": message.sender,
|
||||
"message_type": message.message_type,
|
||||
}
|
||||
],
|
||||
ids=[message_id],
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"消息已存储: {content[:50]}...")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息失败: {e}")
|
||||
return False
|
||||
|
||||
async def find_similar_messages(self, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]:
|
||||
|
||||
async def find_similar_messages(
|
||||
self, query: str, top_k: int = 5, similarity_threshold: float = 0.7
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
查找与查询相似的历史消息
|
||||
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
top_k: 返回的最相似消息数量
|
||||
similarity_threshold: 相似度阈值
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict]: 相似消息列表,包含content、similarity、timestamp等信息
|
||||
"""
|
||||
if not query.strip():
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
query_vector = await get_embedding(query)
|
||||
if not query_vector:
|
||||
return []
|
||||
|
||||
|
||||
# 使用新的服务进行查询
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.collection_name,
|
||||
query_embeddings=[query_vector],
|
||||
n_results=top_k,
|
||||
where={"chat_id": self.chat_id}
|
||||
where={"chat_id": self.chat_id},
|
||||
)
|
||||
|
||||
if not results.get('documents') or not results['documents'][0]:
|
||||
|
||||
if not results.get("documents") or not results["documents"][0]:
|
||||
return []
|
||||
|
||||
|
||||
# 处理搜索结果
|
||||
similar_messages = []
|
||||
documents = results['documents'][0]
|
||||
distances = results['distances'][0] if results['distances'] else []
|
||||
metadatas = results['metadatas'][0] if results['metadatas'] else []
|
||||
|
||||
documents = results["documents"][0]
|
||||
distances = results["distances"][0] if results["distances"] else []
|
||||
metadatas = results["metadatas"][0] if results["metadatas"] else []
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
# 计算相似度(ChromaDB返回距离,需转换)
|
||||
distance = distances[i] if i < len(distances) else 1.0
|
||||
similarity = 1 - distance
|
||||
|
||||
|
||||
# 过滤低相似度结果
|
||||
if similarity < similarity_threshold:
|
||||
continue
|
||||
|
||||
|
||||
# 获取元数据
|
||||
metadata = metadatas[i] if i < len(metadatas) else {}
|
||||
|
||||
|
||||
# 安全获取timestamp
|
||||
timestamp = metadata.get("timestamp", 0) if isinstance(metadata, dict) else 0
|
||||
timestamp = float(timestamp) if isinstance(timestamp, (int, float)) else 0.0
|
||||
|
||||
similar_messages.append({
|
||||
"content": doc,
|
||||
"similarity": similarity,
|
||||
"timestamp": timestamp,
|
||||
"sender": metadata.get("sender", "unknown") if isinstance(metadata, dict) else "unknown",
|
||||
"message_id": metadata.get("message_id", "") if isinstance(metadata, dict) else "",
|
||||
"time_ago": self._format_time_ago(timestamp)
|
||||
})
|
||||
|
||||
|
||||
similar_messages.append(
|
||||
{
|
||||
"content": doc,
|
||||
"similarity": similarity,
|
||||
"timestamp": timestamp,
|
||||
"sender": metadata.get("sender", "unknown") if isinstance(metadata, dict) else "unknown",
|
||||
"message_id": metadata.get("message_id", "") if isinstance(metadata, dict) else "",
|
||||
"time_ago": self._format_time_ago(timestamp),
|
||||
}
|
||||
)
|
||||
|
||||
# 按相似度排序
|
||||
similar_messages.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
|
||||
logger.debug(f"找到 {len(similar_messages)} 条相似消息 (查询: {query[:30]}...)")
|
||||
return similar_messages
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找相似消息失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _format_time_ago(self, timestamp: float) -> str:
|
||||
"""格式化时间差显示"""
|
||||
if timestamp <= 0:
|
||||
return "未知时间"
|
||||
|
||||
|
||||
try:
|
||||
now = time.time()
|
||||
diff = now - timestamp
|
||||
|
||||
|
||||
if diff < 60:
|
||||
return f"{int(diff)}秒前"
|
||||
elif diff < 3600:
|
||||
return f"{int(diff/60)}分钟前"
|
||||
return f"{int(diff / 60)}分钟前"
|
||||
elif diff < 86400:
|
||||
return f"{int(diff/3600)}小时前"
|
||||
return f"{int(diff / 3600)}小时前"
|
||||
else:
|
||||
return f"{int(diff/86400)}天前"
|
||||
return f"{int(diff / 86400)}天前"
|
||||
except Exception:
|
||||
return "时间格式错误"
|
||||
|
||||
|
||||
async def get_memory_for_context(self, current_message: str, context_size: int = 3) -> str:
|
||||
"""
|
||||
获取与当前消息相关的记忆上下文
|
||||
|
||||
|
||||
Args:
|
||||
current_message: 当前消息
|
||||
context_size: 上下文消息数量
|
||||
|
||||
|
||||
Returns:
|
||||
str: 格式化的记忆上下文
|
||||
"""
|
||||
similar_messages = await self.find_similar_messages(
|
||||
current_message,
|
||||
current_message,
|
||||
top_k=context_size,
|
||||
similarity_threshold=0.6 # 降低阈值以获得更多上下文
|
||||
similarity_threshold=0.6, # 降低阈值以获得更多上下文
|
||||
)
|
||||
|
||||
|
||||
if not similar_messages:
|
||||
return ""
|
||||
|
||||
|
||||
# 格式化上下文
|
||||
context_lines = []
|
||||
for msg in similar_messages:
|
||||
context_lines.append(
|
||||
f"[{msg['time_ago']}] {msg['sender']}: {msg['content']} (相似度: {msg['similarity']:.2f})"
|
||||
)
|
||||
|
||||
|
||||
return "相关的历史记忆:\n" + "\n".join(context_lines)
|
||||
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取记忆系统统计信息"""
|
||||
stats = {
|
||||
@@ -304,9 +300,9 @@ class VectorInstantMemoryV2:
|
||||
"cleanup_interval": self.cleanup_interval,
|
||||
"system_status": "running" if self.is_running else "stopped",
|
||||
"total_messages": 0,
|
||||
"db_status": "connected"
|
||||
"db_status": "connected",
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# 注意:count() 现在没有 chat_id 过滤,返回的是整个集合的数量
|
||||
# 若要精确计数,需要 get(where={"chat_id": ...}) 然后 len(results['ids'])
|
||||
@@ -316,9 +312,9 @@ class VectorInstantMemoryV2:
|
||||
except Exception:
|
||||
stats["total_messages"] = "查询失败"
|
||||
stats["db_status"] = "disconnected"
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""停止记忆系统"""
|
||||
self.is_running = False
|
||||
@@ -337,26 +333,26 @@ def create_vector_memory_v2(chat_id: str, retention_hours: int = 24) -> VectorIn
|
||||
async def demo():
|
||||
"""使用演示"""
|
||||
memory = VectorInstantMemoryV2("demo_chat")
|
||||
|
||||
|
||||
# 存储一些测试消息
|
||||
await memory.store_message("今天天气不错,出去散步了", "用户")
|
||||
await memory.store_message("刚才买了个冰淇淋,很好吃", "用户")
|
||||
await memory.store_message("刚才买了个冰淇淋,很好吃", "用户")
|
||||
await memory.store_message("明天要开会,有点紧张", "用户")
|
||||
|
||||
|
||||
# 查找相似消息
|
||||
similar = await memory.find_similar_messages("天气怎么样")
|
||||
print("相似消息:", similar)
|
||||
|
||||
|
||||
# 获取上下文
|
||||
context = await memory.get_memory_for_context("今天心情如何")
|
||||
print("记忆上下文:", context)
|
||||
|
||||
|
||||
# 查看统计信息
|
||||
stats = memory.get_stats()
|
||||
print("系统状态:", stats)
|
||||
|
||||
|
||||
memory.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(demo())
|
||||
asyncio.run(demo())
|
||||
|
||||
Reference in New Issue
Block a user