This commit is contained in:
tt-P607
2025-08-27 19:35:37 +08:00
11 changed files with 696 additions and 169 deletions

View File

@@ -4,10 +4,9 @@ from typing import List, Dict, Any
from dataclasses import dataclass
import threading
import chromadb
from chromadb.config import Settings
from src.common.logger import get_logger
from src.chat.utils.utils import get_embedding
from src.common.vector_db import vector_db_service
logger = get_logger("vector_instant_memory_v2")
@@ -45,10 +44,7 @@ class VectorInstantMemoryV2:
self.chat_id = chat_id
self.retention_hours = retention_hours
self.cleanup_interval = cleanup_interval
# ChromaDB相关
self.client = None
self.collection = None
self.collection_name = "instant_memory"
# 清理任务相关
self.cleanup_task = None
@@ -61,22 +57,16 @@ class VectorInstantMemoryV2:
logger.info(f"向量瞬时记忆系统V2初始化完成: {chat_id} (保留{retention_hours}小时)")
def _init_chroma(self):
"""初始化ChromaDB连接"""
"""使用全局服务初始化向量数据库集合"""
try:
db_path = f"./data/memory_vectors/{self.chat_id}"
self.client = chromadb.PersistentClient(
path=db_path,
settings=Settings(anonymized_telemetry=False)
)
self.collection = self.client.get_or_create_collection(
name="chat_messages",
# 现在我们只获取集合,而不是创建新的客户端
vector_db_service.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"}
)
logger.info(f"向量记忆数据库初始化成功: {db_path}")
logger.info(f"向量记忆集合 '{self.collection_name}' 已准备就绪")
except Exception as e:
logger.error(f"ChromaDB初始化失败: {e}")
self.client = None
self.collection = None
logger.error(f"获取向量记忆集合失败: {e}")
def _start_cleanup_task(self):
"""启动定时清理任务"""
@@ -95,35 +85,23 @@ class VectorInstantMemoryV2:
def _cleanup_expired_messages(self):
"""清理过期的聊天记录"""
if not self.collection:
return
try:
# 计算过期时间戳
expire_time = time.time() - (self.retention_hours * 3600)
# 查询所有记录
all_results = self.collection.get(
where={"chat_id": self.chat_id},
include=["metadatas"]
# 使用 where 条件来删除过期记录
# 注意: ChromaDB 的 where 过滤器目前对 timestamp 的 $lt 操作支持可能有限
# 一个更可靠的方法是 get() -> filter -> delete()
# 但为了简化,我们先尝试直接 delete
# TODO: 确认 ChromaDB 对 $lt 在 metadata 上的支持。如果不支持,需要实现 get-filter-delete 模式。
vector_db_service.delete(
collection_name=self.collection_name,
where={
"chat_id": self.chat_id,
"timestamp": {"$lt": expire_time}
}
)
# 找出过期的记录ID
expired_ids = []
metadatas = all_results.get("metadatas") or []
ids = all_results.get("ids") or []
for i, metadata in enumerate(metadatas):
if metadata and isinstance(metadata, dict):
timestamp = metadata.get("timestamp", 0)
if isinstance(timestamp, (int, float)) and timestamp < expire_time:
if i < len(ids):
expired_ids.append(ids[i])
# 批量删除过期记录
if expired_ids:
self.collection.delete(ids=expired_ids)
logger.info(f"清理了 {len(expired_ids)} 条过期聊天记录")
logger.info(f"已为 chat_id '{self.chat_id}' 触发过期记录清理")
except Exception as e:
logger.error(f"清理过期记录失败: {e}")
@@ -139,7 +117,7 @@ class VectorInstantMemoryV2:
Returns:
bool: 是否存储成功
"""
if not self.collection or not content.strip():
if not content.strip():
return False
try:
@@ -149,10 +127,8 @@ class VectorInstantMemoryV2:
logger.warning(f"消息向量生成失败: {content[:50]}...")
return False
# 生成唯一消息ID
message_id = f"{self.chat_id}_{int(time.time() * 1000)}_{hash(content) % 10000}"
# 创建消息对象
message = ChatMessage(
message_id=message_id,
chat_id=self.chat_id,
@@ -161,8 +137,9 @@ class VectorInstantMemoryV2:
sender=sender
)
# 存储到ChromaDB
self.collection.add(
# 使用新的服务存储
vector_db_service.add(
collection_name=self.collection_name,
embeddings=[message_vector],
documents=[content],
metadatas=[{
@@ -194,23 +171,23 @@ class VectorInstantMemoryV2:
Returns:
List[Dict]: 相似消息列表包含content、similarity、timestamp等信息
"""
if not self.collection or not query.strip():
if not query.strip():
return []
try:
# 生成查询向量
query_vector = await get_embedding(query)
if not query_vector:
return []
# 向量相似度搜索
results = self.collection.query(
# 使用新的服务进行查询
results = vector_db_service.query(
collection_name=self.collection_name,
query_embeddings=[query_vector],
n_results=top_k,
where={"chat_id": self.chat_id}
)
if not results['documents'] or not results['documents'][0]:
if not results.get('documents') or not results['documents'][0]:
return []
# 处理搜索结果
@@ -311,15 +288,18 @@ class VectorInstantMemoryV2:
"cleanup_interval": self.cleanup_interval,
"system_status": "running" if self.is_running else "stopped",
"total_messages": 0,
"db_status": "connected" if self.collection else "disconnected"
"db_status": "connected"
}
if self.collection:
try:
result = self.collection.count()
stats["total_messages"] = result
except Exception:
stats["total_messages"] = "查询失败"
try:
# 注意count() 现在没有 chat_id 过滤,返回的是整个集合的数量
# 若要精确计数,需要 get(where={"chat_id": ...}) 然后 len(results['ids'])
# 这里为了简化,暂时显示集合总数
result = vector_db_service.count(collection_name=self.collection_name)
stats["total_messages"] = result
except Exception:
stats["total_messages"] = "查询失败"
stats["db_status"] = "disconnected"
return stats