修复代码格式和文件名大小写问题

This commit is contained in:
Windpicker-owo
2025-08-31 20:50:17 +08:00
parent df29014e41
commit 8149731925
218 changed files with 6913 additions and 8257 deletions

View File

@@ -16,7 +16,7 @@ from rich.traceback import install
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from sqlalchemy import select,insert,update,delete
from sqlalchemy import select, insert, update, delete
from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入
from src.common.logger import get_logger
from src.common.database.sqlalchemy_database_api import get_db_session
@@ -31,6 +31,7 @@ from src.chat.utils.utils import translate_timestamp_to_human_readable
install(extra_lines=3)
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
@@ -695,7 +696,9 @@ class Hippocampus:
return result
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
async def get_activate_from_text(
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
) -> tuple[float, list[str]]:
"""从文本中提取关键词并获取相关记忆。
Args:
@@ -863,10 +866,10 @@ class EntorhinalCortex:
current_memorized_times = message.get("memorized_times", 0)
with get_db_session() as session:
session.execute(
update(Messages)
.where(Messages.message_id == message["message_id"])
.values(memorized_times=current_memorized_times + 1)
)
update(Messages)
.where(Messages.message_id == message["message_id"])
.values(memorized_times=current_memorized_times + 1)
)
session.commit()
return messages # 直接返回原始的消息列表
@@ -951,7 +954,6 @@ class EntorhinalCortex:
for i in range(0, len(nodes_to_create), batch_size):
batch = nodes_to_create[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
if nodes_to_update:
batch_size = 100
@@ -963,11 +965,9 @@ class EntorhinalCortex:
.where(GraphNodes.concept == node_data["concept"])
.values(**{k: v for k, v in node_data.items() if k != "concept"})
)
if nodes_to_delete:
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
# 处理边的信息
db_edges = list(session.execute(select(GraphEdges)).scalars())
@@ -1023,7 +1023,6 @@ class EntorhinalCortex:
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i : i + batch_size]
session.execute(insert(GraphEdges), batch)
if edges_to_update:
batch_size = 100
@@ -1037,7 +1036,6 @@ class EntorhinalCortex:
)
.values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]})
)
if edges_to_delete:
for source, target in edges_to_delete:
@@ -1048,12 +1046,10 @@ class EntorhinalCortex:
# 提交事务
session.commit()
end_time = time.time()
logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}")
logger.info(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
async def resync_memory_to_db(self):
"""清空数据库并重新同步所有记忆数据"""
start_time = time.time()
@@ -1064,7 +1060,7 @@ class EntorhinalCortex:
clear_start = time.time()
session.execute(delete(GraphNodes))
session.execute(delete(GraphEdges))
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
@@ -1122,7 +1118,7 @@ class EntorhinalCortex:
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
node_end = time.time()
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}")
@@ -1134,7 +1130,7 @@ class EntorhinalCortex:
batch = edges_data[i : i + batch_size]
session.execute(insert(GraphEdges), batch)
session.commit()
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}")
@@ -1170,10 +1166,7 @@ class EntorhinalCortex:
if not node.last_modified:
update_data["last_modified"] = current_time
session.execute(
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
)
session.execute(update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data))
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.created_time or current_time
@@ -1209,7 +1202,6 @@ class EntorhinalCortex:
.where((GraphEdges.source == source) & (GraphEdges.target == target))
.values(**update_data)
)
# 获取时间信息(如果不存在则使用当前时间)
created_time = edge.created_time or current_time
@@ -1231,8 +1223,10 @@ class ParahippocampalGyrus:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
self.memory_modify_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory.modify")
self.memory_modify_model = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="memory.modify"
)
async def memory_compress(self, messages: list, compress_rate=0.1):
"""压缩和总结消息内容,生成记忆主题和摘要。
@@ -1532,14 +1526,20 @@ class ParahippocampalGyrus:
similarity = self._calculate_item_similarity(memory_items[i], memory_items[j])
if similarity > 0.8: # 相似度阈值
# 合并相似记忆项
longer_item = memory_items[i] if len(memory_items[i]) > len(memory_items[j]) else memory_items[j]
shorter_item = memory_items[j] if len(memory_items[i]) > len(memory_items[j]) else memory_items[i]
longer_item = (
memory_items[i] if len(memory_items[i]) > len(memory_items[j]) else memory_items[j]
)
shorter_item = (
memory_items[j] if len(memory_items[i]) > len(memory_items[j]) else memory_items[i]
)
# 保留更长的记忆项,标记短的用于删除
if shorter_item not in items_to_remove:
items_to_remove.append(shorter_item)
merged_count += 1
logger.debug(f"[整合] 在节点 {node} 中合并相似记忆: {shorter_item[:30]}... -> {longer_item[:30]}...")
logger.debug(
f"[整合] 在节点 {node} 中合并相似记忆: {shorter_item[:30]}... -> {longer_item[:30]}..."
)
# 移除被合并的记忆项
if items_to_remove:
@@ -1566,11 +1566,11 @@ class ParahippocampalGyrus:
# 检查是否有变化需要同步到数据库
has_changes = (
edge_changes["weakened"] or
edge_changes["removed"] or
node_changes["reduced"] or
node_changes["removed"] or
merged_count > 0
edge_changes["weakened"]
or edge_changes["removed"]
or node_changes["reduced"]
or node_changes["removed"]
or merged_count > 0
)
if has_changes:
@@ -1696,7 +1696,9 @@ class HippocampusManager:
response = []
return response
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
async def get_activate_from_text(
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
) -> tuple[float, list[str]]:
"""从文本中获取激活值的公共接口"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
@@ -1720,6 +1722,6 @@ class HippocampusManager:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return self._hippocampus.get_all_node_names()
# 创建全局实例
hippocampus_manager = HippocampusManager()

View File

@@ -10,7 +10,7 @@ import os
from typing import Dict, Any
# 添加项目路径
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry
@@ -19,68 +19,64 @@ from src.plugin_system.base.component_types import ComponentType
logger = get_logger("action_diagnostics")
class ActionDiagnostics:
"""Action组件诊断器"""
def __init__(self):
self.required_actions = ["no_reply", "reply", "emoji", "at_user"]
def check_plugin_loading(self) -> Dict[str, Any]:
"""检查插件加载状态"""
logger.info("开始检查插件加载状态...")
result = {
"plugins_loaded": False,
"total_plugins": 0,
"loaded_plugins": [],
"failed_plugins": [],
"core_actions_plugin": None
"core_actions_plugin": None,
}
try:
# 加载所有插件
plugin_manager.load_all_plugins()
# 获取插件统计信息
stats = plugin_manager.get_stats()
result["plugins_loaded"] = True
result["total_plugins"] = stats.get("total_plugins", 0)
# 检查是否有core_actions插件
for plugin_name in plugin_manager.loaded_plugins:
result["loaded_plugins"].append(plugin_name)
if "core_actions" in plugin_name.lower():
result["core_actions_plugin"] = plugin_name
logger.info(f"插件加载成功,总数: {result['total_plugins']}")
logger.info(f"已加载插件: {result['loaded_plugins']}")
except Exception as e:
logger.error(f"插件加载失败: {e}")
result["error"] = str(e)
return result
def check_action_registry(self) -> Dict[str, Any]:
"""检查Action注册状态"""
logger.info("开始检查Action组件注册状态...")
result = {
"registered_actions": [],
"missing_actions": [],
"default_actions": {},
"total_actions": 0
}
result = {"registered_actions": [], "missing_actions": [], "default_actions": {}, "total_actions": 0}
try:
# 获取所有注册的Action
all_components = component_registry.get_all_components(ComponentType.ACTION)
result["total_actions"] = len(all_components)
for name, info in all_components.items():
result["registered_actions"].append(name)
logger.debug(f"已注册Action: {name} (插件: {info.plugin_name})")
# 检查必需的Action是否存在
for required_action in self.required_actions:
if required_action not in all_components:
@@ -88,32 +84,32 @@ class ActionDiagnostics:
logger.warning(f"缺失必需Action: {required_action}")
else:
logger.info(f"找到必需Action: {required_action}")
# 获取默认Action
default_actions = component_registry.get_default_actions()
result["default_actions"] = {name: info.plugin_name for name, info in default_actions.items()}
logger.info(f"总注册Action数量: {result['total_actions']}")
logger.info(f"缺失Action: {result['missing_actions']}")
except Exception as e:
logger.error(f"Action注册检查失败: {e}")
result["error"] = str(e)
return result
def check_specific_action(self, action_name: str) -> Dict[str, Any]:
"""检查特定Action的详细信息"""
logger.info(f"检查Action详细信息: {action_name}")
result = {
"exists": False,
"component_info": None,
"component_class": None,
"is_default": False,
"plugin_name": None
"plugin_name": None,
}
try:
# 检查组件信息
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION)
@@ -123,14 +119,14 @@ class ActionDiagnostics:
"name": component_info.name,
"description": component_info.description,
"plugin_name": component_info.plugin_name,
"version": component_info.version
"version": component_info.version,
}
result["plugin_name"] = component_info.plugin_name
logger.info(f"找到Action组件信息: {action_name}")
else:
logger.warning(f"未找到Action组件信息: {action_name}")
return result
# 检查组件类
component_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
if component_class:
@@ -138,36 +134,32 @@ class ActionDiagnostics:
logger.info(f"找到Action组件类: {component_class.__name__}")
else:
logger.warning(f"未找到Action组件类: {action_name}")
# 检查是否为默认Action
default_actions = component_registry.get_default_actions()
result["is_default"] = action_name in default_actions
logger.info(f"Action {action_name} 检查完成: 存在={result['exists']}, 默认={result['is_default']}")
except Exception as e:
logger.error(f"检查Action {action_name} 失败: {e}")
result["error"] = str(e)
return result
def attempt_fix_missing_actions(self) -> Dict[str, Any]:
"""尝试修复缺失的Action"""
logger.info("尝试修复缺失的Action组件...")
result = {
"fixed_actions": [],
"still_missing": [],
"errors": []
}
result = {"fixed_actions": [], "still_missing": [], "errors": []}
try:
# 重新加载插件
plugin_manager.load_all_plugins()
# 再次检查Action注册状态
registry_check = self.check_action_registry()
for required_action in self.required_actions:
if required_action in registry_check["missing_actions"]:
try:
@@ -182,107 +174,100 @@ class ActionDiagnostics:
logger.error(error_msg)
result["errors"].append(error_msg)
result["still_missing"].append(required_action)
logger.info(f"Action修复完成: 已修复={result['fixed_actions']}, 仍缺失={result['still_missing']}")
except Exception as e:
error_msg = f"Action修复过程失败: {e}"
logger.error(error_msg)
result["errors"].append(error_msg)
return result
def _register_no_reply_action(self):
"""手动注册no_reply Action"""
try:
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
from src.plugin_system.base.component_types import ActionInfo
# 创建Action信息
action_info = ActionInfo(
name="no_reply",
description="暂时不回复消息",
plugin_name="built_in.core_actions",
version="1.0.0"
name="no_reply", description="暂时不回复消息", plugin_name="built_in.core_actions", version="1.0.0"
)
# 注册Action
success = component_registry._register_action_component(action_info, NoReplyAction)
if success:
logger.info("手动注册no_reply Action成功")
else:
raise Exception("注册失败")
except Exception as e:
raise Exception(f"手动注册no_reply Action失败: {e}") from e
def run_full_diagnosis(self) -> Dict[str, Any]:
"""运行完整诊断"""
logger.info("🔧 开始Action组件完整诊断")
logger.info("=" * 60)
diagnosis_result = {
"plugin_status": {},
"registry_status": {},
"action_details": {},
"fix_attempts": {},
"summary": {}
"summary": {},
}
# 1. 检查插件加载
logger.info("\n📦 步骤1: 检查插件加载状态")
diagnosis_result["plugin_status"] = self.check_plugin_loading()
# 2. 检查Action注册
logger.info("\n📋 步骤2: 检查Action注册状态")
diagnosis_result["registry_status"] = self.check_action_registry()
# 3. 检查特定Action详细信息
logger.info("\n🔍 步骤3: 检查特定Action详细信息")
diagnosis_result["action_details"] = {}
for action in self.required_actions:
diagnosis_result["action_details"][action] = self.check_specific_action(action)
# 4. 尝试修复缺失的Action
if diagnosis_result["registry_status"].get("missing_actions"):
logger.info("\n🔧 步骤4: 尝试修复缺失的Action")
diagnosis_result["fix_attempts"] = self.attempt_fix_missing_actions()
# 5. 生成诊断摘要
logger.info("\n📊 步骤5: 生成诊断摘要")
diagnosis_result["summary"] = self._generate_summary(diagnosis_result)
self._print_diagnosis_results(diagnosis_result)
return diagnosis_result
def _generate_summary(self, diagnosis_result: Dict[str, Any]) -> Dict[str, Any]:
"""生成诊断摘要"""
summary = {
"overall_status": "unknown",
"critical_issues": [],
"recommendations": []
}
summary = {"overall_status": "unknown", "critical_issues": [], "recommendations": []}
try:
# 检查插件加载状态
if not diagnosis_result["plugin_status"].get("plugins_loaded"):
summary["critical_issues"].append("插件加载失败")
summary["recommendations"].append("检查插件系统配置")
# 检查必需Action
missing_actions = diagnosis_result["registry_status"].get("missing_actions", [])
if "no_reply" in missing_actions:
summary["critical_issues"].append("缺失no_reply Action")
summary["recommendations"].append("检查core_actions插件是否正确加载")
# 检查修复结果
if diagnosis_result.get("fix_attempts"):
still_missing = diagnosis_result["fix_attempts"].get("still_missing", [])
if still_missing:
summary["critical_issues"].append(f"修复后仍缺失Action: {still_missing}")
summary["recommendations"].append("需要手动修复插件注册问题")
# 确定整体状态
if not summary["critical_issues"]:
summary["overall_status"] = "healthy"
@@ -290,103 +275,106 @@ class ActionDiagnostics:
summary["overall_status"] = "warning"
else:
summary["overall_status"] = "critical"
except Exception as e:
summary["critical_issues"].append(f"摘要生成失败: {e}")
summary["overall_status"] = "error"
return summary
def _print_diagnosis_results(self, diagnosis_result: Dict[str, Any]):
"""打印诊断结果"""
logger.info("\n" + "=" * 60)
logger.info("📈 诊断结果摘要")
logger.info("=" * 60)
summary = diagnosis_result.get("summary", {})
overall_status = summary.get("overall_status", "unknown")
# 状态指示器
status_indicators = {
"healthy": "✅ 系统健康",
"warning": "⚠️ 存在警告",
"critical": "❌ 存在严重问题",
"error": "💥 诊断出错",
"unknown": "❓ 状态未知"
"unknown": "❓ 状态未知",
}
logger.info(f"🎯 整体状态: {status_indicators.get(overall_status, overall_status)}")
# 关键问题
critical_issues = summary.get("critical_issues", [])
if critical_issues:
logger.info("\n🚨 关键问题:")
for issue in critical_issues:
logger.info(f"{issue}")
# 建议
recommendations = summary.get("recommendations", [])
if recommendations:
logger.info("\n💡 建议:")
for rec in recommendations:
logger.info(f"{rec}")
# 详细状态
plugin_status = diagnosis_result.get("plugin_status", {})
if plugin_status.get("plugins_loaded"):
logger.info(f"\n📦 插件状态: 已加载 {plugin_status.get('total_plugins', 0)} 个插件")
else:
logger.info("\n📦 插件状态: ❌ 插件加载失败")
registry_status = diagnosis_result.get("registry_status", {})
total_actions = registry_status.get("total_actions", 0)
missing_actions = registry_status.get("missing_actions", [])
logger.info(f"📋 Action状态: 已注册 {total_actions} 个,缺失 {len(missing_actions)}")
if missing_actions:
logger.info(f" 缺失的Action: {missing_actions}")
logger.info("\n" + "=" * 60)
def main():
"""主函数"""
diagnostics = ActionDiagnostics()
try:
result = diagnostics.run_full_diagnosis()
# 保存诊断结果
import orjson
with open("action_diagnosis_results.json", "w", encoding="utf-8") as f:
f.write(orjson.dumps(
result, option=orjson.OPT_INDENT_2).decode('utf-8')
)
f.write(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode("utf-8"))
logger.info("📄 诊断结果已保存到: action_diagnosis_results.json")
# 根据诊断结果返回适当的退出代码
summary = result.get("summary", {})
overall_status = summary.get("overall_status", "unknown")
if overall_status == "healthy":
return 0
elif overall_status == "warning":
return 1
else:
return 2
except KeyboardInterrupt:
logger.info("❌ 诊断被用户中断")
return 3
except Exception as e:
logger.error(f"❌ 诊断执行失败: {e}")
import traceback
traceback.print_exc()
return 4
if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.INFO)
exit_code = main()
sys.exit(exit_code)

View File

@@ -12,9 +12,10 @@ from src.config.config import global_config
logger = get_logger("async_instant_memory_wrapper")
class AsyncInstantMemoryWrapper:
"""异步瞬时记忆包装器"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.llm_memory = None
@@ -32,6 +33,7 @@ class AsyncInstantMemoryWrapper:
if self.llm_memory is None and self.llm_memory_enabled:
try:
from src.chat.memory_system.instant_memory import InstantMemory
self.llm_memory = InstantMemory(self.chat_id)
logger.info(f"LLM瞬时记忆系统已初始化: {self.chat_id}")
except Exception as e:
@@ -43,80 +45,76 @@ class AsyncInstantMemoryWrapper:
if self.vector_memory is None and self.vector_memory_enabled:
try:
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
self.vector_memory = VectorInstantMemoryV2(self.chat_id)
logger.info(f"向量瞬时记忆系统已初始化: {self.chat_id}")
except Exception as e:
logger.warning(f"向量瞬时记忆系统初始化失败: {e}")
self.vector_memory_enabled = False # 初始化失败则禁用
self.vector_memory_enabled = False # 初始化失败则禁用
def _get_cache_key(self, operation: str, content: str) -> str:
"""生成缓存键"""
return f"{operation}_{self.chat_id}_{hash(content)}"
def _is_cache_valid(self, cache_key: str) -> bool:
"""检查缓存是否有效"""
if cache_key not in self.cache:
return False
_, timestamp = self.cache[cache_key]
return time.time() - timestamp < self.cache_ttl
def _get_cached_result(self, cache_key: str) -> Optional[Any]:
"""获取缓存结果"""
if self._is_cache_valid(cache_key):
result, _ = self.cache[cache_key]
return result
return None
def _cache_result(self, cache_key: str, result: Any):
"""缓存结果"""
self.cache[cache_key] = (result, time.time())
async def store_memory_async(self, content: str, timeout: Optional[float] = None) -> bool:
"""异步存储记忆(带超时控制)"""
if timeout is None:
timeout = self.default_timeout
success_count = 0
# 异步存储到LLM记忆系统
await self._ensure_llm_memory()
if self.llm_memory:
try:
await asyncio.wait_for(
self.llm_memory.create_and_store_memory(content),
timeout=timeout
)
await asyncio.wait_for(self.llm_memory.create_and_store_memory(content), timeout=timeout)
success_count += 1
logger.debug(f"LLM记忆存储成功: {content[:50]}...")
except asyncio.TimeoutError:
logger.warning(f"LLM记忆存储超时: {content[:50]}...")
except Exception as e:
logger.error(f"LLM记忆存储失败: {e}")
# 异步存储到向量记忆系统
await self._ensure_vector_memory()
if self.vector_memory:
try:
await asyncio.wait_for(
self.vector_memory.store_message(content),
timeout=timeout
)
await asyncio.wait_for(self.vector_memory.store_message(content), timeout=timeout)
success_count += 1
logger.debug(f"向量记忆存储成功: {content[:50]}...")
except asyncio.TimeoutError:
logger.warning(f"向量记忆存储超时: {content[:50]}...")
except Exception as e:
logger.error(f"向量记忆存储失败: {e}")
return success_count > 0
async def retrieve_memory_async(self, query: str, timeout: Optional[float] = None,
use_cache: bool = True) -> Optional[Any]:
async def retrieve_memory_async(
self, query: str, timeout: Optional[float] = None, use_cache: bool = True
) -> Optional[Any]:
"""异步检索记忆(带缓存和超时控制)"""
if timeout is None:
timeout = self.default_timeout
# 检查缓存
if use_cache:
cache_key = self._get_cache_key("retrieve", query)
@@ -124,17 +122,17 @@ class AsyncInstantMemoryWrapper:
if cached_result is not None:
logger.debug(f"记忆检索命中缓存: {query[:30]}...")
return cached_result
# 尝试多种记忆系统
results = []
# 从向量记忆系统检索(优先,速度快)
await self._ensure_vector_memory()
if self.vector_memory:
try:
vector_result = await asyncio.wait_for(
self.vector_memory.get_memory_for_context(query),
timeout=timeout * 0.6 # 给向量系统60%的时间
timeout=timeout * 0.6, # 给向量系统60%的时间
)
if vector_result:
results.append(vector_result)
@@ -143,14 +141,14 @@ class AsyncInstantMemoryWrapper:
logger.warning(f"向量记忆检索超时: {query[:30]}...")
except Exception as e:
logger.error(f"向量记忆检索失败: {e}")
# 从LLM记忆系统检索备用更准确但较慢
await self._ensure_llm_memory()
if self.llm_memory and len(results) == 0: # 只有向量检索失败时才使用LLM
try:
llm_result = await asyncio.wait_for(
self.llm_memory.get_memory(query),
timeout=timeout * 0.4 # 给LLM系统40%的时间
timeout=timeout * 0.4, # 给LLM系统40%的时间
)
if llm_result:
results.extend(llm_result)
@@ -159,7 +157,7 @@ class AsyncInstantMemoryWrapper:
logger.warning(f"LLM记忆检索超时: {query[:30]}...")
except Exception as e:
logger.error(f"LLM记忆检索失败: {e}")
# 合并结果
final_result = None
if results:
@@ -178,42 +176,43 @@ class AsyncInstantMemoryWrapper:
final_result.append(r)
else:
final_result = results[0] # 使用第一个结果
# 缓存结果
if use_cache and final_result is not None:
cache_key = self._get_cache_key("retrieve", query)
self._cache_result(cache_key, final_result)
return final_result
async def get_memory_with_fallback(self, query: str, max_timeout: float = 2.0) -> str:
"""获取记忆的回退方法,保证不会长时间阻塞"""
try:
# 首先尝试快速检索
result = await self.retrieve_memory_async(query, timeout=max_timeout)
if result:
if isinstance(result, list):
return "\n".join(str(item) for item in result)
return str(result)
return ""
except Exception as e:
logger.error(f"记忆检索完全失败: {e}")
return ""
def store_memory_background(self, content: str):
"""在后台存储记忆(发后即忘模式)"""
async def background_store():
try:
await self.store_memory_async(content, timeout=10.0) # 后台任务可以用更长超时
except Exception as e:
logger.error(f"后台记忆存储失败: {e}")
# 创建后台任务
asyncio.create_task(background_store())
def get_status(self) -> Dict[str, Any]:
"""获取包装器状态"""
return {
@@ -222,23 +221,26 @@ class AsyncInstantMemoryWrapper:
"vector_memory_available": self.vector_memory is not None,
"cache_entries": len(self.cache),
"cache_ttl": self.cache_ttl,
"default_timeout": self.default_timeout
"default_timeout": self.default_timeout,
}
def clear_cache(self):
"""清理缓存"""
self.cache.clear()
logger.info(f"记忆缓存已清理: {self.chat_id}")
# 缓存包装器实例,避免重复创建
_wrapper_cache: Dict[str, AsyncInstantMemoryWrapper] = {}
def get_async_instant_memory(chat_id: str) -> AsyncInstantMemoryWrapper:
"""获取异步瞬时记忆包装器实例"""
if chat_id not in _wrapper_cache:
_wrapper_cache[chat_id] = AsyncInstantMemoryWrapper(chat_id)
return _wrapper_cache[chat_id]
def clear_wrapper_cache():
"""清理包装器缓存"""
global _wrapper_cache

View File

@@ -15,9 +15,11 @@ from src.chat.memory_system.async_instant_memory_wrapper import get_async_instan
logger = get_logger("async_memory_optimizer")
@dataclass
class MemoryTask:
"""记忆任务数据结构"""
task_id: str
task_type: str # "store", "retrieve", "build"
chat_id: str
@@ -25,14 +27,15 @@ class MemoryTask:
priority: int = 1 # 1=低优先级, 2=中优先级, 3=高优先级
callback: Optional[Callable] = None
created_at: float = None
def __post_init__(self):
if self.created_at is None:
self.created_at = time.time()
class AsyncMemoryQueue:
"""异步记忆任务队列管理器"""
def __init__(self, max_workers: int = 3):
self.max_workers = max_workers
self.executor = ThreadPoolExecutor(max_workers=max_workers)
@@ -42,56 +45,56 @@ class AsyncMemoryQueue:
self.failed_tasks: Dict[str, str] = {}
self.is_running = False
self.worker_tasks: List[asyncio.Task] = []
async def start(self):
"""启动异步队列处理器"""
if self.is_running:
return
self.is_running = True
# 启动多个工作协程
for i in range(self.max_workers):
worker = asyncio.create_task(self._worker(f"worker-{i}"))
self.worker_tasks.append(worker)
logger.info(f"异步记忆队列已启动,工作线程数: {self.max_workers}")
async def stop(self):
"""停止队列处理器"""
self.is_running = False
# 等待所有工作任务完成
for task in self.worker_tasks:
task.cancel()
await asyncio.gather(*self.worker_tasks, return_exceptions=True)
self.executor.shutdown(wait=True)
logger.info("异步记忆队列已停止")
async def _worker(self, worker_name: str):
"""工作协程,处理队列中的任务"""
logger.info(f"记忆处理工作线程 {worker_name} 启动")
while self.is_running:
try:
# 等待任务超时1秒避免永久阻塞
task = await asyncio.wait_for(self.task_queue.get(), timeout=1.0)
# 执行任务
await self._execute_task(task, worker_name)
except asyncio.TimeoutError:
# 超时正常,继续下一次循环
continue
except Exception as e:
logger.error(f"工作线程 {worker_name} 处理任务时出错: {e}")
async def _execute_task(self, task: MemoryTask, worker_name: str):
"""执行具体的记忆任务"""
try:
logger.debug(f"[{worker_name}] 开始处理任务: {task.task_type} - {task.task_id}")
start_time = time.time()
# 根据任务类型执行不同的处理逻辑
result = None
if task.task_type == "store":
@@ -102,13 +105,13 @@ class AsyncMemoryQueue:
result = await self._handle_build_task(task)
else:
raise ValueError(f"未知的任务类型: {task.task_type}")
# 记录完成的任务
self.completed_tasks[task.task_id] = result
execution_time = time.time() - start_time
logger.debug(f"[{worker_name}] 任务完成: {task.task_id} (耗时: {execution_time:.2f}s)")
# 执行回调函数
if task.callback:
try:
@@ -118,12 +121,12 @@ class AsyncMemoryQueue:
task.callback(result)
except Exception as e:
logger.error(f"任务回调执行失败: {e}")
except Exception as e:
error_msg = f"任务执行失败: {e}"
logger.error(f"[{worker_name}] {error_msg}")
self.failed_tasks[task.task_id] = error_msg
# 执行错误回调
if task.callback:
try:
@@ -133,7 +136,7 @@ class AsyncMemoryQueue:
task.callback(None)
except Exception:
pass
async def _handle_store_task(self, task: MemoryTask) -> Any:
"""处理记忆存储任务"""
# 这里需要根据具体的记忆系统来实现
@@ -141,7 +144,7 @@ class AsyncMemoryQueue:
try:
# 获取包装器实例
memory_wrapper = get_async_instant_memory(task.chat_id)
# 使用包装器中的llm_memory实例
if memory_wrapper and memory_wrapper.llm_memory:
await memory_wrapper.llm_memory.create_and_store_memory(task.content)
@@ -152,13 +155,13 @@ class AsyncMemoryQueue:
except Exception as e:
logger.error(f"记忆存储失败: {e}")
return False
async def _handle_retrieve_task(self, task: MemoryTask) -> Any:
"""处理记忆检索任务"""
try:
# 获取包装器实例
memory_wrapper = get_async_instant_memory(task.chat_id)
# 使用包装器中的llm_memory实例
if memory_wrapper and memory_wrapper.llm_memory:
memories = await memory_wrapper.llm_memory.get_memory(task.content)
@@ -169,14 +172,14 @@ class AsyncMemoryQueue:
except Exception as e:
logger.error(f"记忆检索失败: {e}")
return []
async def _handle_build_task(self, task: MemoryTask) -> Any:
"""处理记忆构建任务(海马体系统)"""
try:
# 延迟导入避免循环依赖
if global_config.memory.enable_memory:
from src.chat.memory_system.Hippocampus import hippocampus_manager
if hippocampus_manager._initialized:
await hippocampus_manager.build_memory()
return True
@@ -184,22 +187,22 @@ class AsyncMemoryQueue:
except Exception as e:
logger.error(f"记忆构建失败: {e}")
return False
async def add_task(self, task: MemoryTask) -> str:
"""添加任务到队列"""
await self.task_queue.put(task)
self.running_tasks[task.task_id] = task
logger.debug(f"任务已加入队列: {task.task_type} - {task.task_id}")
return task.task_id
def get_task_result(self, task_id: str) -> Optional[Any]:
"""获取任务结果(非阻塞)"""
return self.completed_tasks.get(task_id)
def is_task_completed(self, task_id: str) -> bool:
"""检查任务是否完成"""
return task_id in self.completed_tasks or task_id in self.failed_tasks
def get_queue_status(self) -> Dict[str, Any]:
"""获取队列状态"""
return {
@@ -208,30 +211,30 @@ class AsyncMemoryQueue:
"running_tasks": len(self.running_tasks),
"completed_tasks": len(self.completed_tasks),
"failed_tasks": len(self.failed_tasks),
"worker_count": len(self.worker_tasks)
"worker_count": len(self.worker_tasks),
}
class NonBlockingMemoryManager:
"""非阻塞记忆管理器"""
def __init__(self):
self.queue = AsyncMemoryQueue(max_workers=3)
self.cache: Dict[str, Any] = {}
self.cache_ttl: Dict[str, float] = {}
self.cache_timeout = 300 # 缓存5分钟
async def initialize(self):
"""初始化管理器"""
await self.queue.start()
logger.info("非阻塞记忆管理器已初始化")
async def shutdown(self):
"""关闭管理器"""
await self.queue.stop()
logger.info("非阻塞记忆管理器已关闭")
async def store_memory_async(self, chat_id: str, content: str,
callback: Optional[Callable] = None) -> str:
async def store_memory_async(self, chat_id: str, content: str, callback: Optional[Callable] = None) -> str:
"""异步存储记忆(非阻塞)"""
task = MemoryTask(
task_id=f"store_{chat_id}_{int(time.time() * 1000)}",
@@ -239,13 +242,12 @@ class NonBlockingMemoryManager:
chat_id=chat_id,
content=content,
priority=1, # 存储优先级较低
callback=callback
callback=callback,
)
return await self.queue.add_task(task)
async def retrieve_memory_async(self, chat_id: str, query: str,
callback: Optional[Callable] = None) -> str:
async def retrieve_memory_async(self, chat_id: str, query: str, callback: Optional[Callable] = None) -> str:
"""异步检索记忆(非阻塞)"""
# 先检查缓存
cache_key = f"retrieve_{chat_id}_{hash(query)}"
@@ -257,18 +259,18 @@ class NonBlockingMemoryManager:
else:
callback(result)
return "cache_hit"
task = MemoryTask(
task_id=f"retrieve_{chat_id}_{int(time.time() * 1000)}",
task_type="retrieve",
chat_id=chat_id,
content=query,
priority=2, # 检索优先级中等
callback=self._create_cache_callback(cache_key, callback)
callback=self._create_cache_callback(cache_key, callback),
)
return await self.queue.add_task(task)
async def build_memory_async(self, callback: Optional[Callable] = None) -> str:
"""异步构建记忆(非阻塞)"""
task = MemoryTask(
@@ -277,70 +279,72 @@ class NonBlockingMemoryManager:
chat_id="system",
content="",
priority=1, # 构建优先级较低,避免影响用户体验
callback=callback
callback=callback,
)
return await self.queue.add_task(task)
def _is_cache_valid(self, cache_key: str) -> bool:
"""检查缓存是否有效"""
if cache_key not in self.cache:
return False
return time.time() - self.cache_ttl.get(cache_key, 0) < self.cache_timeout
def _create_cache_callback(self, cache_key: str, original_callback: Optional[Callable]):
"""创建带缓存的回调函数"""
async def cache_callback(result):
# 存储到缓存
if result is not None:
self.cache[cache_key] = result
self.cache_ttl[cache_key] = time.time()
# 执行原始回调
if original_callback:
if asyncio.iscoroutinefunction(original_callback):
await original_callback(result)
else:
original_callback(result)
return cache_callback
def get_cached_memory(self, chat_id: str, query: str) -> Optional[Any]:
"""获取缓存的记忆(同步,立即返回)"""
cache_key = f"retrieve_{chat_id}_{hash(query)}"
if self._is_cache_valid(cache_key):
return self.cache[cache_key]
return None
def get_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
status = self.queue.get_queue_status()
status.update({
"cache_entries": len(self.cache),
"cache_timeout": self.cache_timeout
})
status.update({"cache_entries": len(self.cache), "cache_timeout": self.cache_timeout})
return status
# 全局实例
async_memory_manager = NonBlockingMemoryManager()
# 便捷函数
async def store_memory_nonblocking(chat_id: str, content: str) -> str:
"""非阻塞存储记忆的便捷函数"""
return await async_memory_manager.store_memory_async(chat_id, content)
async def retrieve_memory_nonblocking(chat_id: str, query: str) -> Optional[Any]:
"""非阻塞检索记忆的便捷函数,支持缓存"""
# 先尝试从缓存获取
cached_result = async_memory_manager.get_cached_memory(chat_id, query)
if cached_result is not None:
return cached_result
# 缓存未命中,启动异步检索
await async_memory_manager.retrieve_memory_async(chat_id, query)
return None # 返回None表示需要异步获取
async def build_memory_nonblocking() -> str:
"""非阻塞构建记忆的便捷函数"""
return await async_memory_manager.build_memory_async()

View File

@@ -14,8 +14,10 @@ from src.common.database.sqlalchemy_database_api import get_db_session
from src.config.config import model_config
from sqlalchemy import select
logger = get_logger(__name__)
class MemoryItem:
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
self.memory_id = memory_id
@@ -24,6 +26,8 @@ class MemoryItem:
self.keywords: list[str] = keywords
self.create_time: float = time.time()
self.last_view_time: float = time.time()
class InstantMemory:
def __init__(self, chat_id):
self.chat_id = chat_id
@@ -105,13 +109,13 @@ class InstantMemory:
async def store_memory(self, memory_item: MemoryItem):
with get_db_session() as session:
memory = Memory(
memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id,
memory_text=memory_item.memory_text,
keywords=orjson.dumps(memory_item.keywords).decode('utf-8'),
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time,
)
memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id,
memory_text=memory_item.memory_text,
keywords=orjson.dumps(memory_item.keywords).decode("utf-8"),
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time,
)
session.add(memory)
session.commit()
@@ -160,12 +164,14 @@ class InstantMemory:
if start_time and end_time:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = session.execute(select(Memory).where(
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts)
& (Memory.create_time < end_ts)
)).scalars()
query = session.execute(
select(Memory).where(
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts)
& (Memory.create_time < end_ts)
)
).scalars()
else:
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars()
for mem in query:
@@ -209,12 +215,14 @@ class InstantMemory:
try:
dt = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
return dt, dt + timedelta(hours=1)
except Exception: ...
except Exception:
...
# 具体日期
try:
dt = datetime.strptime(time_str, "%Y-%m-%d")
return dt, dt + timedelta(days=1)
except Exception: ...
except Exception:
...
# 相对时间
if time_str == "今天":
start = now.replace(hour=0, minute=0, second=0, microsecond=0)

View File

@@ -15,6 +15,7 @@ logger = get_logger("vector_instant_memory_v2")
@dataclass
class ChatMessage:
"""聊天消息数据结构"""
message_id: str
chat_id: str
content: str
@@ -25,51 +26,49 @@ class ChatMessage:
class VectorInstantMemoryV2:
"""重构的向量瞬时记忆系统 V2
新设计理念:
1. 全量存储 - 所有聊天记录都存储为向量
2. 定时清理 - 定期清理过期记录
3. 实时匹配 - 新消息与历史记录做向量相似度匹配
"""
def __init__(self, chat_id: str, retention_hours: int = 24, cleanup_interval: int = 3600):
"""
初始化向量瞬时记忆系统
Args:
chat_id: 聊天ID
retention_hours: 记忆保留时长(小时)
retention_hours: 记忆保留时长(小时)
cleanup_interval: 清理间隔(秒)
"""
self.chat_id = chat_id
self.retention_hours = retention_hours
self.cleanup_interval = cleanup_interval
self.collection_name = "instant_memory"
# 清理任务相关
self.cleanup_task = None
self.is_running = True
# 初始化系统
self._init_chroma()
self._start_cleanup_task()
logger.info(f"向量瞬时记忆系统V2初始化完成: {chat_id} (保留{retention_hours}小时)")
def _init_chroma(self):
"""使用全局服务初始化向量数据库集合"""
try:
# 现在我们只获取集合,而不是创建新的客户端
vector_db_service.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"}
)
vector_db_service.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
logger.info(f"向量记忆集合 '{self.collection_name}' 已准备就绪")
except Exception as e:
logger.error(f"获取向量记忆集合失败: {e}")
def _start_cleanup_task(self):
"""启动定时清理任务"""
def cleanup_worker():
while self.is_running:
try:
@@ -78,11 +77,11 @@ class VectorInstantMemoryV2:
except Exception as e:
logger.error(f"清理任务异常: {e}")
time.sleep(60) # 异常时等待1分钟再继续
self.cleanup_task = threading.Thread(target=cleanup_worker, daemon=True)
self.cleanup_task.start()
logger.info(f"定时清理任务已启动,间隔{self.cleanup_interval}")
def _cleanup_expired_messages(self):
"""清理过期的聊天记录"""
try:
@@ -91,211 +90,208 @@ class VectorInstantMemoryV2:
# 采用 get -> filter -> delete 模式,避免复杂的 where 查询
# 1. 获取当前 chat_id 的所有文档
results = vector_db_service.get(
collection_name=self.collection_name,
where={"chat_id": self.chat_id},
include=["metadatas"]
collection_name=self.collection_name, where={"chat_id": self.chat_id}, include=["metadatas"]
)
if not results or not results.get('ids'):
if not results or not results.get("ids"):
logger.info(f"chat_id '{self.chat_id}' 没有找到任何记录,无需清理")
return
# 2. 在内存中过滤出过期的文档
expired_ids = []
metadatas = results.get('metadatas', [])
ids = results.get('ids', [])
metadatas = results.get("metadatas", [])
ids = results.get("ids", [])
for i, metadata in enumerate(metadatas):
if metadata and metadata.get('timestamp', float('inf')) < expire_time:
if metadata and metadata.get("timestamp", float("inf")) < expire_time:
expired_ids.append(ids[i])
# 3. 如果有过期文档,根据 ID 进行删除
if expired_ids:
vector_db_service.delete(
collection_name=self.collection_name,
ids=expired_ids
)
vector_db_service.delete(collection_name=self.collection_name, ids=expired_ids)
logger.info(f"为 chat_id '{self.chat_id}' 清理了 {len(expired_ids)} 条过期记录")
else:
logger.info(f"chat_id '{self.chat_id}' 没有需要清理的过期记录")
except Exception as e:
logger.error(f"清理过期记录失败: {e}")
async def store_message(self, content: str, sender: str = "user") -> bool:
"""
存储聊天消息到向量库
Args:
content: 消息内容
sender: 发送者
Returns:
bool: 是否存储成功
"""
if not content.strip():
return False
try:
# 生成消息向量
message_vector = await get_embedding(content)
if not message_vector:
logger.warning(f"消息向量生成失败: {content[:50]}...")
return False
message_id = f"{self.chat_id}_{int(time.time() * 1000)}_{hash(content) % 10000}"
message = ChatMessage(
message_id=message_id,
chat_id=self.chat_id,
content=content,
timestamp=time.time(),
sender=sender
message_id=message_id, chat_id=self.chat_id, content=content, timestamp=time.time(), sender=sender
)
# 使用新的服务存储
vector_db_service.add(
collection_name=self.collection_name,
embeddings=[message_vector],
documents=[content],
metadatas=[{
"message_id": message.message_id,
"chat_id": message.chat_id,
"timestamp": message.timestamp,
"sender": message.sender,
"message_type": message.message_type
}],
ids=[message_id]
metadatas=[
{
"message_id": message.message_id,
"chat_id": message.chat_id,
"timestamp": message.timestamp,
"sender": message.sender,
"message_type": message.message_type,
}
],
ids=[message_id],
)
logger.debug(f"消息已存储: {content[:50]}...")
return True
except Exception as e:
logger.error(f"存储消息失败: {e}")
return False
async def find_similar_messages(self, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]:
async def find_similar_messages(
self, query: str, top_k: int = 5, similarity_threshold: float = 0.7
) -> List[Dict[str, Any]]:
"""
查找与查询相似的历史消息
Args:
query: 查询内容
top_k: 返回的最相似消息数量
similarity_threshold: 相似度阈值
Returns:
List[Dict]: 相似消息列表包含content、similarity、timestamp等信息
"""
if not query.strip():
return []
try:
query_vector = await get_embedding(query)
if not query_vector:
return []
# 使用新的服务进行查询
results = vector_db_service.query(
collection_name=self.collection_name,
query_embeddings=[query_vector],
n_results=top_k,
where={"chat_id": self.chat_id}
where={"chat_id": self.chat_id},
)
if not results.get('documents') or not results['documents'][0]:
if not results.get("documents") or not results["documents"][0]:
return []
# 处理搜索结果
similar_messages = []
documents = results['documents'][0]
distances = results['distances'][0] if results['distances'] else []
metadatas = results['metadatas'][0] if results['metadatas'] else []
documents = results["documents"][0]
distances = results["distances"][0] if results["distances"] else []
metadatas = results["metadatas"][0] if results["metadatas"] else []
for i, doc in enumerate(documents):
# 计算相似度ChromaDB返回距离需转换
distance = distances[i] if i < len(distances) else 1.0
similarity = 1 - distance
# 过滤低相似度结果
if similarity < similarity_threshold:
continue
# 获取元数据
metadata = metadatas[i] if i < len(metadatas) else {}
# 安全获取timestamp
timestamp = metadata.get("timestamp", 0) if isinstance(metadata, dict) else 0
timestamp = float(timestamp) if isinstance(timestamp, (int, float)) else 0.0
similar_messages.append({
"content": doc,
"similarity": similarity,
"timestamp": timestamp,
"sender": metadata.get("sender", "unknown") if isinstance(metadata, dict) else "unknown",
"message_id": metadata.get("message_id", "") if isinstance(metadata, dict) else "",
"time_ago": self._format_time_ago(timestamp)
})
similar_messages.append(
{
"content": doc,
"similarity": similarity,
"timestamp": timestamp,
"sender": metadata.get("sender", "unknown") if isinstance(metadata, dict) else "unknown",
"message_id": metadata.get("message_id", "") if isinstance(metadata, dict) else "",
"time_ago": self._format_time_ago(timestamp),
}
)
# 按相似度排序
similar_messages.sort(key=lambda x: x["similarity"], reverse=True)
logger.debug(f"找到 {len(similar_messages)} 条相似消息 (查询: {query[:30]}...)")
return similar_messages
except Exception as e:
logger.error(f"查找相似消息失败: {e}")
return []
def _format_time_ago(self, timestamp: float) -> str:
"""格式化时间差显示"""
if timestamp <= 0:
return "未知时间"
try:
now = time.time()
diff = now - timestamp
if diff < 60:
return f"{int(diff)}秒前"
elif diff < 3600:
return f"{int(diff/60)}分钟前"
return f"{int(diff / 60)}分钟前"
elif diff < 86400:
return f"{int(diff/3600)}小时前"
return f"{int(diff / 3600)}小时前"
else:
return f"{int(diff/86400)}天前"
return f"{int(diff / 86400)}天前"
except Exception:
return "时间格式错误"
async def get_memory_for_context(self, current_message: str, context_size: int = 3) -> str:
"""
获取与当前消息相关的记忆上下文
Args:
current_message: 当前消息
context_size: 上下文消息数量
Returns:
str: 格式化的记忆上下文
"""
similar_messages = await self.find_similar_messages(
current_message,
current_message,
top_k=context_size,
similarity_threshold=0.6 # 降低阈值以获得更多上下文
similarity_threshold=0.6, # 降低阈值以获得更多上下文
)
if not similar_messages:
return ""
# 格式化上下文
context_lines = []
for msg in similar_messages:
context_lines.append(
f"[{msg['time_ago']}] {msg['sender']}: {msg['content']} (相似度: {msg['similarity']:.2f})"
)
return "相关的历史记忆:\n" + "\n".join(context_lines)
def get_stats(self) -> Dict[str, Any]:
"""获取记忆系统统计信息"""
stats = {
@@ -304,9 +300,9 @@ class VectorInstantMemoryV2:
"cleanup_interval": self.cleanup_interval,
"system_status": "running" if self.is_running else "stopped",
"total_messages": 0,
"db_status": "connected"
"db_status": "connected",
}
try:
# 注意count() 现在没有 chat_id 过滤,返回的是整个集合的数量
# 若要精确计数,需要 get(where={"chat_id": ...}) 然后 len(results['ids'])
@@ -316,9 +312,9 @@ class VectorInstantMemoryV2:
except Exception:
stats["total_messages"] = "查询失败"
stats["db_status"] = "disconnected"
return stats
def stop(self):
"""停止记忆系统"""
self.is_running = False
@@ -337,26 +333,26 @@ def create_vector_memory_v2(chat_id: str, retention_hours: int = 24) -> VectorIn
async def demo():
"""使用演示"""
memory = VectorInstantMemoryV2("demo_chat")
# 存储一些测试消息
await memory.store_message("今天天气不错,出去散步了", "用户")
await memory.store_message("刚才买了个冰淇淋,很好吃", "用户")
await memory.store_message("刚才买了个冰淇淋,很好吃", "用户")
await memory.store_message("明天要开会,有点紧张", "用户")
# 查找相似消息
similar = await memory.find_similar_messages("天气怎么样")
print("相似消息:", similar)
# 获取上下文
context = await memory.get_memory_for_context("今天心情如何")
print("记忆上下文:", context)
# 查看统计信息
stats = memory.get_stats()
print("系统状态:", stats)
memory.stop()
if __name__ == "__main__":
asyncio.run(demo())
asyncio.run(demo())