This commit is contained in:
tt-P607
2025-08-27 19:35:37 +08:00
11 changed files with 696 additions and 169 deletions

View File

@@ -0,0 +1,124 @@
# 自动化工具缓存系统使用指南
为了提升性能并减少不必要的重复计算或API调用MMC内置了一套强大且易于使用的自动化工具缓存系统。该系统同时支持传统的**精确缓存**和先进的**语义缓存**。工具开发者无需编写任何手动缓存逻辑,只需在工具类中设置几个属性,即可轻松启用和配置缓存行为。
## 核心概念
- **精确缓存 (KV Cache)**: 当一个工具被调用时,系统会根据工具名称和所有参数生成一个唯一的键。只有当**下一次调用的工具名和所有参数与之前完全一致**时,才会命中缓存。
- **语义缓存 (Vector Cache)**: 它不要求参数完全一致,而是理解参数的**语义和意图**。例如,`"查询深圳今天的天气"``"今天深圳天气怎么样"` 这两个不同的查询,在语义上是高度相似的。如果启用了语义缓存,第二个查询就能成功命中由第一个查询产生的缓存结果。
## 如何为你的工具启用缓存
为你的工具(必须继承自 `BaseTool`)启用缓存非常简单,只需在你的工具类定义中添加以下一个或多个属性即可:
### 1. `enable_cache: bool`
这是启用缓存的总开关。
- **类型**: `bool`
- **默认值**: `False`
- **作用**: 设置为 `True` 即可为该工具启用缓存功能。如果为 `False`,后续的所有缓存配置都将无效。
**示例**:
```python
class MyAwesomeTool(BaseTool):
# ... 其他定义 ...
enable_cache: bool = True
```
### 2. `cache_ttl: int`
设置缓存的生存时间Time-To-Live
- **类型**: `int`
- **单位**: 秒
- **默认值**: `3600` (1小时)
- **作用**: 定义缓存条目在被视为过期之前可以存活多长时间。
**示例**:
```python
class MyLongTermCacheTool(BaseTool):
# ... 其他定义 ...
enable_cache: bool = True
cache_ttl: int = 86400 # 缓存24小时
```
### 3. `semantic_cache_query_key: Optional[str]`
启用语义缓存的关键。
- **类型**: `Optional[str]`
- **默认值**: `None`
- **作用**:
- 将此属性的值设置为你工具的某个**参数的名称**(字符串)。
- 自动化缓存系统在工作时,会提取该参数的值,将其转换为向量,并进行语义相似度搜索。
- 如果该值为 `None`,则此工具**仅使用精确缓存**。
**示例**:
```python
class WebSurfingTool(BaseTool):
name: str = "web_search"
parameters = [
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
# ... 其他参数 ...
]
# --- 缓存配置 ---
enable_cache: bool = True
cache_ttl: int = 7200 # 缓存2小时
semantic_cache_query_key: str = "query" # <-- 关键!
```
在上面的例子中,`web_search` 工具的 `"query"` 参数值(例如,用户输入的搜索词)将被用于语义缓存搜索。
## 完整示例
假设我们有一个调用外部API来获取股票价格的工具。由于股价在短时间内相对稳定且查询意图可能相似如 "苹果股价" vs "AAPL股价"),因此非常适合使用缓存。
```python
# in your_plugin/tools/stock_checker.py
from src.plugin_system import BaseTool, ToolParamType
class StockCheckerTool(BaseTool):
"""
一个用于查询股票价格的工具。
"""
name: str = "get_stock_price"
description: str = "获取指定公司或股票代码的最新价格。"
available_for_llm: bool = True
parameters = [
("symbol", ToolParamType.STRING, "公司名称或股票代码 (e.g., 'AAPL', '苹果')", True, None),
]
# --- 缓存配置 ---
# 1. 开启缓存
enable_cache: bool = True
# 2. 股价信息缓存10分钟
cache_ttl: int = 600
# 3. 使用 "symbol" 参数进行语义搜索
semantic_cache_query_key: str = "symbol"
# --------------------
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
symbol = function_args.get("symbol")
# ... 这里是你调用外部API获取股票价格的逻辑 ...
# price = await some_stock_api.get_price(symbol)
price = 123.45 # 示例价格
return {
"type": "stock_price_result",
"content": f"{symbol} 的当前价格是 ${price}"
}
```
通过以上简单的三行配置,`StockCheckerTool` 现在就拥有了强大的自动化缓存能力:
- 当用户查询 `"苹果"` 时,工具会执行并缓存结果。
- 在接下来的10分钟内如果再次查询 `"苹果"`,将直接从精确缓存返回结果。
- 更智能的是,如果另一个用户查询 `"AAPL"`,语义缓存系统会识别出 `"AAPL"``"苹果"` 在语义上高度相关大概率也会直接返回缓存的结果而无需再次调用API。
---
现在你可以专注于实现工具的核心逻辑把缓存的复杂性交给MMC的自动化系统来处理。

View File

@@ -0,0 +1,128 @@
# 统一向量数据库服务使用指南
本文档旨在说明如何在 `mmc` 项目中使用新集成的统一向量数据库服务。该服务提供了一个标准化的接口,用于与底层向量数据库(当前为 ChromaDB进行交互同时确保了代码的解耦和未来的可扩展性。
## 核心设计理念
1. **统一入口**: 所有对向量数据库的操作都应通过全局单例 `vector_db_service` 进行。
2. **抽象接口**: 服务遵循 `VectorDBBase` 抽象基类定义的接口,未来可以轻松替换为其他向量数据库(如 Milvus, FAISS而无需修改业务代码。
3. **单例模式**: 整个应用程序共享一个数据库客户端实例,避免了资源浪费和管理混乱。
4. **数据隔离**: 使用不同的 `collection` 名称来隔离不同业务模块(如语义缓存、瞬时记忆)的数据。在 `collection` 内部,使用 `metadata` 字段(如 `chat_id`)来隔离不同用户或会话的数据。
## 如何使用
### 1. 导入服务
在任何需要使用向量数据库的文件中,只需导入全局服务实例:
```python
from src.common.vector_db import vector_db_service
```
### 2. 主要操作
`vector_db_service` 对象提供了所有你需要的方法,这些方法都定义在 `VectorDBBase` 中。
#### a. 获取或创建集合 (Collection)
在操作数据之前,你需要先指定一个集合。如果集合不存在,它将被自动创建。
```python
# 为语义缓存创建一个集合
vector_db_service.get_or_create_collection(name="semantic_cache")
# 为瞬时记忆创建一个集合
vector_db_service.get_or_create_collection(
name="instant_memory",
metadata={"hnsw:space": "cosine"} # 可以传入特定于实现的参数
)
```
#### b. 添加数据
使用 `add` 方法向指定集合中添加向量、文档和元数据。
```python
collection_name = "instant_memory"
chat_id = "user_123"
message_id = "msg_abc"
embedding_vector = [0.1, 0.2, 0.3, ...] # 你的 embedding 向量
content = "你好,这是一个测试消息"
vector_db_service.add(
collection_name=collection_name,
embeddings=[embedding_vector],
documents=[content],
metadatas=[{
"chat_id": chat_id,
"timestamp": 1678886400.0,
"sender": "user"
}],
ids=[message_id]
)
```
#### c. 查询数据
使用 `query` 方法来查找相似的向量。你可以使用 `where` 子句来过滤元数据。
```python
query_vector = [0.11, 0.22, 0.33, ...] # 用于查询的向量
collection_name = "instant_memory"
chat_id_to_query = "user_123"
results = vector_db_service.query(
collection_name=collection_name,
query_embeddings=[query_vector],
n_results=5, # 返回最相似的5个结果
where={"chat_id": chat_id_to_query} # **重要**: 使用 where 来隔离不同聊天的数据
)
# results 的结构:
# {
# 'ids': [['msg_abc']],
# 'distances': [[0.0123]],
# 'metadatas': [[{'chat_id': 'user_123', ...}]],
# 'embeddings': None,
# 'documents': [['你好,这是一个测试消息']]
# }
print(results)
```
#### d. 删除数据
你可以根据 `id``where` 条件来删除数据。
```python
# 根据 ID 删除
vector_db_service.delete(
collection_name="instant_memory",
ids=["msg_abc"]
)
# 根据 where 条件删除 (例如,删除某个用户的所有记忆)
vector_db_service.delete(
collection_name="instant_memory",
where={"chat_id": "user_123"}
)
```
#### e. 获取集合数量
使用 `count` 方法获取一个集合中的条目总数。
```python
count = vector_db_service.count(collection_name="semantic_cache")
print(f"语义缓存集合中有 {count} 条数据。")
```
**注意**: `count` 方法目前返回整个集合的条目数,不会根据 `where` 条件进行过滤。
### 3. 代码位置
- **抽象基类**: [`mmc/src/common/vector_db/base.py`](mmc/src/common/vector_db/base.py)
- **ChromaDB 实现**: [`mmc/src/common/vector_db/chromadb_impl.py`](mmc/src/common/vector_db/chromadb_impl.py)
- **服务入口**: [`mmc/src/common/vector_db/__init__.py`](mmc/src/common/vector_db/__init__.py)
---
这份完整的文档应该能帮助您和团队的其他成员正确地使用新的向量数据库服务。如果您有任何其他问题,请随时提出。

View File

@@ -4,10 +4,9 @@ from typing import List, Dict, Any
from dataclasses import dataclass from dataclasses import dataclass
import threading import threading
import chromadb
from chromadb.config import Settings
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from src.common.vector_db import vector_db_service
logger = get_logger("vector_instant_memory_v2") logger = get_logger("vector_instant_memory_v2")
@@ -45,10 +44,7 @@ class VectorInstantMemoryV2:
self.chat_id = chat_id self.chat_id = chat_id
self.retention_hours = retention_hours self.retention_hours = retention_hours
self.cleanup_interval = cleanup_interval self.cleanup_interval = cleanup_interval
self.collection_name = "instant_memory"
# ChromaDB相关
self.client = None
self.collection = None
# 清理任务相关 # 清理任务相关
self.cleanup_task = None self.cleanup_task = None
@@ -61,22 +57,16 @@ class VectorInstantMemoryV2:
logger.info(f"向量瞬时记忆系统V2初始化完成: {chat_id} (保留{retention_hours}小时)") logger.info(f"向量瞬时记忆系统V2初始化完成: {chat_id} (保留{retention_hours}小时)")
def _init_chroma(self): def _init_chroma(self):
"""初始化ChromaDB连接""" """使用全局服务初始化向量数据库集合"""
try: try:
db_path = f"./data/memory_vectors/{self.chat_id}" # 现在我们只获取集合,而不是创建新的客户端
self.client = chromadb.PersistentClient( vector_db_service.get_or_create_collection(
path=db_path, name=self.collection_name,
settings=Settings(anonymized_telemetry=False)
)
self.collection = self.client.get_or_create_collection(
name="chat_messages",
metadata={"hnsw:space": "cosine"} metadata={"hnsw:space": "cosine"}
) )
logger.info(f"向量记忆数据库初始化成功: {db_path}") logger.info(f"向量记忆集合 '{self.collection_name}' 已准备就绪")
except Exception as e: except Exception as e:
logger.error(f"ChromaDB初始化失败: {e}") logger.error(f"获取向量记忆集合失败: {e}")
self.client = None
self.collection = None
def _start_cleanup_task(self): def _start_cleanup_task(self):
"""启动定时清理任务""" """启动定时清理任务"""
@@ -95,35 +85,23 @@ class VectorInstantMemoryV2:
def _cleanup_expired_messages(self): def _cleanup_expired_messages(self):
"""清理过期的聊天记录""" """清理过期的聊天记录"""
if not self.collection:
return
try: try:
# 计算过期时间戳
expire_time = time.time() - (self.retention_hours * 3600) expire_time = time.time() - (self.retention_hours * 3600)
# 查询所有记录 # 使用 where 条件来删除过期记录
all_results = self.collection.get( # 注意: ChromaDB 的 where 过滤器目前对 timestamp 的 $lt 操作支持可能有限
where={"chat_id": self.chat_id}, # 一个更可靠的方法是 get() -> filter -> delete()
include=["metadatas"] # 但为了简化,我们先尝试直接 delete
# TODO: 确认 ChromaDB 对 $lt 在 metadata 上的支持。如果不支持,需要实现 get-filter-delete 模式。
vector_db_service.delete(
collection_name=self.collection_name,
where={
"chat_id": self.chat_id,
"timestamp": {"$lt": expire_time}
}
) )
logger.info(f"已为 chat_id '{self.chat_id}' 触发过期记录清理")
# 找出过期的记录ID
expired_ids = []
metadatas = all_results.get("metadatas") or []
ids = all_results.get("ids") or []
for i, metadata in enumerate(metadatas):
if metadata and isinstance(metadata, dict):
timestamp = metadata.get("timestamp", 0)
if isinstance(timestamp, (int, float)) and timestamp < expire_time:
if i < len(ids):
expired_ids.append(ids[i])
# 批量删除过期记录
if expired_ids:
self.collection.delete(ids=expired_ids)
logger.info(f"清理了 {len(expired_ids)} 条过期聊天记录")
except Exception as e: except Exception as e:
logger.error(f"清理过期记录失败: {e}") logger.error(f"清理过期记录失败: {e}")
@@ -139,7 +117,7 @@ class VectorInstantMemoryV2:
Returns: Returns:
bool: 是否存储成功 bool: 是否存储成功
""" """
if not self.collection or not content.strip(): if not content.strip():
return False return False
try: try:
@@ -149,10 +127,8 @@ class VectorInstantMemoryV2:
logger.warning(f"消息向量生成失败: {content[:50]}...") logger.warning(f"消息向量生成失败: {content[:50]}...")
return False return False
# 生成唯一消息ID
message_id = f"{self.chat_id}_{int(time.time() * 1000)}_{hash(content) % 10000}" message_id = f"{self.chat_id}_{int(time.time() * 1000)}_{hash(content) % 10000}"
# 创建消息对象
message = ChatMessage( message = ChatMessage(
message_id=message_id, message_id=message_id,
chat_id=self.chat_id, chat_id=self.chat_id,
@@ -161,8 +137,9 @@ class VectorInstantMemoryV2:
sender=sender sender=sender
) )
# 存储到ChromaDB # 使用新的服务存储
self.collection.add( vector_db_service.add(
collection_name=self.collection_name,
embeddings=[message_vector], embeddings=[message_vector],
documents=[content], documents=[content],
metadatas=[{ metadatas=[{
@@ -194,23 +171,23 @@ class VectorInstantMemoryV2:
Returns: Returns:
List[Dict]: 相似消息列表包含content、similarity、timestamp等信息 List[Dict]: 相似消息列表包含content、similarity、timestamp等信息
""" """
if not self.collection or not query.strip(): if not query.strip():
return [] return []
try: try:
# 生成查询向量
query_vector = await get_embedding(query) query_vector = await get_embedding(query)
if not query_vector: if not query_vector:
return [] return []
# 向量相似度搜索 # 使用新的服务进行查询
results = self.collection.query( results = vector_db_service.query(
collection_name=self.collection_name,
query_embeddings=[query_vector], query_embeddings=[query_vector],
n_results=top_k, n_results=top_k,
where={"chat_id": self.chat_id} where={"chat_id": self.chat_id}
) )
if not results['documents'] or not results['documents'][0]: if not results.get('documents') or not results['documents'][0]:
return [] return []
# 处理搜索结果 # 处理搜索结果
@@ -311,15 +288,18 @@ class VectorInstantMemoryV2:
"cleanup_interval": self.cleanup_interval, "cleanup_interval": self.cleanup_interval,
"system_status": "running" if self.is_running else "stopped", "system_status": "running" if self.is_running else "stopped",
"total_messages": 0, "total_messages": 0,
"db_status": "connected" if self.collection else "disconnected" "db_status": "connected"
} }
if self.collection: try:
try: # 注意count() 现在没有 chat_id 过滤,返回的是整个集合的数量
result = self.collection.count() # 若要精确计数,需要 get(where={"chat_id": ...}) 然后 len(results['ids'])
stats["total_messages"] = result # 这里为了简化,暂时显示集合总数
except Exception: result = vector_db_service.count(collection_name=self.collection_name)
stats["total_messages"] = "查询失败" stats["total_messages"] = result
except Exception:
stats["total_messages"] = "查询失败"
stats["db_status"] = "disconnected"
return stats return stats

View File

@@ -4,13 +4,13 @@ import hashlib
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import faiss import faiss
import chromadb
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.common.database.sqlalchemy_models import CacheEntries from src.common.database.sqlalchemy_models import CacheEntries
from src.common.database.sqlalchemy_database_api import db_query, db_save from src.common.database.sqlalchemy_database_api import db_query, db_save
from src.common.vector_db import vector_db_service
logger = get_logger("cache_manager") logger = get_logger("cache_manager")
@@ -28,25 +28,23 @@ class CacheManager:
cls._instance = super(CacheManager, cls).__new__(cls) cls._instance = super(CacheManager, cls).__new__(cls)
return cls._instance return cls._instance
def __init__(self, default_ttl: int = 3600, chroma_path: str = "data/chroma_db"): def __init__(self, default_ttl: int = 3600):
""" """
初始化缓存管理器。 初始化缓存管理器。
""" """
if not hasattr(self, '_initialized'): if not hasattr(self, '_initialized'):
self.default_ttl = default_ttl self.default_ttl = default_ttl
self.semantic_cache_collection_name = "semantic_cache"
# L1 缓存 (内存) # L1 缓存 (内存)
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {} self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
embedding_dim = global_config.lpmm_knowledge.embedding_dimension embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim) self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
self.l1_vector_id_to_key: Dict[int, str] = {} self.l1_vector_id_to_key: Dict[int, str] = {}
# 语义缓存 (ChromaDB) # L2 向量缓存 (使用新的服务)
vector_db_service.get_or_create_collection(self.semantic_cache_collection_name)
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
# 嵌入模型 # 嵌入模型
self.embedding_model = LLMRequest(model_config.model_task_config.embedding) self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
@@ -152,18 +150,20 @@ class CacheManager:
return self.l1_kv_cache[l1_hit_key]["data"] return self.l1_kv_cache[l1_hit_key]["data"]
# 步骤 2b: L2 精确缓存 (数据库) # 步骤 2b: L2 精确缓存 (数据库)
cache_results = await db_query( cache_results_obj = await db_query(
model_class=CacheEntries, model_class=CacheEntries,
query_type="get", query_type="get",
filters={"cache_key": key}, filters={"cache_key": key},
single_result=True single_result=True
) )
if cache_results: if cache_results_obj:
expires_at = cache_results["expires_at"] # 使用 getattr 安全访问属性,避免 Pylance 类型检查错误
expires_at = getattr(cache_results_obj, "expires_at", 0)
if time.time() < expires_at: if time.time() < expires_at:
logger.info(f"命中L2键值缓存: {key}") logger.info(f"命中L2键值缓存: {key}")
data = orjson.loads(cache_results["cache_value"]) cache_value = getattr(cache_results_obj, "cache_value", "{}")
data = orjson.loads(cache_value)
# 更新访问统计 # 更新访问统计
await db_query( await db_query(
@@ -172,7 +172,7 @@ class CacheManager:
filters={"cache_key": key}, filters={"cache_key": key},
data={ data={
"last_accessed": time.time(), "last_accessed": time.time(),
"access_count": cache_results["access_count"] + 1 "access_count": getattr(cache_results_obj, "access_count", 0) + 1
} }
) )
@@ -187,29 +187,35 @@ class CacheManager:
filters={"cache_key": key} filters={"cache_key": key}
) )
# 步骤 2c: L2 语义缓存 (ChromaDB) # 步骤 2c: L2 语义缓存 (VectorDB Service)
if query_embedding is not None and self.chroma_collection: if query_embedding is not None:
try: try:
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1) results = vector_db_service.query(
if results and results['ids'] and results['ids'][0]: collection_name=self.semantic_cache_collection_name,
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A' query_embeddings=query_embedding.tolist(),
n_results=1
)
if results and results.get('ids') and results['ids'][0]:
distance = results['distances'][0][0] if results.get('distances') and results['distances'][0] else 'N/A'
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}") logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
if distance != 'N/A' and distance < 0.75: if distance != 'N/A' and distance < 0.75:
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0] l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}") logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
# 从数据库获取缓存数据 # 从数据库获取缓存数据
semantic_cache_results = await db_query( semantic_cache_results_obj = await db_query(
model_class=CacheEntries, model_class=CacheEntries,
query_type="get", query_type="get",
filters={"cache_key": l2_hit_key}, filters={"cache_key": l2_hit_key},
single_result=True single_result=True
) )
if semantic_cache_results: if semantic_cache_results_obj:
expires_at = semantic_cache_results["expires_at"] expires_at = getattr(semantic_cache_results_obj, "expires_at", 0)
if time.time() < expires_at: if time.time() < expires_at:
data = orjson.loads(semantic_cache_results["cache_value"]) cache_value = getattr(semantic_cache_results_obj, "cache_value", "{}")
data = orjson.loads(cache_value)
logger.debug(f"L2语义缓存返回的数据: {data}") logger.debug(f"L2语义缓存返回的数据: {data}")
# 回填 L1 # 回填 L1
@@ -218,13 +224,13 @@ class CacheManager:
try: try:
new_id = self.l1_vector_index.ntotal new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(query_embedding) faiss.normalize_L2(query_embedding)
self.l1_vector_index.add(x=query_embedding) self.l1_vector_index.add(x=query_embedding) # type: ignore
self.l1_vector_id_to_key[new_id] = key self.l1_vector_id_to_key[new_id] = key
except Exception as e: except Exception as e:
logger.error(f"回填L1向量索引时发生错误: {e}") logger.error(f"回填L1向量索引时发生错误: {e}")
return data return data
except Exception as e: except Exception as e:
logger.warning(f"ChromaDB查询失败: {e}") logger.warning(f"VectorDB Service 查询失败: {e}")
logger.debug(f"缓存未命中: {key}") logger.debug(f"缓存未命中: {key}")
return None return None
@@ -261,22 +267,27 @@ class CacheManager:
) )
# 写入语义缓存 # 写入语义缓存
if semantic_query and self.embedding_model and self.chroma_collection: if semantic_query and self.embedding_model:
try: try:
embedding_result = await self.embedding_model.get_embedding(semantic_query) embedding_result = await self.embedding_model.get_embedding(semantic_query)
if embedding_result: if embedding_result:
# embedding_result是一个元组(embedding_vector, model_name),取第一个元素
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
validated_embedding = self._validate_embedding(embedding_vector) validated_embedding = self._validate_embedding(embedding_vector)
if validated_embedding is not None: if validated_embedding is not None:
embedding = np.array([validated_embedding], dtype='float32') embedding = np.array([validated_embedding], dtype='float32')
# 写入 L1 Vector # 写入 L1 Vector
new_id = self.l1_vector_index.ntotal new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(embedding) faiss.normalize_L2(embedding)
self.l1_vector_index.add(x=embedding) self.l1_vector_index.add(x=embedding) # type: ignore
self.l1_vector_id_to_key[new_id] = key self.l1_vector_id_to_key[new_id] = key
# 写入 L2 Vector
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key]) # 写入 L2 Vector (使用新的服务)
vector_db_service.add(
collection_name=self.semantic_cache_collection_name,
embeddings=embedding.tolist(),
ids=[key]
)
except Exception as e: except Exception as e:
logger.warning(f"语义缓存写入失败: {e}") logger.warning(f"语义缓存写入失败: {e}")
@@ -298,15 +309,14 @@ class CacheManager:
filters={} # 删除所有记录 filters={} # 删除所有记录
) )
# 清空ChromaDB # 清空 VectorDB
if self.chroma_collection: try:
try: vector_db_service.delete_collection(name=self.semantic_cache_collection_name)
self.chroma_client.delete_collection(name="semantic_cache") vector_db_service.get_or_create_collection(name=self.semantic_cache_collection_name)
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache") except Exception as e:
except Exception as e: logger.warning(f"清空 VectorDB 集合失败: {e}")
logger.warning(f"清空ChromaDB失败: {e}")
logger.info("L2 (数据库 & ChromaDB) 缓存已清空。") logger.info("L2 (数据库 & VectorDB) 缓存已清空。")
async def clear_all(self): async def clear_all(self):
"""清空所有缓存。""" """清空所有缓存。"""

View File

@@ -4,9 +4,11 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import json import json
from pathlib import Path from pathlib import Path
import inspect
from .logger import get_logger from .logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.common.cache_manager import tool_cache
logger = get_logger("tool_history") logger = get_logger("tool_history")
@@ -113,34 +115,6 @@ class ToolHistoryManager:
except Exception as e: except Exception as e:
logger.error(f"记录工具调用时发生错误: {e}") logger.error(f"记录工具调用时发生错误: {e}")
def find_cached_result(self, tool_name: str, args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""查找匹配的缓存记录
Args:
tool_name: 工具名称
args: 工具调用参数
Returns:
Optional[Dict[str, Any]]: 如果找到匹配的缓存记录则返回结果否则返回None
"""
# 检查是否启用历史记录
if not global_config.tool.history.enable_history:
return None
# 清理输入参数中的敏感信息以便比较
sanitized_input_args = self._sanitize_args(args)
# 按时间倒序遍历历史记录
for record in reversed(self._history):
if (record["tool_name"] == tool_name and
record["status"] == "completed" and
record["ttl_count"] < record.get("ttl", 5)):
# 比较参数是否匹配
if self._sanitize_args(record["arguments"]) == sanitized_input_args:
logger.info(f"工具 {tool_name} 命中缓存记录")
return record["result"]
return None
def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]: def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""清理参数中的敏感信息""" """清理参数中的敏感信息"""
sensitive_keys = ['api_key', 'token', 'password', 'secret'] sensitive_keys = ['api_key', 'token', 'password', 'secret']
@@ -327,27 +301,78 @@ class ToolHistoryManager:
def wrap_tool_executor(): def wrap_tool_executor():
""" """
包装工具执行器以添加历史记录功能 包装工具执行器以添加历史记录和缓存功能
这个函数应该在系统启动时被调用一次 这个函数应该在系统启动时被调用一次
""" """
from src.plugin_system.core.tool_use import ToolExecutor from src.plugin_system.core.tool_use import ToolExecutor
from src.plugin_system.apis.tool_api import get_tool_instance
original_execute = ToolExecutor.execute_tool_call original_execute = ToolExecutor.execute_tool_call
history_manager = ToolHistoryManager() history_manager = ToolHistoryManager()
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None): async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
start_time = time.time() start_time = time.time()
# 确保我们有 tool_instance
if not tool_instance:
tool_instance = get_tool_instance(tool_call.func_name)
# 首先检查缓存 # 如果没有 tool_instance就无法进行缓存检查直接执行
if cached_result := history_manager.find_cached_result(tool_call.func_name, tool_call.args): if not tool_instance:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行") result = await original_execute(self, tool_call, None)
return cached_result execution_time = time.time() - start_time
history_manager.record_tool_call(
tool_name=tool_call.func_name,
args=tool_call.args,
result=result,
execution_time=execution_time,
status="completed",
chat_id=getattr(self, 'chat_id', None),
ttl=5 # Default TTL
)
return result
# 新的缓存逻辑
if tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
try: try:
result = await original_execute(self, tool_call, tool_instance) result = await original_execute(self, tool_call, tool_instance)
execution_time = time.time() - start_time execution_time = time.time() - start_time
# 获取工具的ttl值 # 缓存结果
ttl = getattr(tool_instance, 'history_ttl', 5) if tool_instance else 5 if tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
# 记录成功的调用 # 记录成功的调用
history_manager.record_tool_call( history_manager.record_tool_call(
@@ -357,16 +382,13 @@ def wrap_tool_executor():
execution_time=execution_time, execution_time=execution_time,
status="completed", status="completed",
chat_id=getattr(self, 'chat_id', None), chat_id=getattr(self, 'chat_id', None),
ttl=ttl ttl=tool_instance.history_ttl
) )
return result return result
except Exception as e: except Exception as e:
execution_time = time.time() - start_time execution_time = time.time() - start_time
# 获取工具的ttl值
ttl = getattr(tool_instance, 'history_ttl', 5) if tool_instance else 5
# 记录失败的调用 # 记录失败的调用
history_manager.record_tool_call( history_manager.record_tool_call(
tool_name=tool_call.func_name, tool_name=tool_call.func_name,
@@ -375,7 +397,7 @@ def wrap_tool_executor():
execution_time=execution_time, execution_time=execution_time,
status="error", status="error",
chat_id=getattr(self, 'chat_id', None), chat_id=getattr(self, 'chat_id', None),
ttl=ttl ttl=tool_instance.history_ttl
) )
raise raise

View File

@@ -0,0 +1,19 @@
from .base import VectorDBBase
from .chromadb_impl import ChromaDBImpl
def get_vector_db_service() -> VectorDBBase:
"""
工厂函数,初始化并返回向量数据库服务实例。
目前硬编码为 ChromaDB未来可以从配置中读取。
"""
# TODO: 从全局配置中读取数据库类型和路径
db_path = "data/chroma_db"
# ChromaDBImpl 是一个单例,所以这里每次调用都会返回同一个实例
return ChromaDBImpl(path=db_path)
# 全局向量数据库服务实例
vector_db_service: VectorDBBase = get_vector_db_service()
__all__ = ["vector_db_service", "VectorDBBase"]

View File

@@ -0,0 +1,117 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class VectorDBBase(ABC):
"""
向量数据库的抽象基类 (ABC),定义了所有向量数据库实现必须遵循的接口。
"""
@abstractmethod
def __init__(self, path: str, **kwargs: Any):
"""
初始化向量数据库客户端。
Args:
path (str): 数据库文件的存储路径。
**kwargs: 其他特定于实现的参数。
"""
pass
@abstractmethod
def get_or_create_collection(self, name: str, **kwargs: Any) -> Any:
"""
获取或创建一个集合 (Collection)。
Args:
name (str): 集合的名称。
**kwargs: 其他特定于实现的参数 (例如 metadata)。
Returns:
Any: 代表集合的对象。
"""
pass
@abstractmethod
def add(
self,
collection_name: str,
embeddings: List[List[float]],
documents: Optional[List[str]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None,
) -> None:
"""
向指定集合中添加数据。
Args:
collection_name (str): 目标集合的名称。
embeddings (List[List[float]]): 向量列表。
documents (Optional[List[str]], optional): 文档列表。Defaults to None.
metadatas (Optional[List[Dict[str, Any]]], optional): 元数据列表。Defaults to None.
ids (Optional[List[str]], optional): ID 列表。Defaults to None.
"""
pass
@abstractmethod
def query(
self,
collection_name: str,
query_embeddings: List[List[float]],
n_results: int = 1,
where: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Dict[str, List[Any]]:
"""
在指定集合中查询相似向量。
Args:
collection_name (str): 目标集合的名称。
query_embeddings (List[List[float]]): 用于查询的向量列表。
n_results (int, optional): 返回结果的数量。Defaults to 1.
where (Optional[Dict[str, Any]], optional): 元数据过滤条件。Defaults to None.
**kwargs: 其他特定于实现的参数。
Returns:
Dict[str, List[Any]]: 查询结果,通常包含 ids, distances, metadatas, documents。
"""
pass
@abstractmethod
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict[str, Any]] = None,
) -> None:
"""
从指定集合中删除数据。
Args:
collection_name (str): 目标集合的名称。
ids (Optional[List[str]], optional): 要删除的条目的 ID 列表。Defaults to None.
where (Optional[Dict[str, Any]], optional): 基于元数据的过滤条件。Defaults to None.
"""
pass
@abstractmethod
def count(self, collection_name: str) -> int:
"""
获取指定集合中的条目总数。
Args:
collection_name (str): 目标集合的名称。
Returns:
int: 条目总数。
"""
pass
@abstractmethod
def delete_collection(self, name: str) -> None:
"""
删除一个集合。
Args:
name (str): 要删除的集合的名称。
"""
pass

View File

@@ -0,0 +1,137 @@
import threading
from typing import Any, Dict, List, Optional
import chromadb
from chromadb.config import Settings
from .base import VectorDBBase
from src.common.logger import get_logger
logger = get_logger("chromadb_impl")
class ChromaDBImpl(VectorDBBase):
"""
ChromaDB 的具体实现,遵循 VectorDBBase 接口。
采用单例模式,确保全局只有一个 ChromaDB 客户端实例。
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super(ChromaDBImpl, cls).__new__(cls)
return cls._instance
def __init__(self, path: str = "data/chroma_db", **kwargs: Any):
"""
初始化 ChromaDB 客户端。
由于是单例,这个初始化只会执行一次。
"""
if not hasattr(self, '_initialized'):
with self._lock:
if not hasattr(self, '_initialized'):
try:
self.client = chromadb.PersistentClient(
path=path,
settings=Settings(anonymized_telemetry=False)
)
self._collections: Dict[str, Any] = {}
self._initialized = True
logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}")
except Exception as e:
logger.error(f"ChromaDB 初始化失败: {e}")
self.client = None
self._initialized = False
def get_or_create_collection(self, name: str, **kwargs: Any) -> Any:
if not self.client:
raise ConnectionError("ChromaDB 客户端未初始化")
if name in self._collections:
return self._collections[name]
try:
collection = self.client.get_or_create_collection(name=name, **kwargs)
self._collections[name] = collection
logger.info(f"成功获取或创建集合: '{name}'")
return collection
except Exception as e:
logger.error(f"获取或创建集合 '{name}' 失败: {e}")
return None
def add(
self,
collection_name: str,
embeddings: List[List[float]],
documents: Optional[List[str]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None,
) -> None:
collection = self.get_or_create_collection(collection_name)
if collection:
try:
collection.add(
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=ids,
)
except Exception as e:
logger.error(f"向集合 '{collection_name}' 添加数据失败: {e}")
def query(
self,
collection_name: str,
query_embeddings: List[List[float]],
n_results: int = 1,
where: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Dict[str, List[Any]]:
collection = self.get_or_create_collection(collection_name)
if collection:
try:
return collection.query(
query_embeddings=query_embeddings,
n_results=n_results,
where=where or {},
**kwargs,
)
except Exception as e:
logger.error(f"查询集合 '{collection_name}' 失败: {e}")
return {}
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict[str, Any]] = None,
) -> None:
collection = self.get_or_create_collection(collection_name)
if collection:
try:
collection.delete(ids=ids, where=where)
except Exception as e:
logger.error(f"从集合 '{collection_name}' 删除数据失败: {e}")
def count(self, collection_name: str) -> int:
collection = self.get_or_create_collection(collection_name)
if collection:
try:
return collection.count()
except Exception as e:
logger.error(f"获取集合 '{collection_name}' 计数失败: {e}")
return 0
def delete_collection(self, name: str) -> None:
if not self.client:
raise ConnectionError("ChromaDB 客户端未初始化")
try:
self.client.delete_collection(name=name)
if name in self._collections:
del self._collections[name]
logger.info(f"集合 '{name}' 已被删除")
except Exception as e:
logger.error(f"删除集合 '{name}' 失败: {e}")

View File

@@ -31,6 +31,13 @@ class BaseTool(ABC):
history_ttl: int = 5 history_ttl: int = 5
"""工具调用历史记录的TTL值默认为5。设为0表示不记录历史""" """工具调用历史记录的TTL值默认为5。设为0表示不记录历史"""
enable_cache: bool = False
"""是否为该工具启用缓存"""
cache_ttl: int = 3600
"""缓存的TTL值默认为3600秒1小时"""
semantic_cache_query_key: Optional[str] = None
"""用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索"""
def __init__(self, plugin_config: Optional[dict] = None): def __init__(self, plugin_config: Optional[dict] = None):
self.plugin_config = plugin_config or {} # 直接存储插件配置字典 self.plugin_config = plugin_config or {} # 直接存储插件配置字典

View File

@@ -11,7 +11,6 @@ from bs4 import BeautifulSoup
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BaseTool, ToolParamType, llm_api from src.plugin_system import BaseTool, ToolParamType, llm_api
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
from src.common.cache_manager import tool_cache
from ..utils.formatters import format_url_parse_results from ..utils.formatters import format_url_parse_results
from ..utils.url_utils import parse_urls_from_input, validate_urls from ..utils.url_utils import parse_urls_from_input, validate_urls
@@ -30,6 +29,12 @@ class URLParserTool(BaseTool):
parameters = [ parameters = [
("urls", ToolParamType.STRING, "要理解的网站", True, None), ("urls", ToolParamType.STRING, "要理解的网站", True, None),
] ]
# --- 新的缓存配置 ---
enable_cache: bool = True
cache_ttl: int = 86400 # 缓存24小时
semantic_cache_query_key: str = "urls"
# --------------------
def __init__(self, plugin_config=None): def __init__(self, plugin_config=None):
super().__init__(plugin_config) super().__init__(plugin_config)
@@ -42,10 +47,11 @@ class URLParserTool(BaseTool):
if exa_api_keys is None: if exa_api_keys is None:
# 从插件配置文件读取 # 从插件配置文件读取
exa_api_keys = self.get_config("exa.api_keys", []) exa_api_keys = self.get_config("exa.api_keys", [])
# 创建API密钥管理器 # 创建API密钥管理器
from typing import cast, List
self.api_manager = create_api_key_manager_from_config( self.api_manager = create_api_key_manager_from_config(
exa_api_keys, cast(List[str], exa_api_keys),
lambda key: Exa(api_key=key), lambda key: Exa(api_key=key),
"Exa URL Parser" "Exa URL Parser"
) )
@@ -135,16 +141,6 @@ class URLParserTool(BaseTool):
""" """
执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。 执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。
""" """
# 获取当前文件路径用于缓存键
import os
current_file_path = os.path.abspath(__file__)
# 检查缓存
cached_result = await tool_cache.get(self.name, function_args, current_file_path)
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result
urls_input = function_args.get("urls") urls_input = function_args.get("urls")
if not urls_input: if not urls_input:
return {"error": "URL列表不能为空。"} return {"error": "URL列表不能为空。"}
@@ -235,8 +231,4 @@ class URLParserTool(BaseTool):
"errors": error_messages "errors": error_messages
} }
# 保存到缓存
if "error" not in result:
await tool_cache.set(self.name, function_args, current_file_path, result)
return result return result

View File

@@ -7,7 +7,6 @@ from typing import Any, Dict, List
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BaseTool, ToolParamType from src.plugin_system import BaseTool, ToolParamType
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
from src.common.cache_manager import tool_cache
from ..engines.exa_engine import ExaSearchEngine from ..engines.exa_engine import ExaSearchEngine
from ..engines.tavily_engine import TavilySearchEngine from ..engines.tavily_engine import TavilySearchEngine
@@ -31,6 +30,12 @@ class WebSurfingTool(BaseTool):
("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'", False, ["any", "week", "month"]) ("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'", False, ["any", "week", "month"])
] # type: ignore ] # type: ignore
# --- 新的缓存配置 ---
enable_cache: bool = True
cache_ttl: int = 7200 # 缓存2小时
semantic_cache_query_key: str = "query"
# --------------------
def __init__(self, plugin_config=None): def __init__(self, plugin_config=None):
super().__init__(plugin_config) super().__init__(plugin_config)
# 初始化搜索引擎 # 初始化搜索引擎
@@ -46,16 +51,6 @@ class WebSurfingTool(BaseTool):
if not query: if not query:
return {"error": "搜索查询不能为空。"} return {"error": "搜索查询不能为空。"}
# 获取当前文件路径用于缓存键
import os
current_file_path = os.path.abspath(__file__)
# 检查缓存
cached_result = await tool_cache.get(self.name, function_args, current_file_path, semantic_query=query)
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result
# 读取搜索配置 # 读取搜索配置
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"]) enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
search_strategy = config_api.get_global_config("web_search.search_strategy", "single") search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
@@ -69,10 +64,6 @@ class WebSurfingTool(BaseTool):
result = await self._execute_fallback_search(function_args, enabled_engines) result = await self._execute_fallback_search(function_args, enabled_engines)
else: # single else: # single
result = await self._execute_single_search(function_args, enabled_engines) result = await self._execute_single_search(function_args, enabled_engines)
# 保存到缓存
if "error" not in result:
await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query)
return result return result