Merge branch 'master' of https://github.com/MaiBot-Plus/MaiMbot-Pro-Max
This commit is contained in:
66
docs/memory_system_design_v2.md
Normal file
66
docs/memory_system_design_v2.md
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# 全新三层记忆系统架构 (V2.0) 设计文档
|
||||||
|
|
||||||
|
## 1. 核心思想
|
||||||
|
|
||||||
|
本架构旨在建立一个清晰、有序的信息处理流水线,模拟人类记忆从瞬时感知到长期知识沉淀的过程。信息将经历**短期记忆 (STM)**、**中期记忆 (MTM)** 和 **长期记忆 (LTM)** 三个阶段,实现从海量、零散到结构化、深刻的转化。
|
||||||
|
|
||||||
|
## 2. 架构分层详解
|
||||||
|
|
||||||
|
### 2.1. 短期记忆 (STM - Short-Term Memory) - “消息缓冲区”
|
||||||
|
|
||||||
|
* **职责**: 捕获并暂存所有进入核心的最新消息,为即时对话提供上下文,实现快速响应。
|
||||||
|
* **实现方式**:
|
||||||
|
* **内存队列**: 采用定长的内存队列(如 `collections.deque`),存储最近的 N 条原始消息(建议初始值为 200)。
|
||||||
|
* **实时向量化**: 消息入队时,异步进行文本内容的语义向量化,生成“意义指纹”。
|
||||||
|
* **快速检索**: 利用高效的向量相似度计算库(如 FAISS, Annoy),在新消息到来时,快速从队列中检索最相关的历史消息,构建即时上下文。
|
||||||
|
* **触发机制**: 当队列达到容量上限时,将最老的一批消息(例如前 50 条)打包,移交给中期记忆模块处理。
|
||||||
|
|
||||||
|
### 2.2. 中期记忆 (MTM - Mid-Term Memory) - “记忆压缩器”
|
||||||
|
|
||||||
|
* **职责**: 对来自短期记忆的大量零散信息进行压缩、总结,形成结构化的“记忆片段”。
|
||||||
|
* **实现方式**:
|
||||||
|
* **LLM 总结**: 调用大语言模型(LLM)对 STM 移交的消息包进行深度分析和总结,提炼成一段精简的“记忆陈述”(Memory Statement)。
|
||||||
|
* **信息结构化**: 每个记忆片段都将包含以下元数据:
|
||||||
|
* `memory_text`: 记忆陈述本身。
|
||||||
|
* `keywords`: 关联的关键词列表。
|
||||||
|
* `time_range`: 记忆所涉及的时间范围。
|
||||||
|
* `importance_score`: LLM 评估的重要性评分。
|
||||||
|
* `access_count`: 访问计数器,初始为 0。
|
||||||
|
* **持久化存储**: 将结构化的记忆片段存储在数据库中,可复用或改造现有 `Memory` 表。
|
||||||
|
* **触发机制**: 由 STM 的队列溢出事件触发。
|
||||||
|
|
||||||
|
### 2.3. 长期记忆 (LTM - Long-Term Memory) - “知识图谱”
|
||||||
|
|
||||||
|
* **职责**: 将经过验证的、具有高价值的中期记忆,内化为系统核心知识的一部分,构建深层联系。
|
||||||
|
* **实现方式**:
|
||||||
|
* **晋升机制**: 通过一个定期的“记忆整理”任务,扫描中期记忆数据库。当某个记忆片段的 `access_count` 达到预设阈值(例如 10 次),则触发晋升。
|
||||||
|
* **融入图谱**: 晋升的记忆片段将被送往 `Hippocampus` 模块。`Hippocampus` 将不再直接处理原始聊天记录,而是处理这些高质量、经过预处理的记忆片段。它会从中提取核心概念(节点)和它们之间的关系(边),然后将这些信息融入并更新现有的知识图谱。
|
||||||
|
* **触发机制**: 由定时任务(例如每天执行一次)触发。
|
||||||
|
|
||||||
|
## 3. 信息处理流程
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TD
|
||||||
|
A[输入: 新消息] --> B{短期记忆 STM};
|
||||||
|
B --> |实时向量检索| C[输出: 对话上下文];
|
||||||
|
B --> |队列满| D{中期记忆 MTM};
|
||||||
|
D --> |LLM 总结| E[存入数据库: 记忆片段];
|
||||||
|
E --> |关键词/时间检索| C;
|
||||||
|
E --> |访问次数高| F{长期记忆 LTM};
|
||||||
|
F --> |LLM 提取概念/关系| G[更新: 知识图谱];
|
||||||
|
G --> |图谱扩散激活检索| C;
|
||||||
|
|
||||||
|
subgraph "内存中 (高速)"
|
||||||
|
B
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "数据库中 (持久化)"
|
||||||
|
E
|
||||||
|
G
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. 现有模块改造计划
|
||||||
|
|
||||||
|
* **`InstantMemory`**: 将被新的 **STM** 和 **MTM** 模块取代。其原有的“判断是否需要记忆”和“总结”的功能,将融入到 MTM 的处理流程中。
|
||||||
|
* **`Hippocampus`**: 将保留其作为 **LTM** 的核心地位,但其输入源将从“随机抽样的历史聊天记录”变更为“从 MTM 晋升的高价值记忆片段”。这将极大提升其构建知识图谱的效率和质量。
|
||||||
@@ -165,10 +165,38 @@ configure_dependency_settings(auto_install_timeout=600)
|
|||||||
## 工作流程
|
## 工作流程
|
||||||
|
|
||||||
1. **插件初始化**: 当插件类被实例化时,系统自动检查依赖
|
1. **插件初始化**: 当插件类被实例化时,系统自动检查依赖
|
||||||
2. **依赖标准化**: 将字符串格式的依赖转换为PythonDependency对象
|
2. **依赖标准化**: 将字符串格式的依赖转换为`PythonDependency`对象
|
||||||
3. **检查已安装**: 尝试导入每个依赖包并检查版本
|
3. **检查已安装**: 尝试导入每个依赖包并检查版本
|
||||||
4. **自动安装**: 如果启用,自动安装缺失的依赖
|
4. **智能别名解析 (新增)**: 如果直接导入失败 (例如 `import beautifulsoup4` 失败),系统会查询内置的别名映射表 (例如 `beautifulsoup4` -> `bs4`),并尝试使用别名再次导入。
|
||||||
5. **错误处理**: 记录详细的错误信息和安装日志
|
5. **自动安装**: 如果启用,自动安装缺失的依赖
|
||||||
|
6. **错误处理**: 记录详细的错误信息和安装日志
|
||||||
|
|
||||||
|
## 智能别名解析 (Smart Alias Resolution)
|
||||||
|
|
||||||
|
为了提升开发体验,依赖管理系统内置了一套智能别名解析机制。
|
||||||
|
|
||||||
|
### 解决的问题
|
||||||
|
|
||||||
|
Python生态中存在一些特殊的包,它们的**安装名** (在 `pip install` 中使用) 与**导入名** (在 `import` 语句中使用) 不一致。最典型的例子就是:
|
||||||
|
- 安装名: `beautifulsoup4`, 导入名: `bs4`
|
||||||
|
- 安装名: `Pillow`, 导入名: `PIL`
|
||||||
|
- 安装名: `scikit-learn`, 导入名: `sklearn`
|
||||||
|
|
||||||
|
如果开发者在 `python_dependencies` 列表中使用简单的字符串格式 `"beautifulsoup4"`,标准的依赖检查会因为无法 `import beautifulsoup4` 而失败。
|
||||||
|
|
||||||
|
### 工作原理
|
||||||
|
|
||||||
|
当依赖管理器通过包名直接导入失败时,它会:
|
||||||
|
1. 查询一个内置的、包含上百个常见包的别名映射表。
|
||||||
|
2. 如果在表中找到对应的导入名,则使用该别名再次尝试导入。
|
||||||
|
3. 如果使用别名导入成功,则依赖检查通过,并继续进行版本验证。
|
||||||
|
|
||||||
|
这个过程是自动的,旨在处理绝大多数常见情况,减少开发者手动配置的麻烦。
|
||||||
|
|
||||||
|
### 注意事项
|
||||||
|
|
||||||
|
- **最佳实践**: 尽管有智能别名解析,我们仍然**强烈推荐**使用 `PythonDependency` 对象来明确指定 `package_name` (导入名) 和 `install_name` (安装名),这能确保最高的准确性和可读性。
|
||||||
|
- **覆盖范围**: 内置的别名映射表涵盖了大量常用库,但无法保证100%覆盖所有情况。如果遇到别名库未收录的包,请使用 `PythonDependency` 对象进行精确定义。
|
||||||
|
|
||||||
## 日志输出示例
|
## 日志输出示例
|
||||||
|
|
||||||
@@ -192,12 +220,13 @@ configure_dependency_settings(auto_install_timeout=600)
|
|||||||
|
|
||||||
## 最佳实践
|
## 最佳实践
|
||||||
|
|
||||||
1. **使用详细的PythonDependency对象** 以获得更好的控制和文档
|
1. **优先使用`PythonDependency`对象**: 这是最可靠、最明确的方式,尤其是在安装名和导入名不同时。
|
||||||
2. **配置PyPI镜像源** 特别是在中国大陆地区,可显著提升下载速度
|
2. **利用智能别名解析**: 对于常见的、安装名与导入名不一致的包 (如 `beautifulsoup4`, `Pillow` 等),可以直接在字符串列表里使用安装名,系统会自动解析。
|
||||||
3. **合理设置可选依赖** 避免非核心功能阻止插件加载
|
3. **配置PyPI镜像源**: 特别是在中国大陆地区,可显著提升下载速度。
|
||||||
4. **指定版本要求** 确保兼容性
|
4. **合理设置可选依赖**: 避免非核心功能阻止插件加载。
|
||||||
5. **添加描述信息** 帮助用户理解依赖的用途
|
5. **指定版本要求**: 确保兼容性。
|
||||||
6. **测试依赖配置** 在不同环境中验证依赖是否正确
|
6. **添加描述信息**: 帮助用户理解依赖的用途。
|
||||||
|
7. **测试依赖配置**: 在不同环境中验证依赖是否正确。
|
||||||
|
|
||||||
## 安全考虑
|
## 安全考虑
|
||||||
|
|
||||||
@@ -225,7 +254,8 @@ configure_dependency_settings(auto_install_timeout=600)
|
|||||||
|
|
||||||
### 导入错误
|
### 导入错误
|
||||||
|
|
||||||
1. 确认包名与导入名一致
|
1. **确认包名与导入名**: 检查安装名和导入名是否一致。如果不一致,推荐使用 `PythonDependency` 对象明确指定 `package_name` 和 `install_name`。
|
||||||
2. 检查可选依赖配置
|
2. **利用自动别名解析**: 对于常见库,系统会自动尝试解析别名。如果你的库比较冷门且名称不一致,请使用 `PythonDependency` 对象。
|
||||||
3. 验证安装是否成功
|
3. **检查可选依赖配置**: 确认 `optional=True` 是否被正确设置。
|
||||||
|
4. **验证安装是否成功**: 查看日志,确认 `pip install` 过程没有报错。
|
||||||
|
|
||||||
|
|||||||
@@ -60,5 +60,7 @@ exa_py
|
|||||||
asyncddgs
|
asyncddgs
|
||||||
opencv-python
|
opencv-python
|
||||||
Pillow
|
Pillow
|
||||||
|
chromadb
|
||||||
asyncio
|
asyncio
|
||||||
tavily-python
|
tavily-python
|
||||||
|
google-generativeai = 0.8.5
|
||||||
@@ -1,344 +1,224 @@
|
|||||||
|
import time
|
||||||
import json
|
import json
|
||||||
|
import sqlite3
|
||||||
|
import chromadb
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import inspect
|
||||||
|
import numpy as np
|
||||||
|
import faiss
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
from difflib import SequenceMatcher
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.config.config import global_config, model_config
|
||||||
|
|
||||||
logger = get_logger("cache_manager")
|
logger = get_logger("cache_manager")
|
||||||
|
|
||||||
|
class CacheManager:
|
||||||
class ToolCache:
|
|
||||||
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cache_dir: str = "data/tool_cache",
|
|
||||||
max_age_hours: int = 24,
|
|
||||||
similarity_threshold: float = 0.65,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
初始化缓存管理器
|
一个支持分层和语义缓存的通用工具缓存管理器。
|
||||||
|
采用单例模式,确保在整个应用中只有一个缓存实例。
|
||||||
Args:
|
L1缓存: 内存字典 (KV) + FAISS (Vector)。
|
||||||
cache_dir: 缓存目录路径
|
L2缓存: SQLite (KV) + ChromaDB (Vector)。
|
||||||
max_age_hours: 缓存最大存活时间(小时)
|
|
||||||
similarity_threshold: 近似匹配的相似度阈值 (0-1)
|
|
||||||
"""
|
"""
|
||||||
self.cache_dir = Path(cache_dir)
|
_instance = None
|
||||||
self.max_age = timedelta(hours=max_age_hours)
|
|
||||||
self.max_age_seconds = max_age_hours * 3600
|
|
||||||
self.similarity_threshold = similarity_threshold
|
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
@staticmethod
|
def __new__(cls, *args, **kwargs):
|
||||||
def _normalize_query(query: str) -> str:
|
if not cls._instance:
|
||||||
|
cls._instance = super(CacheManager, cls).__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, default_ttl: int = 3600, db_path: str = "data/cache.db", chroma_path: str = "data/chroma_db"):
|
||||||
"""
|
"""
|
||||||
标准化查询文本,用于相似度比较
|
初始化缓存管理器。
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 原始查询文本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
标准化后的查询文本
|
|
||||||
"""
|
"""
|
||||||
if not query:
|
if not hasattr(self, '_initialized'):
|
||||||
return ""
|
self.default_ttl = default_ttl
|
||||||
|
|
||||||
# 纯 Python 实现
|
# L1 缓存 (内存)
|
||||||
normalized = query.lower()
|
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
normalized = re.sub(r"[^\w\s]", " ", normalized)
|
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||||
normalized = " ".join(normalized.split())
|
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||||
return normalized
|
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||||
|
|
||||||
def _calculate_similarity(self, text1: str, text2: str) -> float:
|
# L2 缓存 (持久化)
|
||||||
"""
|
self.db_path = db_path
|
||||||
计算两个文本的相似度
|
self._init_sqlite()
|
||||||
|
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
|
||||||
|
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||||
|
|
||||||
Args:
|
# 嵌入模型
|
||||||
text1: 文本1
|
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
|
||||||
text2: 文本2
|
|
||||||
|
|
||||||
Returns:
|
self._initialized = True
|
||||||
相似度分数 (0-1)
|
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (SQLite+ChromaDB)")
|
||||||
"""
|
|
||||||
if not text1 or not text2:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
# 纯 Python 实现
|
def _init_sqlite(self):
|
||||||
norm_text1 = self._normalize_query(text1)
|
"""初始化SQLite数据库和表结构。"""
|
||||||
norm_text2 = self._normalize_query(text2)
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS cache (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT,
|
||||||
|
expires_at REAL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
if norm_text1 == norm_text2:
|
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str:
|
||||||
return 1.0
|
"""生成确定性的缓存键,包含代码哈希以实现自动失效。"""
|
||||||
|
|
||||||
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
|
|
||||||
"""
|
|
||||||
生成缓存键
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
function_args: 函数参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
缓存键字符串
|
|
||||||
"""
|
|
||||||
# 将参数排序后序列化,确保相同参数产生相同的键
|
|
||||||
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
|
|
||||||
|
|
||||||
# 纯 Python 实现
|
|
||||||
cache_string = f"{tool_name}:{sorted_args}"
|
|
||||||
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
|
|
||||||
|
|
||||||
def _get_cache_file_path(self, cache_key: str) -> Path:
|
|
||||||
"""获取缓存文件路径"""
|
|
||||||
return self.cache_dir / f"{cache_key}.json"
|
|
||||||
|
|
||||||
def _is_cache_expired(self, cached_time: datetime) -> bool:
|
|
||||||
"""检查缓存是否过期"""
|
|
||||||
return datetime.now() - cached_time > self.max_age
|
|
||||||
|
|
||||||
def _find_similar_cache(
|
|
||||||
self, tool_name: str, function_args: Dict[str, Any]
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
查找相似的缓存条目
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
function_args: 函数参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
相似的缓存结果,如果不存在则返回None
|
|
||||||
"""
|
|
||||||
query = function_args.get("query", "")
|
|
||||||
if not query:
|
|
||||||
return None
|
|
||||||
|
|
||||||
candidates = []
|
|
||||||
cache_data_list = []
|
|
||||||
|
|
||||||
# 遍历所有缓存文件,收集候选项
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
|
||||||
try:
|
try:
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
source_code = inspect.getsource(tool_class)
|
||||||
cache_data = json.load(f)
|
code_hash = hashlib.md5(source_code.encode()).hexdigest()
|
||||||
|
except (TypeError, OSError) as e:
|
||||||
# 检查是否是同一个工具
|
code_hash = "unknown"
|
||||||
if cache_data.get("tool_name") != tool_name:
|
logger.warning(f"无法获取 {tool_class.__name__} 的源代码,代码哈希将为 'unknown'。错误: {e}")
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查缓存是否过期
|
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
|
||||||
if self._is_cache_expired(cached_time):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查其他参数是否匹配(除了query)
|
|
||||||
cached_args = cache_data.get("function_args", {})
|
|
||||||
args_match = True
|
|
||||||
for key, value in function_args.items():
|
|
||||||
if key != "query" and cached_args.get(key) != value:
|
|
||||||
args_match = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if not args_match:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 收集候选项
|
|
||||||
cached_query = cached_args.get("query", "")
|
|
||||||
candidates.append((cached_query, len(cache_data_list)))
|
|
||||||
cache_data_list.append(cache_data)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
logger.debug(
|
|
||||||
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 纯 Python 实现
|
|
||||||
best_match = None
|
|
||||||
best_similarity = 0.0
|
|
||||||
|
|
||||||
for cached_query, index in candidates:
|
|
||||||
similarity = self._calculate_similarity(query, cached_query)
|
|
||||||
if similarity > best_similarity and similarity >= self.similarity_threshold:
|
|
||||||
best_similarity = similarity
|
|
||||||
best_match = cache_data_list[index]
|
|
||||||
|
|
||||||
if best_match is not None:
|
|
||||||
cached_query = best_match["function_args"].get("query", "")
|
|
||||||
logger.info(
|
|
||||||
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
|
|
||||||
)
|
|
||||||
return best_match["result"]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get(
|
|
||||||
self, tool_name: str, function_args: Dict[str, Any]
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
从缓存获取结果,支持精确匹配和近似匹配
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
function_args: 函数参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
缓存的结果,如果不存在或已过期则返回None
|
|
||||||
"""
|
|
||||||
# 首先尝试精确匹配
|
|
||||||
cache_key = self._generate_cache_key(tool_name, function_args)
|
|
||||||
cache_file = self._get_cache_file_path(cache_key)
|
|
||||||
|
|
||||||
if cache_file.exists():
|
|
||||||
try:
|
try:
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
sorted_args = json.dumps(function_args, sort_keys=True)
|
||||||
cache_data = json.load(f)
|
except TypeError:
|
||||||
|
sorted_args = repr(sorted(function_args.items()))
|
||||||
|
return f"{tool_name}::{sorted_args}::{code_hash}"
|
||||||
|
|
||||||
# 检查缓存是否过期
|
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, semantic_query: Optional[str] = None) -> Optional[Any]:
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
"""
|
||||||
if self._is_cache_expired(cached_time):
|
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
|
||||||
logger.debug(f"缓存已过期: {cache_key}")
|
"""
|
||||||
cache_file.unlink() # 删除过期缓存
|
# 步骤 1: L1 精确缓存查询
|
||||||
|
key = self._generate_key(tool_name, function_args, tool_class)
|
||||||
|
logger.debug(f"生成的缓存键: {key}")
|
||||||
|
if semantic_query:
|
||||||
|
logger.debug(f"使用的语义查询: '{semantic_query}'")
|
||||||
|
|
||||||
|
if key in self.l1_kv_cache:
|
||||||
|
entry = self.l1_kv_cache[key]
|
||||||
|
if time.time() < entry["expires_at"]:
|
||||||
|
logger.info(f"命中L1键值缓存: {key}")
|
||||||
|
return entry["data"]
|
||||||
else:
|
else:
|
||||||
logger.debug(f"精确匹配缓存: {tool_name}")
|
del self.l1_kv_cache[key]
|
||||||
return cache_data["result"]
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
# 步骤 2: L1/L2 语义和L2精确缓存查询
|
||||||
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
|
query_embedding = None
|
||||||
# 删除损坏的缓存文件
|
if semantic_query and self.embedding_model:
|
||||||
if cache_file.exists():
|
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
||||||
cache_file.unlink()
|
if embedding_result:
|
||||||
|
query_embedding = np.array([embedding_result], dtype='float32')
|
||||||
|
|
||||||
# 如果精确匹配失败,尝试近似匹配
|
# 步骤 2a: L1 语义缓存 (FAISS)
|
||||||
return self._find_similar_cache(tool_name, function_args)
|
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
distances, indices = self.l1_vector_index.search(query_embedding, 1)
|
||||||
|
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
|
||||||
|
hit_index = indices[0][0]
|
||||||
|
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
|
||||||
|
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
|
||||||
|
logger.info(f"命中L1语义缓存: {l1_hit_key}")
|
||||||
|
return self.l1_kv_cache[l1_hit_key]["data"]
|
||||||
|
|
||||||
def set(
|
# 步骤 2b: L2 精确缓存 (SQLite)
|
||||||
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any]
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
) -> None:
|
cursor = conn.cursor()
|
||||||
"""
|
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (key,))
|
||||||
将结果保存到缓存
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
value, expires_at = row
|
||||||
|
if time.time() < expires_at:
|
||||||
|
logger.info(f"命中L2键值缓存: {key}")
|
||||||
|
data = json.loads(value)
|
||||||
|
# 回填 L1
|
||||||
|
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
cursor.execute("DELETE FROM cache WHERE key = ?", (key,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
Args:
|
# 步骤 2c: L2 语义缓存 (ChromaDB)
|
||||||
tool_name: 工具名称
|
if query_embedding is not None:
|
||||||
function_args: 函数参数
|
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
|
||||||
result: 缓存结果
|
if results and results['ids'] and results['ids'][0]:
|
||||||
"""
|
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
|
||||||
cache_key = self._generate_cache_key(tool_name, function_args)
|
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
||||||
cache_file = self._get_cache_file_path(cache_key)
|
if distance != 'N/A' and distance < 0.75:
|
||||||
|
l2_hit_key = results['ids'][0]
|
||||||
|
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],))
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
value, expires_at = row
|
||||||
|
if time.time() < expires_at:
|
||||||
|
data = json.loads(value)
|
||||||
|
logger.debug(f"L2语义缓存返回的数据: {data}")
|
||||||
|
# 回填 L1
|
||||||
|
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||||
|
if query_embedding is not None:
|
||||||
|
new_id = self.l1_vector_index.ntotal
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
self.l1_vector_index.add(x=query_embedding)
|
||||||
|
self.l1_vector_id_to_key[new_id] = key
|
||||||
|
return data
|
||||||
|
|
||||||
cache_data = {
|
logger.debug(f"缓存未命中: {key}")
|
||||||
"tool_name": tool_name,
|
return None
|
||||||
"function_args": function_args,
|
|
||||||
"result": result,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
|
||||||
with open(cache_file, "w", encoding="utf-8") as f:
|
"""将结果存入所有缓存层。"""
|
||||||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
if ttl is None:
|
||||||
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
|
ttl = self.default_ttl
|
||||||
except Exception as e:
|
if ttl <= 0:
|
||||||
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
|
return
|
||||||
|
|
||||||
def clear_expired(self) -> int:
|
key = self._generate_key(tool_name, function_args, tool_class)
|
||||||
"""
|
expires_at = time.time() + ttl
|
||||||
清理过期缓存
|
|
||||||
|
|
||||||
Returns:
|
# 写入 L1
|
||||||
删除的文件数量
|
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||||
"""
|
|
||||||
removed_count = 0
|
|
||||||
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
# 写入 L2
|
||||||
try:
|
value = json.dumps(data)
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cache_data = json.load(f)
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, value, expires_at))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
# 写入语义缓存
|
||||||
if self._is_cache_expired(cached_time):
|
if semantic_query and self.embedding_model:
|
||||||
cache_file.unlink()
|
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
||||||
removed_count += 1
|
if embedding_result:
|
||||||
logger.debug(f"删除过期缓存: {cache_file}")
|
embedding = np.array([embedding_result], dtype='float32')
|
||||||
|
# 写入 L1 Vector
|
||||||
|
new_id = self.l1_vector_index.ntotal
|
||||||
|
faiss.normalize_L2(embedding)
|
||||||
|
self.l1_vector_index.add(x=embedding)
|
||||||
|
self.l1_vector_id_to_key[new_id] = key
|
||||||
|
# 写入 L2 Vector
|
||||||
|
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
|
||||||
|
|
||||||
except Exception as e:
|
logger.info(f"已缓存条目: {key}, TTL: {ttl}s")
|
||||||
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
|
|
||||||
# 删除损坏的文件
|
|
||||||
try:
|
|
||||||
cache_file.unlink()
|
|
||||||
removed_count += 1
|
|
||||||
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
|
||||||
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
|
|
||||||
|
|
||||||
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
|
def clear_l1(self):
|
||||||
return removed_count
|
"""清空L1缓存。"""
|
||||||
|
self.l1_kv_cache.clear()
|
||||||
|
self.l1_vector_index.reset()
|
||||||
|
self.l1_vector_id_to_key.clear()
|
||||||
|
logger.info("L1 (内存+FAISS) 缓存已清空。")
|
||||||
|
|
||||||
def clear_all(self) -> int:
|
def clear_l2(self):
|
||||||
"""
|
"""清空L2缓存。"""
|
||||||
清空所有缓存
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM cache")
|
||||||
|
conn.commit()
|
||||||
|
self.chroma_client.delete_collection(name="semantic_cache")
|
||||||
|
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||||
|
logger.info("L2 (SQLite & ChromaDB) 缓存已清空。")
|
||||||
|
|
||||||
Returns:
|
def clear_all(self):
|
||||||
删除的文件数量
|
"""清空所有缓存。"""
|
||||||
"""
|
self.clear_l1()
|
||||||
removed_count = 0
|
self.clear_l2()
|
||||||
|
logger.info("所有缓存层级已清空。")
|
||||||
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
# 全局实例
|
||||||
try:
|
tool_cache = CacheManager()
|
||||||
cache_file.unlink()
|
|
||||||
removed_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
|
|
||||||
|
|
||||||
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
|
|
||||||
return removed_count
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
获取缓存统计信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
缓存统计信息字典
|
|
||||||
"""
|
|
||||||
total_files = 0
|
|
||||||
expired_files = 0
|
|
||||||
total_size = 0
|
|
||||||
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
|
||||||
try:
|
|
||||||
total_files += 1
|
|
||||||
total_size += cache_file.stat().st_size
|
|
||||||
|
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
|
||||||
cache_data = json.load(f)
|
|
||||||
|
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
|
||||||
if self._is_cache_expired(cached_time):
|
|
||||||
expired_files += 1
|
|
||||||
|
|
||||||
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
|
||||||
expired_files += 1 # 损坏的文件也算作过期
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_files": total_files,
|
|
||||||
"expired_files": expired_files,
|
|
||||||
"total_size_bytes": total_size,
|
|
||||||
"cache_dir": str(self.cache_dir),
|
|
||||||
"max_age_hours": self.max_age.total_seconds() / 3600,
|
|
||||||
"similarity_threshold": self.similarity_threshold,
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_cache = ToolCache()
|
|
||||||
344
src/common/cache_manager_backup.py
Normal file
344
src/common/cache_manager_backup.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("cache_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCache:
|
||||||
|
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cache_dir: str = "data/tool_cache",
|
||||||
|
max_age_hours: int = 24,
|
||||||
|
similarity_threshold: float = 0.65,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化缓存管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_dir: 缓存目录路径
|
||||||
|
max_age_hours: 缓存最大存活时间(小时)
|
||||||
|
similarity_threshold: 近似匹配的相似度阈值 (0-1)
|
||||||
|
"""
|
||||||
|
self.cache_dir = Path(cache_dir)
|
||||||
|
self.max_age = timedelta(hours=max_age_hours)
|
||||||
|
self.max_age_seconds = max_age_hours * 3600
|
||||||
|
self.similarity_threshold = similarity_threshold
|
||||||
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_query(query: str) -> str:
|
||||||
|
"""
|
||||||
|
标准化查询文本,用于相似度比较
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 原始查询文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化后的查询文本
|
||||||
|
"""
|
||||||
|
if not query:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 纯 Python 实现
|
||||||
|
normalized = query.lower()
|
||||||
|
normalized = re.sub(r"[^\w\s]", " ", normalized)
|
||||||
|
normalized = " ".join(normalized.split())
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def _calculate_similarity(self, text1: str, text2: str) -> float:
|
||||||
|
"""
|
||||||
|
计算两个文本的相似度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text1: 文本1
|
||||||
|
text2: 文本2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似度分数 (0-1)
|
||||||
|
"""
|
||||||
|
if not text1 or not text2:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 纯 Python 实现
|
||||||
|
norm_text1 = self._normalize_query(text1)
|
||||||
|
norm_text2 = self._normalize_query(text2)
|
||||||
|
|
||||||
|
if norm_text1 == norm_text2:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
生成缓存键
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
function_args: 函数参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存键字符串
|
||||||
|
"""
|
||||||
|
# 将参数排序后序列化,确保相同参数产生相同的键
|
||||||
|
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 纯 Python 实现
|
||||||
|
cache_string = f"{tool_name}:{sorted_args}"
|
||||||
|
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
def _get_cache_file_path(self, cache_key: str) -> Path:
|
||||||
|
"""获取缓存文件路径"""
|
||||||
|
return self.cache_dir / f"{cache_key}.json"
|
||||||
|
|
||||||
|
def _is_cache_expired(self, cached_time: datetime) -> bool:
|
||||||
|
"""检查缓存是否过期"""
|
||||||
|
return datetime.now() - cached_time > self.max_age
|
||||||
|
|
||||||
|
def _find_similar_cache(
|
||||||
|
self, tool_name: str, function_args: Dict[str, Any]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
查找相似的缓存条目
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
function_args: 函数参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似的缓存结果,如果不存在则返回None
|
||||||
|
"""
|
||||||
|
query = function_args.get("query", "")
|
||||||
|
if not query:
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
cache_data_list = []
|
||||||
|
|
||||||
|
# 遍历所有缓存文件,收集候选项
|
||||||
|
for cache_file in self.cache_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
with open(cache_file, "r", encoding="utf-8") as f:
|
||||||
|
cache_data = json.load(f)
|
||||||
|
|
||||||
|
# 检查是否是同一个工具
|
||||||
|
if cache_data.get("tool_name") != tool_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查缓存是否过期
|
||||||
|
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||||||
|
if self._is_cache_expired(cached_time):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查其他参数是否匹配(除了query)
|
||||||
|
cached_args = cache_data.get("function_args", {})
|
||||||
|
args_match = True
|
||||||
|
for key, value in function_args.items():
|
||||||
|
if key != "query" and cached_args.get(key) != value:
|
||||||
|
args_match = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not args_match:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 收集候选项
|
||||||
|
cached_query = cached_args.get("query", "")
|
||||||
|
candidates.append((cached_query, len(cache_data_list)))
|
||||||
|
cache_data_list.append(cache_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
logger.debug(
|
||||||
|
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 纯 Python 实现
|
||||||
|
best_match = None
|
||||||
|
best_similarity = 0.0
|
||||||
|
|
||||||
|
for cached_query, index in candidates:
|
||||||
|
similarity = self._calculate_similarity(query, cached_query)
|
||||||
|
if similarity > best_similarity and similarity >= self.similarity_threshold:
|
||||||
|
best_similarity = similarity
|
||||||
|
best_match = cache_data_list[index]
|
||||||
|
|
||||||
|
if best_match is not None:
|
||||||
|
cached_query = best_match["function_args"].get("query", "")
|
||||||
|
logger.info(
|
||||||
|
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
|
||||||
|
)
|
||||||
|
return best_match["result"]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, tool_name: str, function_args: Dict[str, Any]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
从缓存获取结果,支持精确匹配和近似匹配
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
function_args: 函数参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存的结果,如果不存在或已过期则返回None
|
||||||
|
"""
|
||||||
|
# 首先尝试精确匹配
|
||||||
|
cache_key = self._generate_cache_key(tool_name, function_args)
|
||||||
|
cache_file = self._get_cache_file_path(cache_key)
|
||||||
|
|
||||||
|
if cache_file.exists():
|
||||||
|
try:
|
||||||
|
with open(cache_file, "r", encoding="utf-8") as f:
|
||||||
|
cache_data = json.load(f)
|
||||||
|
|
||||||
|
# 检查缓存是否过期
|
||||||
|
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||||||
|
if self._is_cache_expired(cached_time):
|
||||||
|
logger.debug(f"缓存已过期: {cache_key}")
|
||||||
|
cache_file.unlink() # 删除过期缓存
|
||||||
|
else:
|
||||||
|
logger.debug(f"精确匹配缓存: {tool_name}")
|
||||||
|
return cache_data["result"]
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||||
|
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
|
||||||
|
# 删除损坏的缓存文件
|
||||||
|
if cache_file.exists():
|
||||||
|
cache_file.unlink()
|
||||||
|
|
||||||
|
# 如果精确匹配失败,尝试近似匹配
|
||||||
|
return self._find_similar_cache(tool_name, function_args)
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
将结果保存到缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
function_args: 函数参数
|
||||||
|
result: 缓存结果
|
||||||
|
"""
|
||||||
|
cache_key = self._generate_cache_key(tool_name, function_args)
|
||||||
|
cache_file = self._get_cache_file_path(cache_key)
|
||||||
|
|
||||||
|
cache_data = {
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"function_args": function_args,
|
||||||
|
"result": result,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(cache_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||||||
|
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
|
||||||
|
|
||||||
|
def clear_expired(self) -> int:
|
||||||
|
"""
|
||||||
|
清理过期缓存
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的文件数量
|
||||||
|
"""
|
||||||
|
removed_count = 0
|
||||||
|
|
||||||
|
for cache_file in self.cache_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
with open(cache_file, "r", encoding="utf-8") as f:
|
||||||
|
cache_data = json.load(f)
|
||||||
|
|
||||||
|
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||||||
|
if self._is_cache_expired(cached_time):
|
||||||
|
cache_file.unlink()
|
||||||
|
removed_count += 1
|
||||||
|
logger.debug(f"删除过期缓存: {cache_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
|
||||||
|
# 删除损坏的文件
|
||||||
|
try:
|
||||||
|
cache_file.unlink()
|
||||||
|
removed_count += 1
|
||||||
|
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
||||||
|
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
|
||||||
|
|
||||||
|
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
|
||||||
|
return removed_count
|
||||||
|
|
||||||
|
def clear_all(self) -> int:
|
||||||
|
"""
|
||||||
|
清空所有缓存
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的文件数量
|
||||||
|
"""
|
||||||
|
removed_count = 0
|
||||||
|
|
||||||
|
for cache_file in self.cache_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
cache_file.unlink()
|
||||||
|
removed_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
|
||||||
|
|
||||||
|
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
|
||||||
|
return removed_count
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取缓存统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存统计信息字典
|
||||||
|
"""
|
||||||
|
total_files = 0
|
||||||
|
expired_files = 0
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for cache_file in self.cache_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
total_files += 1
|
||||||
|
total_size += cache_file.stat().st_size
|
||||||
|
|
||||||
|
with open(cache_file, "r", encoding="utf-8") as f:
|
||||||
|
cache_data = json.load(f)
|
||||||
|
|
||||||
|
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||||||
|
if self._is_cache_expired(cached_time):
|
||||||
|
expired_files += 1
|
||||||
|
|
||||||
|
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
||||||
|
expired_files += 1 # 损坏的文件也算作过期
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_files": total_files,
|
||||||
|
"expired_files": expired_files,
|
||||||
|
"total_size_bytes": total_size,
|
||||||
|
"cache_dir": str(self.cache_dir),
|
||||||
|
"max_age_hours": self.max_age.total_seconds() / 3600,
|
||||||
|
"similarity_threshold": self.similarity_threshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_cache = ToolCache()
|
||||||
@@ -42,7 +42,8 @@ from src.config.official_configs import (
|
|||||||
ExaConfig,
|
ExaConfig,
|
||||||
WebSearchConfig,
|
WebSearchConfig,
|
||||||
TavilyConfig,
|
TavilyConfig,
|
||||||
AntiPromptInjectionConfig
|
AntiPromptInjectionConfig,
|
||||||
|
PluginsConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
from .api_ada_configs import (
|
from .api_ada_configs import (
|
||||||
@@ -365,6 +366,7 @@ class Config(ConfigBase):
|
|||||||
exa: ExaConfig = field(default_factory=lambda: ExaConfig())
|
exa: ExaConfig = field(default_factory=lambda: ExaConfig())
|
||||||
web_search: WebSearchConfig = field(default_factory=lambda: WebSearchConfig())
|
web_search: WebSearchConfig = field(default_factory=lambda: WebSearchConfig())
|
||||||
tavily: TavilyConfig = field(default_factory=lambda: TavilyConfig())
|
tavily: TavilyConfig = field(default_factory=lambda: TavilyConfig())
|
||||||
|
plugins: PluginsConfig = field(default_factory=lambda: PluginsConfig())
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal, Optional, Dict
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from src.config.config_base import ConfigBase
|
from src.config.config_base import ConfigBase
|
||||||
|
|
||||||
@@ -864,43 +864,6 @@ class ScheduleConfig(ConfigBase):
|
|||||||
guidelines: Optional[str] = field(default=None)
|
guidelines: Optional[str] = field(default=None)
|
||||||
"""日程生成指导原则,如果为None则使用默认指导原则"""
|
"""日程生成指导原则,如果为None则使用默认指导原则"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VideoAnalysisConfig(ConfigBase):
|
|
||||||
"""视频分析配置类"""
|
|
||||||
|
|
||||||
enable: bool = True
|
|
||||||
"""是否启用视频分析功能"""
|
|
||||||
|
|
||||||
analysis_mode: Literal["frame_by_frame", "batch_frames", "auto"] = "auto"
|
|
||||||
"""分析模式:逐帧分析(慢但详细)、批量分析(快但可能略简单)或自动选择"""
|
|
||||||
|
|
||||||
max_frames: int = 8
|
|
||||||
"""最大分析帧数"""
|
|
||||||
|
|
||||||
frame_quality: int = 85
|
|
||||||
"""帧图像JPEG质量 (1-100)"""
|
|
||||||
|
|
||||||
max_image_size: int = 800
|
|
||||||
"""单帧最大图像尺寸(像素)"""
|
|
||||||
|
|
||||||
batch_analysis_prompt: str = field(default="""请分析这个视频的内容。这些图片是从视频中按时间顺序提取的关键帧。
|
|
||||||
|
|
||||||
请提供详细的分析,包括:
|
|
||||||
1. 视频的整体内容和主题
|
|
||||||
2. 主要人物、对象和场景描述
|
|
||||||
3. 动作、情节和时间线发展
|
|
||||||
4. 视觉风格和艺术特点
|
|
||||||
5. 整体氛围和情感表达
|
|
||||||
6. 任何特殊的视觉效果或文字内容
|
|
||||||
|
|
||||||
请用中文回答,分析要详细准确。""")
|
|
||||||
"""批量分析时使用的提示词"""
|
|
||||||
|
|
||||||
enable_frame_timing: bool = True
|
|
||||||
"""是否在分析中包含帧的时间信息"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DependencyManagementConfig(ConfigBase):
|
class DependencyManagementConfig(ConfigBase):
|
||||||
"""插件Python依赖管理配置类"""
|
"""插件Python依赖管理配置类"""
|
||||||
@@ -1065,3 +1028,12 @@ class AntiPromptInjectionConfig(ConfigBase):
|
|||||||
|
|
||||||
shield_suffix: str = " 🛡️"
|
shield_suffix: str = " 🛡️"
|
||||||
"""加盾消息后缀"""
|
"""加盾消息后缀"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PluginsConfig(ConfigBase):
|
||||||
|
"""插件配置"""
|
||||||
|
|
||||||
|
centralized_config: bool = field(
|
||||||
|
default=True, metadata={"description": "是否启用插件配置集中化管理"}
|
||||||
|
)
|
||||||
@@ -1,32 +1,54 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
|
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict, Union
|
||||||
|
|
||||||
from google import genai
|
import google.generativeai as genai
|
||||||
from google.genai.types import (
|
from google.generativeai.types import (
|
||||||
Content,
|
|
||||||
Part,
|
|
||||||
FunctionDeclaration,
|
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
ContentListUnion,
|
|
||||||
ContentUnion,
|
|
||||||
ThinkingConfig,
|
|
||||||
Tool,
|
|
||||||
GenerateContentConfig,
|
|
||||||
EmbedContentResponse,
|
|
||||||
EmbedContentConfig,
|
|
||||||
SafetySetting,
|
|
||||||
HarmCategory,
|
HarmCategory,
|
||||||
HarmBlockThreshold,
|
HarmBlockThreshold,
|
||||||
)
|
)
|
||||||
from google.genai.errors import (
|
|
||||||
ClientError,
|
try:
|
||||||
ServerError,
|
# 尝试从较新的API导入
|
||||||
UnknownFunctionCallArgumentError,
|
from google.generativeai import configure
|
||||||
UnsupportedFunctionError,
|
from google.generativeai.types import SafetySetting, GenerationConfig
|
||||||
FunctionInvocationError,
|
except ImportError:
|
||||||
)
|
# 回退到基本类型
|
||||||
|
SafetySetting = Dict
|
||||||
|
GenerationConfig = Dict
|
||||||
|
|
||||||
|
# 定义兼容性类型
|
||||||
|
ContentDict = Dict
|
||||||
|
PartDict = Dict
|
||||||
|
ToolDict = Dict
|
||||||
|
FunctionDeclaration = Dict
|
||||||
|
Tool = Dict
|
||||||
|
ContentListUnion = List[Dict]
|
||||||
|
ContentUnion = Dict
|
||||||
|
Content = Dict
|
||||||
|
Part = Dict
|
||||||
|
ThinkingConfig = Dict
|
||||||
|
GenerateContentConfig = Dict
|
||||||
|
EmbedContentConfig = Dict
|
||||||
|
EmbedContentResponse = Dict
|
||||||
|
|
||||||
|
# 定义异常类型
|
||||||
|
class ClientError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ServerError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UnknownFunctionCallArgumentError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UnsupportedFunctionError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FunctionInvocationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -44,18 +66,17 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
|||||||
|
|
||||||
logger = get_logger("Gemini客户端")
|
logger = get_logger("Gemini客户端")
|
||||||
|
|
||||||
gemini_safe_settings = [
|
SAFETY_SETTINGS = [
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
|
{"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
{"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
{"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
{"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _convert_messages(
|
def _convert_messages(
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
) -> tuple[ContentListUnion, list[str] | None]:
|
) -> tuple[List[Dict], list[str] | None]:
|
||||||
"""
|
"""
|
||||||
转换消息格式 - 将消息转换为Gemini API所需的格式
|
转换消息格式 - 将消息转换为Gemini API所需的格式
|
||||||
:param messages: 消息列表
|
:param messages: 消息列表
|
||||||
@@ -81,7 +102,7 @@ def _convert_messages(
|
|||||||
normalized_format = format_mapping.get(image_format.lower(), image_format.lower())
|
normalized_format = format_mapping.get(image_format.lower(), image_format.lower())
|
||||||
return f"image/{normalized_format}"
|
return f"image/{normalized_format}"
|
||||||
|
|
||||||
def _convert_message_item(message: Message) -> Content:
|
def _convert_message_item(message: Message) -> Dict:
|
||||||
"""
|
"""
|
||||||
转换单个消息格式,除了system和tool类型的消息
|
转换单个消息格式,除了system和tool类型的消息
|
||||||
:param message: 消息对象
|
:param message: 消息对象
|
||||||
@@ -96,22 +117,25 @@ def _convert_messages(
|
|||||||
|
|
||||||
# 添加Content
|
# 添加Content
|
||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
content = [Part.from_text(text=message.content)]
|
content = [{"text": message.content}]
|
||||||
elif isinstance(message.content, list):
|
elif isinstance(message.content, list):
|
||||||
content: List[Part] = []
|
content = []
|
||||||
for item in message.content:
|
for item in message.content:
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
content.append(
|
content.append({
|
||||||
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=_get_correct_mime_type(item[0]))
|
"inline_data": {
|
||||||
)
|
"mime_type": _get_correct_mime_type(item[0]),
|
||||||
|
"data": item[1]
|
||||||
|
}
|
||||||
|
})
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
content.append(Part.from_text(text=item))
|
content.append({"text": item})
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||||
|
|
||||||
return Content(role=role, parts=content)
|
return {"role": role, "parts": content}
|
||||||
|
|
||||||
temp_list: list[ContentUnion] = []
|
temp_list: List[Dict] = []
|
||||||
system_instructions: list[str] = []
|
system_instructions: list[str] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.role == RoleType.System:
|
if message.role == RoleType.System:
|
||||||
@@ -338,13 +362,10 @@ def _default_normal_response_parser(
|
|||||||
|
|
||||||
@client_registry.register_client_class("gemini")
|
@client_registry.register_client_class("gemini")
|
||||||
class GeminiClient(BaseClient):
|
class GeminiClient(BaseClient):
|
||||||
client: genai.Client
|
|
||||||
|
|
||||||
def __init__(self, api_provider: APIProvider):
|
def __init__(self, api_provider: APIProvider):
|
||||||
super().__init__(api_provider)
|
super().__init__(api_provider)
|
||||||
self.client = genai.Client(
|
# 配置 Google Generative AI
|
||||||
api_key=api_provider.api_key,
|
genai.configure(api_key=api_provider.api_key)
|
||||||
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(
|
||||||
self,
|
self,
|
||||||
@@ -396,18 +417,18 @@ class GeminiClient(BaseClient):
|
|||||||
"max_output_tokens": max_tokens,
|
"max_output_tokens": max_tokens,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"response_modalities": ["TEXT"],
|
"response_modalities": ["TEXT"],
|
||||||
"thinking_config": ThinkingConfig(
|
"thinking_config": {
|
||||||
include_thoughts=True,
|
"include_thoughts": True,
|
||||||
thinking_budget=(
|
"thinking_budget": (
|
||||||
extra_params["thinking_budget"]
|
extra_params["thinking_budget"]
|
||||||
if extra_params and "thinking_budget" in extra_params
|
if extra_params and "thinking_budget" in extra_params
|
||||||
else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复
|
else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复
|
||||||
),
|
),
|
||||||
),
|
},
|
||||||
"safety_settings": gemini_safe_settings, # 防止空回复问题
|
"safety_settings": SAFETY_SETTINGS, # 防止空回复问题
|
||||||
}
|
}
|
||||||
if tools:
|
if tools:
|
||||||
generation_config_dict["tools"] = Tool(function_declarations=tools)
|
generation_config_dict["tools"] = {"function_declarations": tools}
|
||||||
if messages[1]:
|
if messages[1]:
|
||||||
# 如果有system消息,则将其添加到配置中
|
# 如果有system消息,则将其添加到配置中
|
||||||
generation_config_dict["system_instructions"] = messages[1]
|
generation_config_dict["system_instructions"] = messages[1]
|
||||||
@@ -417,15 +438,18 @@ class GeminiClient(BaseClient):
|
|||||||
generation_config_dict["response_mime_type"] = "application/json"
|
generation_config_dict["response_mime_type"] = "application/json"
|
||||||
generation_config_dict["response_schema"] = response_format.to_dict()
|
generation_config_dict["response_schema"] = response_format.to_dict()
|
||||||
|
|
||||||
generation_config = GenerateContentConfig(**generation_config_dict)
|
generation_config = generation_config_dict
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 创建模型实例
|
||||||
|
model = genai.GenerativeModel(model_info.model_identifier)
|
||||||
|
|
||||||
if model_info.force_stream_mode:
|
if model_info.force_stream_mode:
|
||||||
req_task = asyncio.create_task(
|
req_task = asyncio.create_task(
|
||||||
self.client.aio.models.generate_content_stream(
|
model.generate_content_async(
|
||||||
model=model_info.model_identifier,
|
|
||||||
contents=messages[0],
|
contents=messages[0],
|
||||||
config=generation_config,
|
generation_config=generation_config,
|
||||||
|
stream=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
while not req_task.done():
|
while not req_task.done():
|
||||||
@@ -437,10 +461,9 @@ class GeminiClient(BaseClient):
|
|||||||
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
|
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
|
||||||
else:
|
else:
|
||||||
req_task = asyncio.create_task(
|
req_task = asyncio.create_task(
|
||||||
self.client.aio.models.generate_content(
|
model.generate_content_async(
|
||||||
model=model_info.model_identifier,
|
|
||||||
contents=messages[0],
|
contents=messages[0],
|
||||||
config=generation_config,
|
generation_config=generation_config
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
while not req_task.done():
|
while not req_task.done():
|
||||||
@@ -451,16 +474,17 @@ class GeminiClient(BaseClient):
|
|||||||
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
|
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
|
||||||
|
|
||||||
resp, usage_record = async_response_parser(req_task.result())
|
resp, usage_record = async_response_parser(req_task.result())
|
||||||
except (ClientError, ServerError) as e:
|
|
||||||
# 重封装ClientError和ServerError为RespNotOkException
|
|
||||||
raise RespNotOkException(e.code, e.message) from None
|
|
||||||
except (
|
|
||||||
UnknownFunctionCallArgumentError,
|
|
||||||
UnsupportedFunctionError,
|
|
||||||
FunctionInvocationError,
|
|
||||||
) as e:
|
|
||||||
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# 处理Google Generative AI异常
|
||||||
|
if "rate limit" in str(e).lower():
|
||||||
|
raise RespNotOkException(429, "请求频率过高,请稍后再试") from None
|
||||||
|
elif "quota" in str(e).lower():
|
||||||
|
raise RespNotOkException(429, "配额已用完") from None
|
||||||
|
elif "invalid" in str(e).lower() or "bad request" in str(e).lower():
|
||||||
|
raise RespNotOkException(400, f"请求无效:{str(e)}") from None
|
||||||
|
elif "permission" in str(e).lower() or "forbidden" in str(e).lower():
|
||||||
|
raise RespNotOkException(403, "权限不足") from None
|
||||||
|
else:
|
||||||
raise NetworkConnectionError() from e
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
if usage_record:
|
if usage_record:
|
||||||
@@ -535,7 +559,7 @@ class GeminiClient(BaseClient):
|
|||||||
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
"safety_settings": gemini_safe_settings,
|
"safety_settings": SAFETY_SETTINGS,
|
||||||
}
|
}
|
||||||
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
||||||
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from src.common.server import get_global_server, Server
|
|||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.manager.schedule_manager import schedule_manager
|
from src.manager.schedule_manager import schedule_manager
|
||||||
|
from src.common.cache_manager import tool_cache
|
||||||
# from src.api.main import start_api_server
|
# from src.api.main import start_api_server
|
||||||
|
|
||||||
# 导入新的插件管理器和热重载管理器
|
# 导入新的插件管理器和热重载管理器
|
||||||
|
|||||||
@@ -207,6 +207,9 @@ class VideoAnalyzer:
|
|||||||
"""批量分析所有帧"""
|
"""批量分析所有帧"""
|
||||||
self.logger.info(f"开始批量分析{len(frames)}帧")
|
self.logger.info(f"开始批量分析{len(frames)}帧")
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
return "❌ 没有可分析的帧"
|
||||||
|
|
||||||
# 构建提示词
|
# 构建提示词
|
||||||
prompt = self.batch_analysis_prompt
|
prompt = self.batch_analysis_prompt
|
||||||
|
|
||||||
@@ -214,29 +217,78 @@ class VideoAnalyzer:
|
|||||||
prompt += f"\n\n用户问题: {user_question}"
|
prompt += f"\n\n用户问题: {user_question}"
|
||||||
|
|
||||||
# 添加帧信息到提示词
|
# 添加帧信息到提示词
|
||||||
|
frame_info = []
|
||||||
for i, (frame_base64, timestamp) in enumerate(frames):
|
for i, (frame_base64, timestamp) in enumerate(frames):
|
||||||
if self.enable_frame_timing:
|
if self.enable_frame_timing:
|
||||||
prompt += f"\n\n第{i+1}帧 (时间: {timestamp:.2f}s):"
|
frame_info.append(f"第{i+1}帧 (时间: {timestamp:.2f}s)")
|
||||||
|
else:
|
||||||
|
frame_info.append(f"第{i+1}帧")
|
||||||
|
|
||||||
|
prompt += f"\n\n视频包含{len(frames)}帧图像:{', '.join(frame_info)}"
|
||||||
|
prompt += "\n\n请基于所有提供的帧图像进行综合分析,描述视频的完整内容和故事发展。"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用第一帧进行分析(批量模式暂时使用单帧,后续可以优化为真正的多图片分析)
|
# 尝试使用多图片分析
|
||||||
if frames:
|
response = await self._analyze_multiple_frames(frames, prompt)
|
||||||
frame_base64, _ = frames[0]
|
self.logger.info("✅ 批量多图片分析完成")
|
||||||
prompt += f"\n\n注意:当前显示的是第1帧,请基于这一帧和提示词进行分析。视频共有{len(frames)}帧。"
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"❌ 多图片分析失败: {e}")
|
||||||
|
# 降级到单帧分析
|
||||||
|
self.logger.warning("降级到单帧分析模式")
|
||||||
|
try:
|
||||||
|
frame_base64, timestamp = frames[0]
|
||||||
|
fallback_prompt = prompt + f"\n\n注意:由于技术限制,当前仅显示第1帧 (时间: {timestamp:.2f}s),视频共有{len(frames)}帧。请基于这一帧进行分析。"
|
||||||
|
|
||||||
response, _ = await self.video_llm.generate_response_for_image(
|
response, _ = await self.video_llm.generate_response_for_image(
|
||||||
prompt=prompt,
|
prompt=fallback_prompt,
|
||||||
image_base64=frame_base64,
|
image_base64=frame_base64,
|
||||||
image_format="jpeg"
|
image_format="jpeg"
|
||||||
)
|
)
|
||||||
self.logger.info("✅ 批量分析完成")
|
self.logger.info("✅ 降级的单帧分析完成")
|
||||||
return response
|
return response
|
||||||
else:
|
except Exception as fallback_e:
|
||||||
return "❌ 没有可分析的帧"
|
self.logger.error(f"❌ 降级分析也失败: {fallback_e}")
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"❌ 批量分析失败: {e}")
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str:
|
||||||
|
"""使用多图片分析方法"""
|
||||||
|
self.logger.info(f"开始构建包含{len(frames)}帧的多图片分析请求")
|
||||||
|
|
||||||
|
# 导入MessageBuilder用于构建多图片消息
|
||||||
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||||
|
from src.llm_models.utils_model import RequestType
|
||||||
|
|
||||||
|
# 构建包含多张图片的消息
|
||||||
|
message_builder = MessageBuilder().set_role(RoleType.User).add_text_content(prompt)
|
||||||
|
|
||||||
|
# 添加所有帧图像
|
||||||
|
for i, (frame_base64, timestamp) in enumerate(frames):
|
||||||
|
message_builder.add_image_content("jpeg", frame_base64)
|
||||||
|
# self.logger.info(f"已添加第{i+1}帧到分析请求 (时间: {timestamp:.2f}s, 图片大小: {len(frame_base64)} chars)")
|
||||||
|
|
||||||
|
message = message_builder.build()
|
||||||
|
self.logger.info(f"✅ 多图片消息构建完成,包含{len(frames)}张图片")
|
||||||
|
|
||||||
|
# 获取模型信息和客户端
|
||||||
|
model_info, api_provider, client = await self.video_llm._get_best_model_and_client()
|
||||||
|
self.logger.info(f"使用模型: {model_info.name} 进行多图片分析")
|
||||||
|
|
||||||
|
# 直接执行多图片请求
|
||||||
|
api_response = await self.video_llm._execute_request(
|
||||||
|
api_provider=api_provider,
|
||||||
|
client=client,
|
||||||
|
request_type=RequestType.RESPONSE,
|
||||||
|
model_info=model_info,
|
||||||
|
message_list=[message],
|
||||||
|
temperature=None,
|
||||||
|
max_tokens=None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
|
||||||
|
return api_response.content or "❌ 未获得响应内容"
|
||||||
|
|
||||||
async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
|
async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
|
||||||
"""逐帧分析并汇总"""
|
"""逐帧分析并汇总"""
|
||||||
self.logger.info(f"开始逐帧分析{len(frames)}帧")
|
self.logger.info(f"开始逐帧分析{len(frames)}帧")
|
||||||
@@ -355,7 +407,7 @@ class VideoAnalyzer:
|
|||||||
|
|
||||||
# 计算视频hash值
|
# 计算视频hash值
|
||||||
video_hash = self._calculate_video_hash(video_bytes)
|
video_hash = self._calculate_video_hash(video_bytes)
|
||||||
logger.info(f"视频hash: {video_hash[:16]}...")
|
# logger.info(f"视频hash: {video_hash[:16]}...")
|
||||||
|
|
||||||
# 检查数据库中是否已存在该视频的分析结果
|
# 检查数据库中是否已存在该视频的分析结果
|
||||||
existing_video = self._check_video_exists(video_hash)
|
existing_video = self._check_video_exists(video_hash)
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Any, Union
|
from typing import Dict, List, Any, Union
|
||||||
import os
|
import os
|
||||||
import inspect
|
|
||||||
import toml
|
import toml
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import CONFIG_DIR
|
||||||
from src.plugin_system.base.component_types import (
|
from src.plugin_system.base.component_types import (
|
||||||
PluginInfo,
|
PluginInfo,
|
||||||
PythonDependency,
|
PythonDependency,
|
||||||
@@ -71,6 +71,7 @@ class PluginBase(ABC):
|
|||||||
self.config: Dict[str, Any] = {} # 插件配置
|
self.config: Dict[str, Any] = {} # 插件配置
|
||||||
self.plugin_dir = plugin_dir # 插件目录路径
|
self.plugin_dir = plugin_dir # 插件目录路径
|
||||||
self.log_prefix = f"[Plugin:{self.plugin_name}]"
|
self.log_prefix = f"[Plugin:{self.plugin_name}]"
|
||||||
|
self._is_enabled = self.enable_plugin # 从插件定义中获取默认启用状态
|
||||||
|
|
||||||
# 加载manifest文件
|
# 加载manifest文件
|
||||||
self._load_manifest()
|
self._load_manifest()
|
||||||
@@ -100,7 +101,7 @@ class PluginBase(ABC):
|
|||||||
description=self.plugin_description,
|
description=self.plugin_description,
|
||||||
version=self.plugin_version,
|
version=self.plugin_version,
|
||||||
author=self.plugin_author,
|
author=self.plugin_author,
|
||||||
enabled=self.enable_plugin,
|
enabled=self._is_enabled,
|
||||||
is_built_in=False,
|
is_built_in=False,
|
||||||
config_file=self.config_file_name or "",
|
config_file=self.config_file_name or "",
|
||||||
dependencies=self.dependencies.copy(),
|
dependencies=self.dependencies.copy(),
|
||||||
@@ -453,86 +454,91 @@ class PluginBase(ABC):
|
|||||||
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
|
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
|
||||||
|
|
||||||
def _load_plugin_config(self): # sourcery skip: extract-method
|
def _load_plugin_config(self): # sourcery skip: extract-method
|
||||||
"""加载插件配置文件,支持版本检查和自动迁移"""
|
"""
|
||||||
|
加载插件配置文件,实现集中化管理和自动迁移。
|
||||||
|
|
||||||
|
处理逻辑:
|
||||||
|
1. 确定插件模板配置文件路径(位于插件目录内)。
|
||||||
|
2. 如果模板不存在,则在插件目录内生成一份默认配置。
|
||||||
|
3. 确定用户配置文件路径(位于 `config/plugins/` 目录下)。
|
||||||
|
4. 如果用户配置文件不存在,则从插件目录复制模板文件过去。
|
||||||
|
5. 加载用户配置文件,并进行版本检查和自动迁移(如果需要)。
|
||||||
|
6. 最终加载的配置是用户配置文件。
|
||||||
|
"""
|
||||||
if not self.config_file_name:
|
if not self.config_file_name:
|
||||||
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
|
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 优先使用传入的插件目录路径
|
# 1. 确定插件模板配置文件路径
|
||||||
if self.plugin_dir:
|
template_config_path = os.path.join(self.plugin_dir, self.config_file_name)
|
||||||
plugin_dir = self.plugin_dir
|
|
||||||
else:
|
# 2. 如果模板不存在,则在插件目录内生成
|
||||||
# fallback:尝试从类的模块信息获取路径
|
if not os.path.exists(template_config_path):
|
||||||
|
logger.info(f"{self.log_prefix} 插件目录缺少配置文件 {template_config_path},将生成默认配置。")
|
||||||
|
self._generate_and_save_default_config(template_config_path)
|
||||||
|
|
||||||
|
# 3. 确定用户配置文件路径
|
||||||
|
plugin_config_dir = os.path.join(CONFIG_DIR, "plugins", self.plugin_name)
|
||||||
|
user_config_path = os.path.join(plugin_config_dir, self.config_file_name)
|
||||||
|
|
||||||
|
# 确保用户插件配置目录存在
|
||||||
|
os.makedirs(plugin_config_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 4. 如果用户配置文件不存在,从模板复制
|
||||||
|
if not os.path.exists(user_config_path):
|
||||||
try:
|
try:
|
||||||
plugin_module_path = inspect.getfile(self.__class__)
|
shutil.copy2(template_config_path, user_config_path)
|
||||||
plugin_dir = os.path.dirname(plugin_module_path)
|
logger.info(f"{self.log_prefix} 已从模板创建用户配置文件: {user_config_path}")
|
||||||
except (TypeError, OSError):
|
except IOError as e:
|
||||||
# 最后的fallback:从模块的__file__属性获取
|
logger.error(f"{self.log_prefix} 复制配置文件失败: {e}", exc_info=True)
|
||||||
module = inspect.getmodule(self.__class__)
|
# 如果复制失败,后续将无法加载,直接返回
|
||||||
if module and hasattr(module, "__file__") and module.__file__:
|
|
||||||
plugin_dir = os.path.dirname(module.__file__)
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
config_file_path = os.path.join(plugin_dir, self.config_file_name)
|
# 检查最终的用户配置文件是否存在
|
||||||
|
if not os.path.exists(user_config_path):
|
||||||
# 如果配置文件不存在,生成默认配置
|
logger.warning(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在且无法创建。")
|
||||||
if not os.path.exists(config_file_path):
|
|
||||||
logger.info(f"{self.log_prefix} 配置文件 {config_file_path} 不存在,将生成默认配置。")
|
|
||||||
self._generate_and_save_default_config(config_file_path)
|
|
||||||
|
|
||||||
if not os.path.exists(config_file_path):
|
|
||||||
logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在且无法生成。")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
file_ext = os.path.splitext(self.config_file_name)[1].lower()
|
# 5. 加载、检查和迁移用户配置文件
|
||||||
|
_, file_ext = os.path.splitext(self.config_file_name)
|
||||||
if file_ext == ".toml":
|
if file_ext.lower() != ".toml":
|
||||||
# 加载现有配置
|
|
||||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
|
||||||
existing_config = toml.load(f) or {}
|
|
||||||
|
|
||||||
# 检查配置版本
|
|
||||||
current_version = self._get_current_config_version(existing_config)
|
|
||||||
|
|
||||||
# 如果配置文件没有版本信息,跳过版本检查
|
|
||||||
if current_version == "0.0.0":
|
|
||||||
logger.debug(f"{self.log_prefix} 配置文件无版本信息,跳过版本检查")
|
|
||||||
self.config = existing_config
|
|
||||||
else:
|
|
||||||
expected_version = self._get_expected_config_version()
|
|
||||||
|
|
||||||
if current_version != expected_version:
|
|
||||||
logger.info(
|
|
||||||
f"{self.log_prefix} 检测到配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 生成新的默认配置结构
|
|
||||||
new_config_structure = self._generate_config_from_schema()
|
|
||||||
|
|
||||||
# 迁移旧配置值到新结构
|
|
||||||
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
|
|
||||||
|
|
||||||
# 保存迁移后的配置
|
|
||||||
self._save_config_to_file(migrated_config, config_file_path)
|
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 配置文件已从 v{current_version} 更新到 v{expected_version}")
|
|
||||||
|
|
||||||
self.config = migrated_config
|
|
||||||
else:
|
|
||||||
logger.debug(f"{self.log_prefix} 配置版本匹配 (v{current_version}),直接加载")
|
|
||||||
self.config = existing_config
|
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} 配置已从 {config_file_path} 加载")
|
|
||||||
|
|
||||||
# 从配置中更新 enable_plugin
|
|
||||||
if "plugin" in self.config and "enabled" in self.config["plugin"]:
|
|
||||||
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
|
|
||||||
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
|
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
|
||||||
self.config = {}
|
self.config = {}
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(user_config_path, "r", encoding="utf-8") as f:
|
||||||
|
existing_config = toml.load(f) or {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True)
|
||||||
|
self.config = {}
|
||||||
|
return
|
||||||
|
|
||||||
|
current_version = self._get_current_config_version(existing_config)
|
||||||
|
expected_version = self._get_expected_config_version()
|
||||||
|
|
||||||
|
if current_version == "0.0.0":
|
||||||
|
logger.debug(f"{self.log_prefix} 用户配置文件无版本信息,跳过版本检查")
|
||||||
|
self.config = existing_config
|
||||||
|
elif current_version != expected_version:
|
||||||
|
logger.info(
|
||||||
|
f"{self.log_prefix} 检测到用户配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
|
||||||
|
)
|
||||||
|
new_config_structure = self._generate_config_from_schema()
|
||||||
|
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
|
||||||
|
self._save_config_to_file(migrated_config, user_config_path)
|
||||||
|
logger.info(f"{self.log_prefix} 用户配置文件已从 v{current_version} 更新到 v{expected_version}")
|
||||||
|
self.config = migrated_config
|
||||||
|
else:
|
||||||
|
logger.debug(f"{self.log_prefix} 用户配置版本匹配 (v{current_version}),直接加载")
|
||||||
|
self.config = existing_config
|
||||||
|
|
||||||
|
logger.debug(f"{self.log_prefix} 配置已从 {user_config_path} 加载")
|
||||||
|
|
||||||
|
# 从配置中更新 enable_plugin 状态
|
||||||
|
if "plugin" in self.config and "enabled" in self.config["plugin"]:
|
||||||
|
self._is_enabled = self.config["plugin"]["enabled"]
|
||||||
|
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self._is_enabled}")
|
||||||
|
|
||||||
def _check_dependencies(self) -> bool:
|
def _check_dependencies(self) -> bool:
|
||||||
"""检查插件依赖"""
|
"""检查插件依赖"""
|
||||||
|
|||||||
134
src/plugin_system/utils/dependency_alias.py
Normal file
134
src/plugin_system/utils/dependency_alias.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
本模块包含一个从Python包的“安装名”到其“导入名”的映射。
|
||||||
|
|
||||||
|
这个映射表主要用于解决一个常见问题:某些Python包通过pip安装时使用的名称
|
||||||
|
与在代码中`import`时使用的名称不一致。例如,我们使用`pip install beautifulsoup4`
|
||||||
|
来安装,但在代码中却需要`import bs4`。
|
||||||
|
|
||||||
|
当插件系统检查依赖时,如果一个开发者只简单地在依赖列表中写了安装名
|
||||||
|
(例如 "beautifulsoup4"),标准的导入检查`import('beautifulsoup4')`会失败。
|
||||||
|
通过这个映射表,依赖管理器可以在初次导入检查失败后,查询是否存在一个
|
||||||
|
已知的别名(例如 "bs4"),并尝试使用该别名进行二次导入检查。
|
||||||
|
|
||||||
|
这样做的好处是:
|
||||||
|
1. 提升开发者体验:插件开发者无需强制记忆这些特殊的名称对应关系,或者强制
|
||||||
|
使用更复杂的`PythonDependency`对象来分别指定安装名和导入名。
|
||||||
|
2. 增强系统健壮性:减少因名称不一致导致的插件加载失败问题。
|
||||||
|
3. 兼容性:对遵循最佳实践、正确指定了`package_name`和`install_name`的
|
||||||
|
开发者没有任何影响。
|
||||||
|
|
||||||
|
开发者可以持续向这个列表中贡献新的映射关系,使其更加完善。
|
||||||
|
"""
|
||||||
|
|
||||||
|
INSTALL_NAME_TO_IMPORT_NAME = {
|
||||||
|
# ============== 数据科学与机器学习 (Data Science & Machine Learning) ==============
|
||||||
|
"scikit-learn": "sklearn", # 机器学习库
|
||||||
|
"scikit-image": "skimage", # 图像处理库
|
||||||
|
"opencv-python": "cv2", # OpenCV 计算机视觉库
|
||||||
|
"opencv-contrib-python": "cv2", # OpenCV 扩展模块
|
||||||
|
"tensorflow-gpu": "tensorflow", # TensorFlow GPU版本
|
||||||
|
"tensorboardx": "tensorboardX", # TensorBoard 的封装
|
||||||
|
"torchvision": "torchvision", # PyTorch 视觉库 (通常与 torch 一起)
|
||||||
|
"torchaudio": "torchaudio", # PyTorch 音频库
|
||||||
|
"catboost": "catboost", # CatBoost 梯度提升库
|
||||||
|
"lightgbm": "lightgbm", # LightGBM 梯度提升库
|
||||||
|
"xgboost": "xgboost", # XGBoost 梯度提升库
|
||||||
|
"imbalanced-learn": "imblearn", # 处理不平衡数据集
|
||||||
|
"seqeval": "seqeval", # 序列标注评估
|
||||||
|
"gensim": "gensim", # 主题建模和NLP
|
||||||
|
"nltk": "nltk", # 自然语言工具包
|
||||||
|
"spacy": "spacy", # 工业级自然语言处理
|
||||||
|
"fuzzywuzzy": "fuzzywuzzy", # 模糊字符串匹配
|
||||||
|
"python-levenshtein": "Levenshtein", # Levenshtein 距离计算
|
||||||
|
|
||||||
|
# ============== Web开发与API (Web Development & API) ==============
|
||||||
|
"python-socketio": "socketio", # Socket.IO 服务器和客户端
|
||||||
|
"python-engineio": "engineio", # Engine.IO 底层库
|
||||||
|
"aiohttp": "aiohttp", # 异步HTTP客户端/服务器
|
||||||
|
"python-multipart": "multipart", # 解析 multipart/form-data
|
||||||
|
"uvloop": "uvloop", # 高性能asyncio事件循环
|
||||||
|
"httptools": "httptools", # 高性能HTTP解析器
|
||||||
|
"websockets": "websockets", # WebSocket实现
|
||||||
|
"fastapi": "fastapi", # 高性能Web框架
|
||||||
|
"starlette": "starlette", # ASGI框架
|
||||||
|
"uvicorn": "uvicorn", # ASGI服务器
|
||||||
|
"gunicorn": "gunicorn", # WSGI服务器
|
||||||
|
"django-rest-framework": "rest_framework", # Django REST框架
|
||||||
|
"django-cors-headers": "corsheaders", # Django CORS处理
|
||||||
|
"flask-jwt-extended": "flask_jwt_extended", # Flask JWT扩展
|
||||||
|
"flask-sqlalchemy": "flask_sqlalchemy", # Flask SQLAlchemy扩展
|
||||||
|
"flask-migrate": "flask_migrate", # Flask Alembic迁移扩展
|
||||||
|
"python-jose": "jose", # JOSE (JWT, JWS, JWE) 实现
|
||||||
|
"passlib": "passlib", # 密码哈希库
|
||||||
|
"bcrypt": "bcrypt", # Bcrypt密码哈希
|
||||||
|
|
||||||
|
# ============== 数据库 (Database) ==============
|
||||||
|
"mysql-connector-python": "mysql.connector", # MySQL官方驱动
|
||||||
|
"psycopg2-binary": "psycopg2", # PostgreSQL驱动 (二进制)
|
||||||
|
"pymongo": "pymongo", # MongoDB驱动
|
||||||
|
"redis": "redis", # Redis客户端
|
||||||
|
"aioredis": "aioredis", # 异步Redis客户端
|
||||||
|
"sqlalchemy": "sqlalchemy", # SQL工具包和ORM
|
||||||
|
"alembic": "alembic", # SQLAlchemy数据库迁移工具
|
||||||
|
"tortoise-orm": "tortoise", # 异步ORM
|
||||||
|
|
||||||
|
# ============== 图像与多媒体 (Image & Multimedia) ==============
|
||||||
|
"Pillow": "PIL", # Python图像处理库 (PIL Fork)
|
||||||
|
"moviepy": "moviepy", # 视频编辑库
|
||||||
|
"pydub": "pydub", # 音频处理库
|
||||||
|
"pycairo": "cairo", # Cairo 2D图形库的Python绑定
|
||||||
|
"wand": "wand", # ImageMagick的Python绑定
|
||||||
|
|
||||||
|
# ============== 解析与序列化 (Parsing & Serialization) ==============
|
||||||
|
"beautifulsoup4": "bs4", # HTML/XML解析库
|
||||||
|
"lxml": "lxml", # 高性能HTML/XML解析库
|
||||||
|
"PyYAML": "yaml", # YAML解析库
|
||||||
|
"python-dotenv": "dotenv", # .env文件解析
|
||||||
|
"python-dateutil": "dateutil", # 强大的日期时间解析
|
||||||
|
"protobuf": "google.protobuf", # Protocol Buffers
|
||||||
|
"msgpack": "msgpack", # MessagePack序列化
|
||||||
|
"orjson": "orjson", # 高性能JSON库
|
||||||
|
"pydantic": "pydantic", # 数据验证和设置管理
|
||||||
|
|
||||||
|
# ============== 系统与硬件 (System & Hardware) ==============
|
||||||
|
"pyserial": "serial", # 串口通信
|
||||||
|
"pyusb": "usb", # USB访问
|
||||||
|
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
|
||||||
|
"psutil": "psutil", # 系统信息和进程管理
|
||||||
|
"watchdog": "watchdog", # 文件系统事件监控
|
||||||
|
"python-gnupg": "gnupg", # GnuPG的Python接口
|
||||||
|
|
||||||
|
# ============== 加密与安全 (Cryptography & Security) ==============
|
||||||
|
"pycrypto": "Crypto", # 加密库 (较旧)
|
||||||
|
"pycryptodome": "Crypto", # PyCrypto的现代分支
|
||||||
|
"cryptography": "cryptography", # 现代加密库
|
||||||
|
"pyopenssl": "OpenSSL", # OpenSSL的Python接口
|
||||||
|
"service-identity": "service_identity", # 服务身份验证
|
||||||
|
|
||||||
|
# ============== 工具与杂项 (Utilities & Miscellaneous) ==============
|
||||||
|
"setuptools": "setuptools", # 打包工具
|
||||||
|
"pip": "pip", # 包安装器
|
||||||
|
"tqdm": "tqdm", # 进度条
|
||||||
|
"regex": "regex", # 替代的正则表达式引擎
|
||||||
|
"colorama": "colorama", # 跨平台彩色终端文本
|
||||||
|
"termcolor": "termcolor", # 终端颜色格式化
|
||||||
|
"requests-oauthlib": "requests_oauthlib", # OAuth for Requests
|
||||||
|
"oauthlib": "oauthlib", # 通用OAuth库
|
||||||
|
"authlib": "authlib", # OAuth和OpenID Connect客户端/服务器
|
||||||
|
"pyjwt": "jwt", # JSON Web Token实现
|
||||||
|
"python-editor": "editor", # 程序化地调用编辑器
|
||||||
|
"prompt-toolkit": "prompt_toolkit", # 构建交互式命令行
|
||||||
|
"pygments": "pygments", # 语法高亮
|
||||||
|
"tabulate": "tabulate", # 生成漂亮的表格
|
||||||
|
"nats-client": "nats", # NATS客户端
|
||||||
|
"gitpython": "git", # Git的Python接口
|
||||||
|
"pygithub": "github", # GitHub API v3的Python接口
|
||||||
|
"python-gitlab": "gitlab", # GitLab API的Python接口
|
||||||
|
"jira": "jira", # JIRA API的Python接口
|
||||||
|
"python-jenkins": "jenkins", # Jenkins API的Python接口
|
||||||
|
"huggingface-hub": "huggingface_hub", # Hugging Face Hub API
|
||||||
|
"apache-airflow": "airflow", # Airflow工作流管理
|
||||||
|
"pandas-stubs": "pandas-stubs", # Pandas的类型存根
|
||||||
|
"data-science-types": "data_science_types", # 数据科学类型
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ from packaging.requirements import Requirement
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import PythonDependency
|
from src.plugin_system.base.component_types import PythonDependency
|
||||||
|
from src.plugin_system.utils.dependency_alias import INSTALL_NAME_TO_IMPORT_NAME
|
||||||
|
|
||||||
logger = get_logger("dependency_manager")
|
logger = get_logger("dependency_manager")
|
||||||
|
|
||||||
@@ -190,9 +191,11 @@ class DependencyManager:
|
|||||||
|
|
||||||
def _check_single_dependency(self, dep: PythonDependency) -> bool:
|
def _check_single_dependency(self, dep: PythonDependency) -> bool:
|
||||||
"""检查单个依赖是否满足要求"""
|
"""检查单个依赖是否满足要求"""
|
||||||
|
|
||||||
|
def _try_check(import_name: str) -> bool:
|
||||||
|
"""尝试使用给定的导入名进行检查"""
|
||||||
try:
|
try:
|
||||||
# 尝试导入包
|
spec = importlib.util.find_spec(import_name)
|
||||||
spec = importlib.util.find_spec(dep.package_name)
|
|
||||||
if spec is None:
|
if spec is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -202,14 +205,14 @@ class DependencyManager:
|
|||||||
|
|
||||||
# 检查版本要求
|
# 检查版本要求
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(dep.package_name)
|
module = importlib.import_module(import_name)
|
||||||
installed_version = getattr(module, '__version__', None)
|
installed_version = getattr(module, '__version__', None)
|
||||||
|
|
||||||
if installed_version is None:
|
if installed_version is None:
|
||||||
# 尝试其他常见的版本属性
|
# 尝试其他常见的版本属性
|
||||||
installed_version = getattr(module, 'VERSION', None)
|
installed_version = getattr(module, 'VERSION', None)
|
||||||
if installed_version is None:
|
if installed_version is None:
|
||||||
logger.debug(f"无法获取包 {dep.package_name} 的版本信息,假设满足要求")
|
logger.debug(f"无法获取包 {import_name} 的版本信息,假设满足要求")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 解析版本要求
|
# 解析版本要求
|
||||||
@@ -217,13 +220,28 @@ class DependencyManager:
|
|||||||
return version.parse(str(installed_version)) in req.specifier
|
return version.parse(str(installed_version)) in req.specifier
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"检查包 {dep.package_name} 版本时出错: {e}")
|
logger.debug(f"检查包 {import_name} 版本时出错: {e}")
|
||||||
return True # 如果无法检查版本,假设满足要求
|
return True # 如果无法检查版本,假设满足要求
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"检查依赖 {dep.package_name} 时发生未知错误: {e}")
|
logger.error(f"检查依赖 {import_name} 时发生未知错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 1. 首先尝试使用原始的 package_name 进行检查
|
||||||
|
if _try_check(dep.package_name):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 2. 如果失败,查询别名映射表
|
||||||
|
# 注意:此时 dep.package_name 可能是 simple "requests" 或 "beautifulsoup4"
|
||||||
|
import_alias = INSTALL_NAME_TO_IMPORT_NAME.get(dep.package_name)
|
||||||
|
if import_alias:
|
||||||
|
logger.debug(f"依赖 '{dep.package_name}' 导入失败, 尝试使用别名 '{import_alias}'")
|
||||||
|
if _try_check(import_alias):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 3. 如果别名也失败了,或者没有别名,最终确认失败
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _install_single_package(self, package: str, plugin_name: str = "") -> bool:
|
def _install_single_package(self, package: str, plugin_name: str = "") -> bool:
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ from src.common.logger import get_logger
|
|||||||
from src.plugin_system import (
|
from src.plugin_system import (
|
||||||
BasePlugin,
|
BasePlugin,
|
||||||
ComponentInfo,
|
ComponentInfo,
|
||||||
BaseAction,
|
|
||||||
BaseCommand,
|
|
||||||
register_plugin
|
register_plugin
|
||||||
)
|
)
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
from src.plugin_system.base.config_types import ConfigField
|
||||||
@@ -68,6 +66,8 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
|||||||
},
|
},
|
||||||
"schedule": {
|
"schedule": {
|
||||||
"enable_schedule": ConfigField(type=bool, default=False, description="是否启用定时发送"),
|
"enable_schedule": ConfigField(type=bool, default=False, description="是否启用定时发送"),
|
||||||
|
"random_interval_min_minutes": ConfigField(type=int, default=5, description="随机间隔分钟数下限"),
|
||||||
|
"random_interval_max_minutes": ConfigField(type=int, default=15, description="随机间隔分钟数上限"),
|
||||||
},
|
},
|
||||||
"cookie": {
|
"cookie": {
|
||||||
"http_fallback_host": ConfigField(type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"),
|
"http_fallback_host": ConfigField(type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"),
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ QQ空间服务模块
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -15,7 +14,6 @@ from typing import Callable, Optional, Dict, Any, List, Tuple
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import bs4
|
import bs4
|
||||||
import json5
|
import json5
|
||||||
from src.chat.utils.utils_image import get_image_manager
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.apis import config_api, person_api
|
from src.plugin_system.apis import config_api, person_api
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
@@ -91,8 +92,12 @@ class SchedulerService:
|
|||||||
result.get("message", "")
|
result.get("message", "")
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 等待5分钟后进行下一次检查
|
# 6. 计算并等待一个随机的时间间隔
|
||||||
await asyncio.sleep(300)
|
min_minutes = self.get_config("schedule.random_interval_min_minutes", 5)
|
||||||
|
max_minutes = self.get_config("schedule.random_interval_max_minutes", 15)
|
||||||
|
wait_seconds = random.randint(min_minutes * 60, max_minutes * 60)
|
||||||
|
logger.info(f"下一次检查将在 {wait_seconds / 60:.2f} 分钟后进行。")
|
||||||
|
await asyncio.sleep(wait_seconds)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("定时任务循环被取消。")
|
logger.info("定时任务循环被取消。")
|
||||||
@@ -113,7 +118,7 @@ class SchedulerService:
|
|||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
record = session.query(MaiZoneScheduleStatus).filter(
|
record = session.query(MaiZoneScheduleStatus).filter(
|
||||||
MaiZoneScheduleStatus.datetime_hour == hour_str,
|
MaiZoneScheduleStatus.datetime_hour == hour_str,
|
||||||
MaiZoneScheduleStatus.is_processed == True
|
MaiZoneScheduleStatus.is_processed == True # noqa: E712
|
||||||
).first()
|
).first()
|
||||||
return record is not None
|
return record is not None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -138,10 +143,10 @@ class SchedulerService:
|
|||||||
|
|
||||||
if record:
|
if record:
|
||||||
# 如果存在,则更新状态
|
# 如果存在,则更新状态
|
||||||
record.is_processed = True
|
record.is_processed = True # type: ignore
|
||||||
record.processed_at = datetime.datetime.now()
|
record.processed_at = datetime.datetime.now()# type: ignore
|
||||||
record.send_success = success
|
record.send_success = success# type: ignore
|
||||||
record.story_content = content
|
record.story_content = content# type: ignore
|
||||||
else:
|
else:
|
||||||
# 如果不存在,则创建新记录
|
# 如果不存在,则创建新记录
|
||||||
new_record = MaiZoneScheduleStatus(
|
new_record = MaiZoneScheduleStatus(
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from src.plugin_system import (
|
|||||||
PythonDependency
|
PythonDependency
|
||||||
)
|
)
|
||||||
from src.plugin_system.apis import config_api # 添加config_api导入
|
from src.plugin_system.apis import config_api # 添加config_api导入
|
||||||
|
from src.common.cache_manager import tool_cache
|
||||||
import httpx
|
import httpx
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
@@ -86,6 +87,13 @@ class WebSurfingTool(BaseTool):
|
|||||||
if not query:
|
if not query:
|
||||||
return {"error": "搜索查询不能为空。"}
|
return {"error": "搜索查询不能为空。"}
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
query = function_args.get("query")
|
||||||
|
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__, semantic_query=query)
|
||||||
|
if cached_result:
|
||||||
|
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
||||||
|
return cached_result
|
||||||
|
|
||||||
# 读取搜索配置
|
# 读取搜索配置
|
||||||
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
|
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
|
||||||
search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
|
search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
|
||||||
@@ -94,11 +102,18 @@ class WebSurfingTool(BaseTool):
|
|||||||
|
|
||||||
# 根据策略执行搜索
|
# 根据策略执行搜索
|
||||||
if search_strategy == "parallel":
|
if search_strategy == "parallel":
|
||||||
return await self._execute_parallel_search(function_args, enabled_engines)
|
result = await self._execute_parallel_search(function_args, enabled_engines)
|
||||||
elif search_strategy == "fallback":
|
elif search_strategy == "fallback":
|
||||||
return await self._execute_fallback_search(function_args, enabled_engines)
|
result = await self._execute_fallback_search(function_args, enabled_engines)
|
||||||
else: # single
|
else: # single
|
||||||
return await self._execute_single_search(function_args, enabled_engines)
|
result = await self._execute_single_search(function_args, enabled_engines)
|
||||||
|
|
||||||
|
# 保存到缓存
|
||||||
|
if "error" not in result:
|
||||||
|
query = function_args.get("query")
|
||||||
|
await tool_cache.set(self.name, function_args, self.__class__, result, semantic_query=query)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||||
"""并行搜索策略:同时使用所有启用的搜索引擎"""
|
"""并行搜索策略:同时使用所有启用的搜索引擎"""
|
||||||
@@ -449,6 +464,12 @@ class URLParserTool(BaseTool):
|
|||||||
"""
|
"""
|
||||||
执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。
|
执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。
|
||||||
"""
|
"""
|
||||||
|
# 检查缓存
|
||||||
|
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__)
|
||||||
|
if cached_result:
|
||||||
|
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
||||||
|
return cached_result
|
||||||
|
|
||||||
urls_input = function_args.get("urls")
|
urls_input = function_args.get("urls")
|
||||||
if not urls_input:
|
if not urls_input:
|
||||||
return {"error": "URL列表不能为空。"}
|
return {"error": "URL列表不能为空。"}
|
||||||
@@ -556,6 +577,10 @@ class URLParserTool(BaseTool):
|
|||||||
"errors": error_messages
|
"errors": error_messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 保存到缓存
|
||||||
|
if "error" not in result:
|
||||||
|
await tool_cache.set(self.name, function_args, self.__class__, result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
||||||
|
|||||||
@@ -382,3 +382,6 @@ enable_url_tool = true # 是否启用URL解析tool
|
|||||||
# 搜索引擎配置
|
# 搜索引擎配置
|
||||||
enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg"
|
enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg"
|
||||||
search_strategy = "single" # 搜索策略: "single"(使用第一个可用引擎), "parallel"(并行使用所有启用的引擎), "fallback"(按顺序尝试,失败则尝试下一个)
|
search_strategy = "single" # 搜索策略: "single"(使用第一个可用引擎), "parallel"(并行使用所有启用的引擎), "fallback"(按顺序尝试,失败则尝试下一个)
|
||||||
|
|
||||||
|
[plugins] # 插件配置
|
||||||
|
centralized_config = true # 是否启用插件配置集中化管理
|
||||||
|
|||||||
Reference in New Issue
Block a user