Merge branch 'master' of https://github.com/MoFox-Studio/MoFox_Bot
This commit is contained in:
124
docs/plugins/tool_caching_guide.md
Normal file
124
docs/plugins/tool_caching_guide.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# 自动化工具缓存系统使用指南
|
||||
|
||||
为了提升性能并减少不必要的重复计算或API调用,MMC内置了一套强大且易于使用的自动化工具缓存系统。该系统同时支持传统的**精确缓存**和先进的**语义缓存**。工具开发者无需编写任何手动缓存逻辑,只需在工具类中设置几个属性,即可轻松启用和配置缓存行为。
|
||||
|
||||
## 核心概念
|
||||
|
||||
- **精确缓存 (KV Cache)**: 当一个工具被调用时,系统会根据工具名称和所有参数生成一个唯一的键。只有当**下一次调用的工具名和所有参数与之前完全一致**时,才会命中缓存。
|
||||
- **语义缓存 (Vector Cache)**: 它不要求参数完全一致,而是理解参数的**语义和意图**。例如,`"查询深圳今天的天气"` 和 `"今天深圳天气怎么样"` 这两个不同的查询,在语义上是高度相似的。如果启用了语义缓存,第二个查询就能成功命中由第一个查询产生的缓存结果。
|
||||
|
||||
## 如何为你的工具启用缓存
|
||||
|
||||
为你的工具(必须继承自 `BaseTool`)启用缓存非常简单,只需在你的工具类定义中添加以下一个或多个属性即可:
|
||||
|
||||
### 1. `enable_cache: bool`
|
||||
|
||||
这是启用缓存的总开关。
|
||||
|
||||
- **类型**: `bool`
|
||||
- **默认值**: `False`
|
||||
- **作用**: 设置为 `True` 即可为该工具启用缓存功能。如果为 `False`,后续的所有缓存配置都将无效。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class MyAwesomeTool(BaseTool):
|
||||
# ... 其他定义 ...
|
||||
enable_cache: bool = True
|
||||
```
|
||||
|
||||
### 2. `cache_ttl: int`
|
||||
|
||||
设置缓存的生存时间(Time-To-Live)。
|
||||
|
||||
- **类型**: `int`
|
||||
- **单位**: 秒
|
||||
- **默认值**: `3600` (1小时)
|
||||
- **作用**: 定义缓存条目在被视为过期之前可以存活多长时间。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class MyLongTermCacheTool(BaseTool):
|
||||
# ... 其他定义 ...
|
||||
enable_cache: bool = True
|
||||
cache_ttl: int = 86400 # 缓存24小时
|
||||
```
|
||||
|
||||
### 3. `semantic_cache_query_key: Optional[str]`
|
||||
|
||||
启用语义缓存的关键。
|
||||
|
||||
- **类型**: `Optional[str]`
|
||||
- **默认值**: `None`
|
||||
- **作用**:
|
||||
- 将此属性的值设置为你工具的某个**参数的名称**(字符串)。
|
||||
- 自动化缓存系统在工作时,会提取该参数的值,将其转换为向量,并进行语义相似度搜索。
|
||||
- 如果该值为 `None`,则此工具**仅使用精确缓存**。
|
||||
|
||||
**示例**:
|
||||
```python
|
||||
class WebSurfingTool(BaseTool):
|
||||
name: str = "web_search"
|
||||
parameters = [
|
||||
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||
# ... 其他参数 ...
|
||||
]
|
||||
|
||||
# --- 缓存配置 ---
|
||||
enable_cache: bool = True
|
||||
cache_ttl: int = 7200 # 缓存2小时
|
||||
semantic_cache_query_key: str = "query" # <-- 关键!
|
||||
```
|
||||
在上面的例子中,`web_search` 工具的 `"query"` 参数值(例如,用户输入的搜索词)将被用于语义缓存搜索。
|
||||
|
||||
## 完整示例
|
||||
|
||||
假设我们有一个调用外部API来获取股票价格的工具。由于股价在短时间内相对稳定,且查询意图可能相似(如 "苹果股价" vs "AAPL股价"),因此非常适合使用缓存。
|
||||
|
||||
```python
|
||||
# in your_plugin/tools/stock_checker.py
|
||||
|
||||
from src.plugin_system import BaseTool, ToolParamType
|
||||
|
||||
class StockCheckerTool(BaseTool):
|
||||
"""
|
||||
一个用于查询股票价格的工具。
|
||||
"""
|
||||
name: str = "get_stock_price"
|
||||
description: str = "获取指定公司或股票代码的最新价格。"
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
("symbol", ToolParamType.STRING, "公司名称或股票代码 (e.g., 'AAPL', '苹果')", True, None),
|
||||
]
|
||||
|
||||
# --- 缓存配置 ---
|
||||
# 1. 开启缓存
|
||||
enable_cache: bool = True
|
||||
# 2. 股价信息缓存10分钟
|
||||
cache_ttl: int = 600
|
||||
# 3. 使用 "symbol" 参数进行语义搜索
|
||||
semantic_cache_query_key: str = "symbol"
|
||||
# --------------------
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
symbol = function_args.get("symbol")
|
||||
|
||||
# ... 这里是你调用外部API获取股票价格的逻辑 ...
|
||||
# price = await some_stock_api.get_price(symbol)
|
||||
price = 123.45 # 示例价格
|
||||
|
||||
return {
|
||||
"type": "stock_price_result",
|
||||
"content": f"{symbol} 的当前价格是 ${price}"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
通过以上简单的三行配置,`StockCheckerTool` 现在就拥有了强大的自动化缓存能力:
|
||||
|
||||
- 当用户查询 `"苹果"` 时,工具会执行并缓存结果。
|
||||
- 在接下来的10分钟内,如果再次查询 `"苹果"`,将直接从精确缓存返回结果。
|
||||
- 更智能的是,如果另一个用户查询 `"AAPL"`,语义缓存系统会识别出 `"AAPL"` 和 `"苹果"` 在语义上高度相关,大概率也会直接返回缓存的结果,而无需再次调用API。
|
||||
|
||||
---
|
||||
|
||||
现在,你可以专注于实现工具的核心逻辑,把缓存的复杂性交给MMC的自动化系统来处理。
|
||||
128
docs/vector_db_usage_guide.md
Normal file
128
docs/vector_db_usage_guide.md
Normal file
@@ -0,0 +1,128 @@
|
||||
# 统一向量数据库服务使用指南
|
||||
|
||||
本文档旨在说明如何在 `mmc` 项目中使用新集成的统一向量数据库服务。该服务提供了一个标准化的接口,用于与底层向量数据库(当前为 ChromaDB)进行交互,同时确保了代码的解耦和未来的可扩展性。
|
||||
|
||||
## 核心设计理念
|
||||
|
||||
1. **统一入口**: 所有对向量数据库的操作都应通过全局单例 `vector_db_service` 进行。
|
||||
2. **抽象接口**: 服务遵循 `VectorDBBase` 抽象基类定义的接口,未来可以轻松替换为其他向量数据库(如 Milvus, FAISS)而无需修改业务代码。
|
||||
3. **单例模式**: 整个应用程序共享一个数据库客户端实例,避免了资源浪费和管理混乱。
|
||||
4. **数据隔离**: 使用不同的 `collection` 名称来隔离不同业务模块(如语义缓存、瞬时记忆)的数据。在 `collection` 内部,使用 `metadata` 字段(如 `chat_id`)来隔离不同用户或会话的数据。
|
||||
|
||||
## 如何使用
|
||||
|
||||
### 1. 导入服务
|
||||
|
||||
在任何需要使用向量数据库的文件中,只需导入全局服务实例:
|
||||
|
||||
```python
|
||||
from src.common.vector_db import vector_db_service
|
||||
```
|
||||
|
||||
### 2. 主要操作
|
||||
|
||||
`vector_db_service` 对象提供了所有你需要的方法,这些方法都定义在 `VectorDBBase` 中。
|
||||
|
||||
#### a. 获取或创建集合 (Collection)
|
||||
|
||||
在操作数据之前,你需要先指定一个集合。如果集合不存在,它将被自动创建。
|
||||
|
||||
```python
|
||||
# 为语义缓存创建一个集合
|
||||
vector_db_service.get_or_create_collection(name="semantic_cache")
|
||||
|
||||
# 为瞬时记忆创建一个集合
|
||||
vector_db_service.get_or_create_collection(
|
||||
name="instant_memory",
|
||||
metadata={"hnsw:space": "cosine"} # 可以传入特定于实现的参数
|
||||
)
|
||||
```
|
||||
|
||||
#### b. 添加数据
|
||||
|
||||
使用 `add` 方法向指定集合中添加向量、文档和元数据。
|
||||
|
||||
```python
|
||||
collection_name = "instant_memory"
|
||||
chat_id = "user_123"
|
||||
message_id = "msg_abc"
|
||||
embedding_vector = [0.1, 0.2, 0.3, ...] # 你的 embedding 向量
|
||||
content = "你好,这是一个测试消息"
|
||||
|
||||
vector_db_service.add(
|
||||
collection_name=collection_name,
|
||||
embeddings=[embedding_vector],
|
||||
documents=[content],
|
||||
metadatas=[{
|
||||
"chat_id": chat_id,
|
||||
"timestamp": 1678886400.0,
|
||||
"sender": "user"
|
||||
}],
|
||||
ids=[message_id]
|
||||
)
|
||||
```
|
||||
|
||||
#### c. 查询数据
|
||||
|
||||
使用 `query` 方法来查找相似的向量。你可以使用 `where` 子句来过滤元数据。
|
||||
|
||||
```python
|
||||
query_vector = [0.11, 0.22, 0.33, ...] # 用于查询的向量
|
||||
collection_name = "instant_memory"
|
||||
chat_id_to_query = "user_123"
|
||||
|
||||
results = vector_db_service.query(
|
||||
collection_name=collection_name,
|
||||
query_embeddings=[query_vector],
|
||||
n_results=5, # 返回最相似的5个结果
|
||||
where={"chat_id": chat_id_to_query} # **重要**: 使用 where 来隔离不同聊天的数据
|
||||
)
|
||||
|
||||
# results 的结构:
|
||||
# {
|
||||
# 'ids': [['msg_abc']],
|
||||
# 'distances': [[0.0123]],
|
||||
# 'metadatas': [[{'chat_id': 'user_123', ...}]],
|
||||
# 'embeddings': None,
|
||||
# 'documents': [['你好,这是一个测试消息']]
|
||||
# }
|
||||
print(results)
|
||||
```
|
||||
|
||||
#### d. 删除数据
|
||||
|
||||
你可以根据 `id` 或 `where` 条件来删除数据。
|
||||
|
||||
```python
|
||||
# 根据 ID 删除
|
||||
vector_db_service.delete(
|
||||
collection_name="instant_memory",
|
||||
ids=["msg_abc"]
|
||||
)
|
||||
|
||||
# 根据 where 条件删除 (例如,删除某个用户的所有记忆)
|
||||
vector_db_service.delete(
|
||||
collection_name="instant_memory",
|
||||
where={"chat_id": "user_123"}
|
||||
)
|
||||
```
|
||||
|
||||
#### e. 获取集合数量
|
||||
|
||||
使用 `count` 方法获取一个集合中的条目总数。
|
||||
|
||||
```python
|
||||
count = vector_db_service.count(collection_name="semantic_cache")
|
||||
print(f"语义缓存集合中有 {count} 条数据。")
|
||||
```
|
||||
**注意**: `count` 方法目前返回整个集合的条目数,不会根据 `where` 条件进行过滤。
|
||||
|
||||
### 3. 代码位置
|
||||
|
||||
- **抽象基类**: [`mmc/src/common/vector_db/base.py`](mmc/src/common/vector_db/base.py)
|
||||
- **ChromaDB 实现**: [`mmc/src/common/vector_db/chromadb_impl.py`](mmc/src/common/vector_db/chromadb_impl.py)
|
||||
- **服务入口**: [`mmc/src/common/vector_db/__init__.py`](mmc/src/common/vector_db/__init__.py)
|
||||
|
||||
---
|
||||
|
||||
这份完整的文档应该能帮助您和团队的其他成员正确地使用新的向量数据库服务。如果您有任何其他问题,请随时提出。
|
||||
@@ -4,10 +4,9 @@ from typing import List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
|
||||
logger = get_logger("vector_instant_memory_v2")
|
||||
@@ -45,10 +44,7 @@ class VectorInstantMemoryV2:
|
||||
self.chat_id = chat_id
|
||||
self.retention_hours = retention_hours
|
||||
self.cleanup_interval = cleanup_interval
|
||||
|
||||
# ChromaDB相关
|
||||
self.client = None
|
||||
self.collection = None
|
||||
self.collection_name = "instant_memory"
|
||||
|
||||
# 清理任务相关
|
||||
self.cleanup_task = None
|
||||
@@ -61,22 +57,16 @@ class VectorInstantMemoryV2:
|
||||
logger.info(f"向量瞬时记忆系统V2初始化完成: {chat_id} (保留{retention_hours}小时)")
|
||||
|
||||
def _init_chroma(self):
|
||||
"""初始化ChromaDB连接"""
|
||||
"""使用全局服务初始化向量数据库集合"""
|
||||
try:
|
||||
db_path = f"./data/memory_vectors/{self.chat_id}"
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=db_path,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name="chat_messages",
|
||||
# 现在我们只获取集合,而不是创建新的客户端
|
||||
vector_db_service.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
logger.info(f"向量记忆数据库初始化成功: {db_path}")
|
||||
logger.info(f"向量记忆集合 '{self.collection_name}' 已准备就绪")
|
||||
except Exception as e:
|
||||
logger.error(f"ChromaDB初始化失败: {e}")
|
||||
self.client = None
|
||||
self.collection = None
|
||||
logger.error(f"获取向量记忆集合失败: {e}")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动定时清理任务"""
|
||||
@@ -95,35 +85,23 @@ class VectorInstantMemoryV2:
|
||||
|
||||
def _cleanup_expired_messages(self):
|
||||
"""清理过期的聊天记录"""
|
||||
if not self.collection:
|
||||
return
|
||||
|
||||
try:
|
||||
# 计算过期时间戳
|
||||
expire_time = time.time() - (self.retention_hours * 3600)
|
||||
|
||||
# 查询所有记录
|
||||
all_results = self.collection.get(
|
||||
where={"chat_id": self.chat_id},
|
||||
include=["metadatas"]
|
||||
# 使用 where 条件来删除过期记录
|
||||
# 注意: ChromaDB 的 where 过滤器目前对 timestamp 的 $lt 操作支持可能有限
|
||||
# 一个更可靠的方法是 get() -> filter -> delete()
|
||||
# 但为了简化,我们先尝试直接 delete
|
||||
|
||||
# TODO: 确认 ChromaDB 对 $lt 在 metadata 上的支持。如果不支持,需要实现 get-filter-delete 模式。
|
||||
vector_db_service.delete(
|
||||
collection_name=self.collection_name,
|
||||
where={
|
||||
"chat_id": self.chat_id,
|
||||
"timestamp": {"$lt": expire_time}
|
||||
}
|
||||
)
|
||||
|
||||
# 找出过期的记录ID
|
||||
expired_ids = []
|
||||
metadatas = all_results.get("metadatas") or []
|
||||
ids = all_results.get("ids") or []
|
||||
|
||||
for i, metadata in enumerate(metadatas):
|
||||
if metadata and isinstance(metadata, dict):
|
||||
timestamp = metadata.get("timestamp", 0)
|
||||
if isinstance(timestamp, (int, float)) and timestamp < expire_time:
|
||||
if i < len(ids):
|
||||
expired_ids.append(ids[i])
|
||||
|
||||
# 批量删除过期记录
|
||||
if expired_ids:
|
||||
self.collection.delete(ids=expired_ids)
|
||||
logger.info(f"清理了 {len(expired_ids)} 条过期聊天记录")
|
||||
logger.info(f"已为 chat_id '{self.chat_id}' 触发过期记录清理")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期记录失败: {e}")
|
||||
@@ -139,7 +117,7 @@ class VectorInstantMemoryV2:
|
||||
Returns:
|
||||
bool: 是否存储成功
|
||||
"""
|
||||
if not self.collection or not content.strip():
|
||||
if not content.strip():
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -149,10 +127,8 @@ class VectorInstantMemoryV2:
|
||||
logger.warning(f"消息向量生成失败: {content[:50]}...")
|
||||
return False
|
||||
|
||||
# 生成唯一消息ID
|
||||
message_id = f"{self.chat_id}_{int(time.time() * 1000)}_{hash(content) % 10000}"
|
||||
|
||||
# 创建消息对象
|
||||
message = ChatMessage(
|
||||
message_id=message_id,
|
||||
chat_id=self.chat_id,
|
||||
@@ -161,8 +137,9 @@ class VectorInstantMemoryV2:
|
||||
sender=sender
|
||||
)
|
||||
|
||||
# 存储到ChromaDB
|
||||
self.collection.add(
|
||||
# 使用新的服务存储
|
||||
vector_db_service.add(
|
||||
collection_name=self.collection_name,
|
||||
embeddings=[message_vector],
|
||||
documents=[content],
|
||||
metadatas=[{
|
||||
@@ -194,23 +171,23 @@ class VectorInstantMemoryV2:
|
||||
Returns:
|
||||
List[Dict]: 相似消息列表,包含content、similarity、timestamp等信息
|
||||
"""
|
||||
if not self.collection or not query.strip():
|
||||
if not query.strip():
|
||||
return []
|
||||
|
||||
try:
|
||||
# 生成查询向量
|
||||
query_vector = await get_embedding(query)
|
||||
if not query_vector:
|
||||
return []
|
||||
|
||||
# 向量相似度搜索
|
||||
results = self.collection.query(
|
||||
# 使用新的服务进行查询
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.collection_name,
|
||||
query_embeddings=[query_vector],
|
||||
n_results=top_k,
|
||||
where={"chat_id": self.chat_id}
|
||||
)
|
||||
|
||||
if not results['documents'] or not results['documents'][0]:
|
||||
if not results.get('documents') or not results['documents'][0]:
|
||||
return []
|
||||
|
||||
# 处理搜索结果
|
||||
@@ -311,15 +288,18 @@ class VectorInstantMemoryV2:
|
||||
"cleanup_interval": self.cleanup_interval,
|
||||
"system_status": "running" if self.is_running else "stopped",
|
||||
"total_messages": 0,
|
||||
"db_status": "connected" if self.collection else "disconnected"
|
||||
"db_status": "connected"
|
||||
}
|
||||
|
||||
if self.collection:
|
||||
try:
|
||||
result = self.collection.count()
|
||||
stats["total_messages"] = result
|
||||
except Exception:
|
||||
stats["total_messages"] = "查询失败"
|
||||
try:
|
||||
# 注意:count() 现在没有 chat_id 过滤,返回的是整个集合的数量
|
||||
# 若要精确计数,需要 get(where={"chat_id": ...}) 然后 len(results['ids'])
|
||||
# 这里为了简化,暂时显示集合总数
|
||||
result = vector_db_service.count(collection_name=self.collection_name)
|
||||
stats["total_messages"] = result
|
||||
except Exception:
|
||||
stats["total_messages"] = "查询失败"
|
||||
stats["db_status"] = "disconnected"
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@ import hashlib
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import faiss
|
||||
import chromadb
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.sqlalchemy_models import CacheEntries
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
logger = get_logger("cache_manager")
|
||||
|
||||
@@ -28,25 +28,23 @@ class CacheManager:
|
||||
cls._instance = super(CacheManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, default_ttl: int = 3600, chroma_path: str = "data/chroma_db"):
|
||||
def __init__(self, default_ttl: int = 3600):
|
||||
"""
|
||||
初始化缓存管理器。
|
||||
"""
|
||||
if not hasattr(self, '_initialized'):
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
self.semantic_cache_collection_name = "semantic_cache"
|
||||
|
||||
# L1 缓存 (内存)
|
||||
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||
|
||||
# 语义缓存 (ChromaDB)
|
||||
|
||||
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
|
||||
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||
# L2 向量缓存 (使用新的服务)
|
||||
vector_db_service.get_or_create_collection(self.semantic_cache_collection_name)
|
||||
|
||||
|
||||
# 嵌入模型
|
||||
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
|
||||
|
||||
@@ -152,18 +150,20 @@ class CacheManager:
|
||||
return self.l1_kv_cache[l1_hit_key]["data"]
|
||||
|
||||
# 步骤 2b: L2 精确缓存 (数据库)
|
||||
cache_results = await db_query(
|
||||
cache_results_obj = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": key},
|
||||
single_result=True
|
||||
)
|
||||
|
||||
if cache_results:
|
||||
expires_at = cache_results["expires_at"]
|
||||
if cache_results_obj:
|
||||
# 使用 getattr 安全访问属性,避免 Pylance 类型检查错误
|
||||
expires_at = getattr(cache_results_obj, "expires_at", 0)
|
||||
if time.time() < expires_at:
|
||||
logger.info(f"命中L2键值缓存: {key}")
|
||||
data = orjson.loads(cache_results["cache_value"])
|
||||
cache_value = getattr(cache_results_obj, "cache_value", "{}")
|
||||
data = orjson.loads(cache_value)
|
||||
|
||||
# 更新访问统计
|
||||
await db_query(
|
||||
@@ -172,7 +172,7 @@ class CacheManager:
|
||||
filters={"cache_key": key},
|
||||
data={
|
||||
"last_accessed": time.time(),
|
||||
"access_count": cache_results["access_count"] + 1
|
||||
"access_count": getattr(cache_results_obj, "access_count", 0) + 1
|
||||
}
|
||||
)
|
||||
|
||||
@@ -187,29 +187,35 @@ class CacheManager:
|
||||
filters={"cache_key": key}
|
||||
)
|
||||
|
||||
# 步骤 2c: L2 语义缓存 (ChromaDB)
|
||||
if query_embedding is not None and self.chroma_collection:
|
||||
# 步骤 2c: L2 语义缓存 (VectorDB Service)
|
||||
if query_embedding is not None:
|
||||
try:
|
||||
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
|
||||
if results and results['ids'] and results['ids'][0]:
|
||||
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
n_results=1
|
||||
)
|
||||
if results and results.get('ids') and results['ids'][0]:
|
||||
distance = results['distances'][0][0] if results.get('distances') and results['distances'][0] else 'N/A'
|
||||
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
||||
|
||||
if distance != 'N/A' and distance < 0.75:
|
||||
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
|
||||
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||
|
||||
# 从数据库获取缓存数据
|
||||
semantic_cache_results = await db_query(
|
||||
semantic_cache_results_obj = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": l2_hit_key},
|
||||
single_result=True
|
||||
)
|
||||
|
||||
if semantic_cache_results:
|
||||
expires_at = semantic_cache_results["expires_at"]
|
||||
if semantic_cache_results_obj:
|
||||
expires_at = getattr(semantic_cache_results_obj, "expires_at", 0)
|
||||
if time.time() < expires_at:
|
||||
data = orjson.loads(semantic_cache_results["cache_value"])
|
||||
cache_value = getattr(semantic_cache_results_obj, "cache_value", "{}")
|
||||
data = orjson.loads(cache_value)
|
||||
logger.debug(f"L2语义缓存返回的数据: {data}")
|
||||
|
||||
# 回填 L1
|
||||
@@ -218,13 +224,13 @@ class CacheManager:
|
||||
try:
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
faiss.normalize_L2(query_embedding)
|
||||
self.l1_vector_index.add(x=query_embedding)
|
||||
self.l1_vector_index.add(x=query_embedding) # type: ignore
|
||||
self.l1_vector_id_to_key[new_id] = key
|
||||
except Exception as e:
|
||||
logger.error(f"回填L1向量索引时发生错误: {e}")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning(f"ChromaDB查询失败: {e}")
|
||||
logger.warning(f"VectorDB Service 查询失败: {e}")
|
||||
|
||||
logger.debug(f"缓存未命中: {key}")
|
||||
return None
|
||||
@@ -261,22 +267,27 @@ class CacheManager:
|
||||
)
|
||||
|
||||
# 写入语义缓存
|
||||
if semantic_query and self.embedding_model and self.chroma_collection:
|
||||
if semantic_query and self.embedding_model:
|
||||
try:
|
||||
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
||||
if embedding_result:
|
||||
# embedding_result是一个元组(embedding_vector, model_name),取第一个元素
|
||||
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
|
||||
validated_embedding = self._validate_embedding(embedding_vector)
|
||||
if validated_embedding is not None:
|
||||
embedding = np.array([validated_embedding], dtype='float32')
|
||||
|
||||
# 写入 L1 Vector
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
faiss.normalize_L2(embedding)
|
||||
self.l1_vector_index.add(x=embedding)
|
||||
self.l1_vector_index.add(x=embedding) # type: ignore
|
||||
self.l1_vector_id_to_key[new_id] = key
|
||||
# 写入 L2 Vector
|
||||
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
|
||||
|
||||
# 写入 L2 Vector (使用新的服务)
|
||||
vector_db_service.add(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
embeddings=embedding.tolist(),
|
||||
ids=[key]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"语义缓存写入失败: {e}")
|
||||
|
||||
@@ -298,15 +309,14 @@ class CacheManager:
|
||||
filters={} # 删除所有记录
|
||||
)
|
||||
|
||||
# 清空ChromaDB
|
||||
if self.chroma_collection:
|
||||
try:
|
||||
self.chroma_client.delete_collection(name="semantic_cache")
|
||||
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空ChromaDB失败: {e}")
|
||||
# 清空 VectorDB
|
||||
try:
|
||||
vector_db_service.delete_collection(name=self.semantic_cache_collection_name)
|
||||
vector_db_service.get_or_create_collection(name=self.semantic_cache_collection_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"清空 VectorDB 集合失败: {e}")
|
||||
|
||||
logger.info("L2 (数据库 & ChromaDB) 缓存已清空。")
|
||||
logger.info("L2 (数据库 & VectorDB) 缓存已清空。")
|
||||
|
||||
async def clear_all(self):
|
||||
"""清空所有缓存。"""
|
||||
|
||||
@@ -4,9 +4,11 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import json
|
||||
from pathlib import Path
|
||||
import inspect
|
||||
|
||||
from .logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
logger = get_logger("tool_history")
|
||||
|
||||
@@ -113,34 +115,6 @@ class ToolHistoryManager:
|
||||
except Exception as e:
|
||||
logger.error(f"记录工具调用时发生错误: {e}")
|
||||
|
||||
def find_cached_result(self, tool_name: str, args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""查找匹配的缓存记录
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 如果找到匹配的缓存记录则返回结果,否则返回None
|
||||
"""
|
||||
# 检查是否启用历史记录
|
||||
if not global_config.tool.history.enable_history:
|
||||
return None
|
||||
|
||||
# 清理输入参数中的敏感信息以便比较
|
||||
sanitized_input_args = self._sanitize_args(args)
|
||||
|
||||
# 按时间倒序遍历历史记录
|
||||
for record in reversed(self._history):
|
||||
if (record["tool_name"] == tool_name and
|
||||
record["status"] == "completed" and
|
||||
record["ttl_count"] < record.get("ttl", 5)):
|
||||
# 比较参数是否匹配
|
||||
if self._sanitize_args(record["arguments"]) == sanitized_input_args:
|
||||
logger.info(f"工具 {tool_name} 命中缓存记录")
|
||||
return record["result"]
|
||||
return None
|
||||
|
||||
def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""清理参数中的敏感信息"""
|
||||
sensitive_keys = ['api_key', 'token', 'password', 'secret']
|
||||
@@ -327,27 +301,78 @@ class ToolHistoryManager:
|
||||
|
||||
def wrap_tool_executor():
|
||||
"""
|
||||
包装工具执行器以添加历史记录功能
|
||||
包装工具执行器以添加历史记录和缓存功能
|
||||
这个函数应该在系统启动时被调用一次
|
||||
"""
|
||||
from src.plugin_system.core.tool_use import ToolExecutor
|
||||
from src.plugin_system.apis.tool_api import get_tool_instance
|
||||
original_execute = ToolExecutor.execute_tool_call
|
||||
history_manager = ToolHistoryManager()
|
||||
|
||||
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
|
||||
start_time = time.time()
|
||||
|
||||
# 确保我们有 tool_instance
|
||||
if not tool_instance:
|
||||
tool_instance = get_tool_instance(tool_call.func_name)
|
||||
|
||||
# 首先检查缓存
|
||||
if cached_result := history_manager.find_cached_result(tool_call.func_name, tool_call.args):
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
return cached_result
|
||||
# 如果没有 tool_instance,就无法进行缓存检查,直接执行
|
||||
if not tool_instance:
|
||||
result = await original_execute(self, tool_call, None)
|
||||
execution_time = time.time() - start_time
|
||||
history_manager.record_tool_call(
|
||||
tool_name=tool_call.func_name,
|
||||
args=tool_call.args,
|
||||
result=result,
|
||||
execution_time=execution_time,
|
||||
status="completed",
|
||||
chat_id=getattr(self, 'chat_id', None),
|
||||
ttl=5 # Default TTL
|
||||
)
|
||||
return result
|
||||
|
||||
# 新的缓存逻辑
|
||||
if tool_instance.enable_cache:
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
cached_result = await tool_cache.get(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
|
||||
|
||||
try:
|
||||
result = await original_execute(self, tool_call, tool_instance)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# 获取工具的ttl值
|
||||
ttl = getattr(tool_instance, 'history_ttl', 5) if tool_instance else 5
|
||||
# 缓存结果
|
||||
if tool_instance.enable_cache:
|
||||
try:
|
||||
tool_file_path = inspect.getfile(tool_instance.__class__)
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=tool_call.args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
|
||||
|
||||
# 记录成功的调用
|
||||
history_manager.record_tool_call(
|
||||
@@ -357,16 +382,13 @@ def wrap_tool_executor():
|
||||
execution_time=execution_time,
|
||||
status="completed",
|
||||
chat_id=getattr(self, 'chat_id', None),
|
||||
ttl=ttl
|
||||
ttl=tool_instance.history_ttl
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
# 获取工具的ttl值
|
||||
ttl = getattr(tool_instance, 'history_ttl', 5) if tool_instance else 5
|
||||
|
||||
# 记录失败的调用
|
||||
history_manager.record_tool_call(
|
||||
tool_name=tool_call.func_name,
|
||||
@@ -375,7 +397,7 @@ def wrap_tool_executor():
|
||||
execution_time=execution_time,
|
||||
status="error",
|
||||
chat_id=getattr(self, 'chat_id', None),
|
||||
ttl=ttl
|
||||
ttl=tool_instance.history_ttl
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
19
src/common/vector_db/__init__.py
Normal file
19
src/common/vector_db/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .base import VectorDBBase
|
||||
from .chromadb_impl import ChromaDBImpl
|
||||
|
||||
def get_vector_db_service() -> VectorDBBase:
|
||||
"""
|
||||
工厂函数,初始化并返回向量数据库服务实例。
|
||||
|
||||
目前硬编码为 ChromaDB,未来可以从配置中读取。
|
||||
"""
|
||||
# TODO: 从全局配置中读取数据库类型和路径
|
||||
db_path = "data/chroma_db"
|
||||
|
||||
# ChromaDBImpl 是一个单例,所以这里每次调用都会返回同一个实例
|
||||
return ChromaDBImpl(path=db_path)
|
||||
|
||||
# 全局向量数据库服务实例
|
||||
vector_db_service: VectorDBBase = get_vector_db_service()
|
||||
|
||||
__all__ = ["vector_db_service", "VectorDBBase"]
|
||||
117
src/common/vector_db/base.py
Normal file
117
src/common/vector_db/base.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
class VectorDBBase(ABC):
|
||||
"""
|
||||
向量数据库的抽象基类 (ABC),定义了所有向量数据库实现必须遵循的接口。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, path: str, **kwargs: Any):
|
||||
"""
|
||||
初始化向量数据库客户端。
|
||||
|
||||
Args:
|
||||
path (str): 数据库文件的存储路径。
|
||||
**kwargs: 其他特定于实现的参数。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_or_create_collection(self, name: str, **kwargs: Any) -> Any:
|
||||
"""
|
||||
获取或创建一个集合 (Collection)。
|
||||
|
||||
Args:
|
||||
name (str): 集合的名称。
|
||||
**kwargs: 其他特定于实现的参数 (例如 metadata)。
|
||||
|
||||
Returns:
|
||||
Any: 代表集合的对象。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: List[List[float]],
|
||||
documents: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
向指定集合中添加数据。
|
||||
|
||||
Args:
|
||||
collection_name (str): 目标集合的名称。
|
||||
embeddings (List[List[float]]): 向量列表。
|
||||
documents (Optional[List[str]], optional): 文档列表。Defaults to None.
|
||||
metadatas (Optional[List[Dict[str, Any]]], optional): 元数据列表。Defaults to None.
|
||||
ids (Optional[List[str]], optional): ID 列表。Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_embeddings: List[List[float]],
|
||||
n_results: int = 1,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
在指定集合中查询相似向量。
|
||||
|
||||
Args:
|
||||
collection_name (str): 目标集合的名称。
|
||||
query_embeddings (List[List[float]]): 用于查询的向量列表。
|
||||
n_results (int, optional): 返回结果的数量。Defaults to 1.
|
||||
where (Optional[Dict[str, Any]], optional): 元数据过滤条件。Defaults to None.
|
||||
**kwargs: 其他特定于实现的参数。
|
||||
|
||||
Returns:
|
||||
Dict[str, List[Any]]: 查询结果,通常包含 ids, distances, metadatas, documents。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
从指定集合中删除数据。
|
||||
|
||||
Args:
|
||||
collection_name (str): 目标集合的名称。
|
||||
ids (Optional[List[str]], optional): 要删除的条目的 ID 列表。Defaults to None.
|
||||
where (Optional[Dict[str, Any]], optional): 基于元数据的过滤条件。Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count(self, collection_name: str) -> int:
|
||||
"""
|
||||
获取指定集合中的条目总数。
|
||||
|
||||
Args:
|
||||
collection_name (str): 目标集合的名称。
|
||||
|
||||
Returns:
|
||||
int: 条目总数。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, name: str) -> None:
|
||||
"""
|
||||
删除一个集合。
|
||||
|
||||
Args:
|
||||
name (str): 要删除的集合的名称。
|
||||
"""
|
||||
pass
|
||||
137
src/common/vector_db/chromadb_impl.py
Normal file
137
src/common/vector_db/chromadb_impl.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
from .base import VectorDBBase
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chromadb_impl")
|
||||
|
||||
class ChromaDBImpl(VectorDBBase):
|
||||
"""
|
||||
ChromaDB 的具体实现,遵循 VectorDBBase 接口。
|
||||
采用单例模式,确保全局只有一个 ChromaDB 客户端实例。
|
||||
"""
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super(ChromaDBImpl, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, path: str = "data/chroma_db", **kwargs: Any):
|
||||
"""
|
||||
初始化 ChromaDB 客户端。
|
||||
由于是单例,这个初始化只会执行一次。
|
||||
"""
|
||||
if not hasattr(self, '_initialized'):
|
||||
with self._lock:
|
||||
if not hasattr(self, '_initialized'):
|
||||
try:
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=path,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
self._collections: Dict[str, Any] = {}
|
||||
self._initialized = True
|
||||
logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"ChromaDB 初始化失败: {e}")
|
||||
self.client = None
|
||||
self._initialized = False
|
||||
|
||||
def get_or_create_collection(self, name: str, **kwargs: Any) -> Any:
|
||||
if not self.client:
|
||||
raise ConnectionError("ChromaDB 客户端未初始化")
|
||||
|
||||
if name in self._collections:
|
||||
return self._collections[name]
|
||||
|
||||
try:
|
||||
collection = self.client.get_or_create_collection(name=name, **kwargs)
|
||||
self._collections[name] = collection
|
||||
logger.info(f"成功获取或创建集合: '{name}'")
|
||||
return collection
|
||||
except Exception as e:
|
||||
logger.error(f"获取或创建集合 '{name}' 失败: {e}")
|
||||
return None
|
||||
|
||||
def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: List[List[float]],
|
||||
documents: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
collection.add(
|
||||
embeddings=embeddings,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"向集合 '{collection_name}' 添加数据失败: {e}")
|
||||
|
||||
def query(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_embeddings: List[List[float]],
|
||||
n_results: int = 1,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, List[Any]]:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
return collection.query(
|
||||
query_embeddings=query_embeddings,
|
||||
n_results=n_results,
|
||||
where=where or {},
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询集合 '{collection_name}' 失败: {e}")
|
||||
return {}
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
collection.delete(ids=ids, where=where)
|
||||
except Exception as e:
|
||||
logger.error(f"从集合 '{collection_name}' 删除数据失败: {e}")
|
||||
|
||||
def count(self, collection_name: str) -> int:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
return collection.count()
|
||||
except Exception as e:
|
||||
logger.error(f"获取集合 '{collection_name}' 计数失败: {e}")
|
||||
return 0
|
||||
|
||||
def delete_collection(self, name: str) -> None:
|
||||
if not self.client:
|
||||
raise ConnectionError("ChromaDB 客户端未初始化")
|
||||
|
||||
try:
|
||||
self.client.delete_collection(name=name)
|
||||
if name in self._collections:
|
||||
del self._collections[name]
|
||||
logger.info(f"集合 '{name}' 已被删除")
|
||||
except Exception as e:
|
||||
logger.error(f"删除集合 '{name}' 失败: {e}")
|
||||
@@ -31,6 +31,13 @@ class BaseTool(ABC):
|
||||
history_ttl: int = 5
|
||||
"""工具调用历史记录的TTL值,默认为5。设为0表示不记录历史"""
|
||||
|
||||
enable_cache: bool = False
|
||||
"""是否为该工具启用缓存"""
|
||||
cache_ttl: int = 3600
|
||||
"""缓存的TTL值(秒),默认为3600秒(1小时)"""
|
||||
semantic_cache_query_key: Optional[str] = None
|
||||
"""用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索"""
|
||||
|
||||
def __init__(self, plugin_config: Optional[dict] = None):
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from bs4 import BeautifulSoup
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseTool, ToolParamType, llm_api
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
from ..utils.formatters import format_url_parse_results
|
||||
from ..utils.url_utils import parse_urls_from_input, validate_urls
|
||||
@@ -30,6 +29,12 @@ class URLParserTool(BaseTool):
|
||||
parameters = [
|
||||
("urls", ToolParamType.STRING, "要理解的网站", True, None),
|
||||
]
|
||||
|
||||
# --- 新的缓存配置 ---
|
||||
enable_cache: bool = True
|
||||
cache_ttl: int = 86400 # 缓存24小时
|
||||
semantic_cache_query_key: str = "urls"
|
||||
# --------------------
|
||||
|
||||
def __init__(self, plugin_config=None):
|
||||
super().__init__(plugin_config)
|
||||
@@ -42,10 +47,11 @@ class URLParserTool(BaseTool):
|
||||
if exa_api_keys is None:
|
||||
# 从插件配置文件读取
|
||||
exa_api_keys = self.get_config("exa.api_keys", [])
|
||||
|
||||
|
||||
# 创建API密钥管理器
|
||||
from typing import cast, List
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
exa_api_keys,
|
||||
cast(List[str], exa_api_keys),
|
||||
lambda key: Exa(api_key=key),
|
||||
"Exa URL Parser"
|
||||
)
|
||||
@@ -135,16 +141,6 @@ class URLParserTool(BaseTool):
|
||||
"""
|
||||
执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。
|
||||
"""
|
||||
# 获取当前文件路径用于缓存键
|
||||
import os
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
|
||||
# 检查缓存
|
||||
cached_result = await tool_cache.get(self.name, function_args, current_file_path)
|
||||
if cached_result:
|
||||
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
||||
return cached_result
|
||||
|
||||
urls_input = function_args.get("urls")
|
||||
if not urls_input:
|
||||
return {"error": "URL列表不能为空。"}
|
||||
@@ -235,8 +231,4 @@ class URLParserTool(BaseTool):
|
||||
"errors": error_messages
|
||||
}
|
||||
|
||||
# 保存到缓存
|
||||
if "error" not in result:
|
||||
await tool_cache.set(self.name, function_args, current_file_path, result)
|
||||
|
||||
return result
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Any, Dict, List
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseTool, ToolParamType
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
from ..engines.exa_engine import ExaSearchEngine
|
||||
from ..engines.tavily_engine import TavilySearchEngine
|
||||
@@ -31,6 +30,12 @@ class WebSurfingTool(BaseTool):
|
||||
("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", False, ["any", "week", "month"])
|
||||
] # type: ignore
|
||||
|
||||
# --- 新的缓存配置 ---
|
||||
enable_cache: bool = True
|
||||
cache_ttl: int = 7200 # 缓存2小时
|
||||
semantic_cache_query_key: str = "query"
|
||||
# --------------------
|
||||
|
||||
def __init__(self, plugin_config=None):
|
||||
super().__init__(plugin_config)
|
||||
# 初始化搜索引擎
|
||||
@@ -46,16 +51,6 @@ class WebSurfingTool(BaseTool):
|
||||
if not query:
|
||||
return {"error": "搜索查询不能为空。"}
|
||||
|
||||
# 获取当前文件路径用于缓存键
|
||||
import os
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
|
||||
# 检查缓存
|
||||
cached_result = await tool_cache.get(self.name, function_args, current_file_path, semantic_query=query)
|
||||
if cached_result:
|
||||
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
||||
return cached_result
|
||||
|
||||
# 读取搜索配置
|
||||
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
|
||||
search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
|
||||
@@ -69,10 +64,6 @@ class WebSurfingTool(BaseTool):
|
||||
result = await self._execute_fallback_search(function_args, enabled_engines)
|
||||
else: # single
|
||||
result = await self._execute_single_search(function_args, enabled_engines)
|
||||
|
||||
# 保存到缓存
|
||||
if "error" not in result:
|
||||
await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user