This commit is contained in:
雅诺狐
2025-08-18 17:29:32 +08:00
18 changed files with 1128 additions and 566 deletions

View File

@@ -0,0 +1,66 @@
# 全新三层记忆系统架构 (V2.0) 设计文档
## 1. 核心思想
本架构旨在建立一个清晰、有序的信息处理流水线,模拟人类记忆从瞬时感知到长期知识沉淀的过程。信息将经历**短期记忆 (STM)**、**中期记忆 (MTM)** 和 **长期记忆 (LTM)** 三个阶段,实现从海量、零散到结构化、深刻的转化。
## 2. 架构分层详解
### 2.1. 短期记忆 (STM - Short-Term Memory) - “消息缓冲区”
* **职责**: 捕获并暂存所有进入核心的最新消息,为即时对话提供上下文,实现快速响应。
* **实现方式**:
* **内存队列**: 采用定长的内存队列(如 `collections.deque`),存储最近的 N 条原始消息(建议初始值为 200
* **实时向量化**: 消息入队时,异步进行文本内容的语义向量化,生成“意义指纹”。
* **快速检索**: 利用高效的向量相似度计算库(如 FAISS, Annoy在新消息到来时快速从队列中检索最相关的历史消息构建即时上下文。
* **触发机制**: 当队列达到容量上限时,将最老的一批消息(例如前 50 条)打包,移交给中期记忆模块处理。
### 2.2. 中期记忆 (MTM - Mid-Term Memory) - “记忆压缩器”
* **职责**: 对来自短期记忆的大量零散信息进行压缩、总结,形成结构化的“记忆片段”。
* **实现方式**:
* **LLM 总结**: 调用大语言模型LLM对 STM 移交的消息包进行深度分析和总结提炼成一段精简的“记忆陈述”Memory Statement
* **信息结构化**: 每个记忆片段都将包含以下元数据:
* `memory_text`: 记忆陈述本身。
* `keywords`: 关联的关键词列表。
* `time_range`: 记忆所涉及的时间范围。
* `importance_score`: LLM 评估的重要性评分。
* `access_count`: 访问计数器,初始为 0。
* **持久化存储**: 将结构化的记忆片段存储在数据库中,可复用或改造现有 `Memory` 表。
* **触发机制**: 由 STM 的队列溢出事件触发。
### 2.3. 长期记忆 (LTM - Long-Term Memory) - “知识图谱”
* **职责**: 将经过验证的、具有高价值的中期记忆,内化为系统核心知识的一部分,构建深层联系。
* **实现方式**:
* **晋升机制**: 通过一个定期的“记忆整理”任务,扫描中期记忆数据库。当某个记忆片段的 `access_count` 达到预设阈值(例如 10 次),则触发晋升。
* **融入图谱**: 晋升的记忆片段将被送往 `Hippocampus` 模块。`Hippocampus` 将不再直接处理原始聊天记录,而是处理这些高质量、经过预处理的记忆片段。它会从中提取核心概念(节点)和它们之间的关系(边),然后将这些信息融入并更新现有的知识图谱。
* **触发机制**: 由定时任务(例如每天执行一次)触发。
## 3. 信息处理流程
```mermaid
graph TD
A[输入: 新消息] --> B{短期记忆 STM};
B --> |实时向量检索| C[输出: 对话上下文];
B --> |队列满| D{中期记忆 MTM};
D --> |LLM 总结| E[存入数据库: 记忆片段];
E --> |关键词/时间检索| C;
E --> |访问次数高| F{长期记忆 LTM};
F --> |LLM 提取概念/关系| G[更新: 知识图谱];
G --> |图谱扩散激活检索| C;
subgraph "内存中 (高速)"
B
end
subgraph "数据库中 (持久化)"
E
G
end
```
## 4. 现有模块改造计划
* **`InstantMemory`**: 将被新的 **STM****MTM** 模块取代。其原有的“判断是否需要记忆”和“总结”的功能,将融入到 MTM 的处理流程中。
* **`Hippocampus`**: 将保留其作为 **LTM** 的核心地位,但其输入源将从“随机抽样的历史聊天记录”变更为“从 MTM 晋升的高价值记忆片段”。这将极大提升其构建知识图谱的效率和质量。

View File

@@ -165,10 +165,38 @@ configure_dependency_settings(auto_install_timeout=600)
## 工作流程 ## 工作流程
1. **插件初始化**: 当插件类被实例化时,系统自动检查依赖 1. **插件初始化**: 当插件类被实例化时,系统自动检查依赖
2. **依赖标准化**: 将字符串格式的依赖转换为PythonDependency对象 2. **依赖标准化**: 将字符串格式的依赖转换为`PythonDependency`对象
3. **检查已安装**: 尝试导入每个依赖包并检查版本 3. **检查已安装**: 尝试导入每个依赖包并检查版本
4. **自动安装**: 如果启用,自动安装缺失的依赖 4. **智能别名解析 (新增)**: 如果直接导入失败 (例如 `import beautifulsoup4` 失败),系统会查询内置的别名映射表 (例如 `beautifulsoup4` -> `bs4`),并尝试使用别名再次导入。
5. **错误处理**: 记录详细的错误信息和安装日志 5. **自动安装**: 如果启用,自动安装缺失的依赖
6. **错误处理**: 记录详细的错误信息和安装日志
## 智能别名解析 (Smart Alias Resolution)
为了提升开发体验,依赖管理系统内置了一套智能别名解析机制。
### 解决的问题
Python生态中存在一些特殊的包它们的**安装名** (在 `pip install` 中使用) 与**导入名** (在 `import` 语句中使用) 不一致。最典型的例子就是:
- 安装名: `beautifulsoup4`, 导入名: `bs4`
- 安装名: `Pillow`, 导入名: `PIL`
- 安装名: `scikit-learn`, 导入名: `sklearn`
如果开发者在 `python_dependencies` 列表中使用简单的字符串格式 `"beautifulsoup4"`,标准的依赖检查会因为无法 `import beautifulsoup4` 而失败。
### 工作原理
当依赖管理器通过包名直接导入失败时,它会:
1. 查询一个内置的、包含上百个常见包的别名映射表。
2. 如果在表中找到对应的导入名,则使用该别名再次尝试导入。
3. 如果使用别名导入成功,则依赖检查通过,并继续进行版本验证。
这个过程是自动的,旨在处理绝大多数常见情况,减少开发者手动配置的麻烦。
### 注意事项
- **最佳实践**: 尽管有智能别名解析,我们仍然**强烈推荐**使用 `PythonDependency` 对象来明确指定 `package_name` (导入名) 和 `install_name` (安装名),这能确保最高的准确性和可读性。
- **覆盖范围**: 内置的别名映射表涵盖了大量常用库但无法保证100%覆盖所有情况。如果遇到别名库未收录的包,请使用 `PythonDependency` 对象进行精确定义。
## 日志输出示例 ## 日志输出示例
@@ -192,12 +220,13 @@ configure_dependency_settings(auto_install_timeout=600)
## 最佳实践 ## 最佳实践
1. **使用详细的PythonDependency对象** 以获得更好的控制和文档 1. **优先使用`PythonDependency`对象**: 这是最可靠、最明确的方式,尤其是在安装名和导入名不同时。
2. **配置PyPI镜像源** 特别是在中国大陆地区,可显著提升下载速度 2. **利用智能别名解析**: 对于常见的、安装名与导入名不一致的包 (如 `beautifulsoup4`, `Pillow` 等),可以直接在字符串列表里使用安装名,系统会自动解析。
3. **合理设置可选依赖** 避免非核心功能阻止插件加载 3. **配置PyPI镜像源**: 特别是在中国大陆地区,可显著提升下载速度。
4. **指定版本要求** 确保兼容性 4. **合理设置可选依赖**: 避免非核心功能阻止插件加载。
5. **添加描述信息** 帮助用户理解依赖的用途 5. **指定版本要求**: 确保兼容性。
6. **测试依赖配置** 在不同环境中验证依赖是否正确 6. **添加描述信息**: 帮助用户理解依赖的用途。
7. **测试依赖配置**: 在不同环境中验证依赖是否正确。
## 安全考虑 ## 安全考虑
@@ -225,7 +254,8 @@ configure_dependency_settings(auto_install_timeout=600)
### 导入错误 ### 导入错误
1. 确认包名与导入名一致 1. **确认包名与导入名**: 检查安装名和导入名是否一致。如果不一致,推荐使用 `PythonDependency` 对象明确指定 `package_name``install_name`
2. 检查可选依赖配置 2. **利用自动别名解析**: 对于常见库,系统会自动尝试解析别名。如果你的库比较冷门且名称不一致,请使用 `PythonDependency` 对象。
3. 验证安装是否成功 3. **检查可选依赖配置**: 确认 `optional=True` 是否被正确设置。
4. **验证安装是否成功**: 查看日志,确认 `pip install` 过程没有报错。

View File

@@ -60,5 +60,7 @@ exa_py
asyncddgs asyncddgs
opencv-python opencv-python
Pillow Pillow
chromadb
asyncio asyncio
tavily-python tavily-python
google-generativeai = 0.8.5

View File

@@ -1,344 +1,224 @@
import time
import json import json
import sqlite3
import chromadb
import hashlib import hashlib
import re import inspect
import numpy as np
import faiss
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from datetime import datetime, timedelta
from pathlib import Path
from difflib import SequenceMatcher
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
logger = get_logger("cache_manager") logger = get_logger("cache_manager")
class CacheManager:
"""
一个支持分层和语义缓存的通用工具缓存管理器。
采用单例模式,确保在整个应用中只有一个缓存实例。
L1缓存: 内存字典 (KV) + FAISS (Vector)。
L2缓存: SQLite (KV) + ChromaDB (Vector)。
"""
_instance = None
class ToolCache: def __new__(cls, *args, **kwargs):
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配""" if not cls._instance:
cls._instance = super(CacheManager, cls).__new__(cls)
return cls._instance
def __init__( def __init__(self, default_ttl: int = 3600, db_path: str = "data/cache.db", chroma_path: str = "data/chroma_db"):
self,
cache_dir: str = "data/tool_cache",
max_age_hours: int = 24,
similarity_threshold: float = 0.65,
):
""" """
初始化缓存管理器 初始化缓存管理器
Args:
cache_dir: 缓存目录路径
max_age_hours: 缓存最大存活时间(小时)
similarity_threshold: 近似匹配的相似度阈值 (0-1)
""" """
self.cache_dir = Path(cache_dir) if not hasattr(self, '_initialized'):
self.max_age = timedelta(hours=max_age_hours) self.default_ttl = default_ttl
self.max_age_seconds = max_age_hours * 3600
self.similarity_threshold = similarity_threshold # L1 缓存 (内存)
self.cache_dir.mkdir(parents=True, exist_ok=True) self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
self.l1_vector_id_to_key: Dict[int, str] = {}
# L2 缓存 (持久化)
self.db_path = db_path
self._init_sqlite()
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
# 嵌入模型
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
@staticmethod self._initialized = True
def _normalize_query(query: str) -> str: logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (SQLite+ChromaDB)")
def _init_sqlite(self):
"""初始化SQLite数据库和表结构。"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
value TEXT,
expires_at REAL
)
""")
conn.commit()
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str:
"""生成确定性的缓存键,包含代码哈希以实现自动失效。"""
try:
source_code = inspect.getsource(tool_class)
code_hash = hashlib.md5(source_code.encode()).hexdigest()
except (TypeError, OSError) as e:
code_hash = "unknown"
logger.warning(f"无法获取 {tool_class.__name__} 的源代码,代码哈希将为 'unknown'。错误: {e}")
try:
sorted_args = json.dumps(function_args, sort_keys=True)
except TypeError:
sorted_args = repr(sorted(function_args.items()))
return f"{tool_name}::{sorted_args}::{code_hash}"
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, semantic_query: Optional[str] = None) -> Optional[Any]:
""" """
标准化查询文本,用于相似度比较 从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
Args:
query: 原始查询文本
Returns:
标准化后的查询文本
""" """
if not query: # 步骤 1: L1 精确缓存查询
return "" key = self._generate_key(tool_name, function_args, tool_class)
logger.debug(f"生成的缓存键: {key}")
if semantic_query:
logger.debug(f"使用的语义查询: '{semantic_query}'")
# 纯 Python 实现 if key in self.l1_kv_cache:
normalized = query.lower() entry = self.l1_kv_cache[key]
normalized = re.sub(r"[^\w\s]", " ", normalized) if time.time() < entry["expires_at"]:
normalized = " ".join(normalized.split()) logger.info(f"命中L1键值缓存: {key}")
return normalized return entry["data"]
else:
del self.l1_kv_cache[key]
def _calculate_similarity(self, text1: str, text2: str) -> float: # 步骤 2: L1/L2 语义和L2精确缓存查询
""" query_embedding = None
计算两个文本的相似度 if semantic_query and self.embedding_model:
embedding_result = await self.embedding_model.get_embedding(semantic_query)
if embedding_result:
query_embedding = np.array([embedding_result], dtype='float32')
Args: # 步骤 2a: L1 语义缓存 (FAISS)
text1: 文本1 if query_embedding is not None and self.l1_vector_index.ntotal > 0:
text2: 文本2 faiss.normalize_L2(query_embedding)
distances, indices = self.l1_vector_index.search(query_embedding, 1)
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
hit_index = indices[0][0]
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
logger.info(f"命中L1语义缓存: {l1_hit_key}")
return self.l1_kv_cache[l1_hit_key]["data"]
Returns: # 步骤 2b: L2 精确缓存 (SQLite)
相似度分数 (0-1) with sqlite3.connect(self.db_path) as conn:
""" cursor = conn.cursor()
if not text1 or not text2: cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (key,))
return 0.0 row = cursor.fetchone()
if row:
value, expires_at = row
if time.time() < expires_at:
logger.info(f"命中L2键值缓存: {key}")
data = json.loads(value)
# 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
return data
else:
cursor.execute("DELETE FROM cache WHERE key = ?", (key,))
conn.commit()
# 纯 Python 实现 # 步骤 2c: L2 语义缓存 (ChromaDB)
norm_text1 = self._normalize_query(text1) if query_embedding is not None:
norm_text2 = self._normalize_query(text2) results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
if results and results['ids'] and results['ids'][0]:
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
if distance != 'N/A' and distance < 0.75:
l2_hit_key = results['ids'][0]
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],))
row = cursor.fetchone()
if row:
value, expires_at = row
if time.time() < expires_at:
data = json.loads(value)
logger.debug(f"L2语义缓存返回的数据: {data}")
# 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
if query_embedding is not None:
new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(query_embedding)
self.l1_vector_index.add(x=query_embedding)
self.l1_vector_id_to_key[new_id] = key
return data
if norm_text1 == norm_text2: logger.debug(f"缓存未命中: {key}")
return 1.0
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
@staticmethod
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
"""
生成缓存键
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
缓存键字符串
"""
# 将参数排序后序列化,确保相同参数产生相同的键
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
# 纯 Python 实现
cache_string = f"{tool_name}:{sorted_args}"
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
def _get_cache_file_path(self, cache_key: str) -> Path:
"""获取缓存文件路径"""
return self.cache_dir / f"{cache_key}.json"
def _is_cache_expired(self, cached_time: datetime) -> bool:
"""检查缓存是否过期"""
return datetime.now() - cached_time > self.max_age
def _find_similar_cache(
self, tool_name: str, function_args: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
查找相似的缓存条目
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
相似的缓存结果如果不存在则返回None
"""
query = function_args.get("query", "")
if not query:
return None
candidates = []
cache_data_list = []
# 遍历所有缓存文件,收集候选项
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
# 检查是否是同一个工具
if cache_data.get("tool_name") != tool_name:
continue
# 检查缓存是否过期
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
continue
# 检查其他参数是否匹配除了query
cached_args = cache_data.get("function_args", {})
args_match = True
for key, value in function_args.items():
if key != "query" and cached_args.get(key) != value:
args_match = False
break
if not args_match:
continue
# 收集候选项
cached_query = cached_args.get("query", "")
candidates.append((cached_query, len(cache_data_list)))
cache_data_list.append(cache_data)
except Exception as e:
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
continue
if not candidates:
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None
# 纯 Python 实现
best_match = None
best_similarity = 0.0
for cached_query, index in candidates:
similarity = self._calculate_similarity(query, cached_query)
if similarity > best_similarity and similarity >= self.similarity_threshold:
best_similarity = similarity
best_match = cache_data_list[index]
if best_match is not None:
cached_query = best_match["function_args"].get("query", "")
logger.info(
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
)
return best_match["result"]
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None return None
def get( async def set(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
self, tool_name: str, function_args: Dict[str, Any] """将结果存入所有缓存层。"""
) -> Optional[Dict[str, Any]]: if ttl is None:
""" ttl = self.default_ttl
从缓存获取结果,支持精确匹配和近似匹配 if ttl <= 0:
return
Args: key = self._generate_key(tool_name, function_args, tool_class)
tool_name: 工具名称 expires_at = time.time() + ttl
function_args: 函数参数
# 写入 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
Returns: # 写入 L2
缓存的结果如果不存在或已过期则返回None value = json.dumps(data)
""" with sqlite3.connect(self.db_path) as conn:
# 首先尝试精确匹配 cursor = conn.cursor()
cache_key = self._generate_cache_key(tool_name, function_args) cursor.execute("REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, value, expires_at))
cache_file = self._get_cache_file_path(cache_key) conn.commit()
if cache_file.exists(): # 写入语义缓存
try: if semantic_query and self.embedding_model:
with open(cache_file, "r", encoding="utf-8") as f: embedding_result = await self.embedding_model.get_embedding(semantic_query)
cache_data = json.load(f) if embedding_result:
embedding = np.array([embedding_result], dtype='float32')
# 写入 L1 Vector
new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(embedding)
self.l1_vector_index.add(x=embedding)
self.l1_vector_id_to_key[new_id] = key
# 写入 L2 Vector
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
# 检查缓存是否过期 logger.info(f"已缓存条目: {key}, TTL: {ttl}s")
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
logger.debug(f"缓存已过期: {cache_key}")
cache_file.unlink() # 删除过期缓存
else:
logger.debug(f"精确匹配缓存: {tool_name}")
return cache_data["result"]
except (json.JSONDecodeError, KeyError, ValueError) as e: def clear_l1(self):
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}") """清空L1缓存。"""
# 删除损坏的缓存文件 self.l1_kv_cache.clear()
if cache_file.exists(): self.l1_vector_index.reset()
cache_file.unlink() self.l1_vector_id_to_key.clear()
logger.info("L1 (内存+FAISS) 缓存已清空。")
# 如果精确匹配失败,尝试近似匹配 def clear_l2(self):
return self._find_similar_cache(tool_name, function_args) """清空L2缓存。"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM cache")
conn.commit()
self.chroma_client.delete_collection(name="semantic_cache")
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
logger.info("L2 (SQLite & ChromaDB) 缓存已清空。")
def set( def clear_all(self):
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any] """清空所有缓存。"""
) -> None: self.clear_l1()
""" self.clear_l2()
将结果保存到缓存 logger.info("所有缓存层级已清空。")
Args: # 全局实例
tool_name: 工具名称 tool_cache = CacheManager()
function_args: 函数参数
result: 缓存结果
"""
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
cache_data = {
"tool_name": tool_name,
"function_args": function_args,
"result": result,
"timestamp": datetime.now().isoformat(),
}
try:
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
except Exception as e:
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
def clear_expired(self) -> int:
"""
清理过期缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
cache_file.unlink()
removed_count += 1
logger.debug(f"删除过期缓存: {cache_file}")
except Exception as e:
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
# 删除损坏的文件
try:
cache_file.unlink()
removed_count += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
return removed_count
def clear_all(self) -> int:
"""
清空所有缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
cache_file.unlink()
removed_count += 1
except Exception as e:
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
return removed_count
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
缓存统计信息字典
"""
total_files = 0
expired_files = 0
total_size = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
total_files += 1
total_size += cache_file.stat().st_size
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
expired_files += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
expired_files += 1 # 损坏的文件也算作过期
return {
"total_files": total_files,
"expired_files": expired_files,
"total_size_bytes": total_size,
"cache_dir": str(self.cache_dir),
"max_age_hours": self.max_age.total_seconds() / 3600,
"similarity_threshold": self.similarity_threshold,
}
tool_cache = ToolCache()

View File

@@ -0,0 +1,344 @@
import json
import hashlib
import re
from typing import Any, Dict, Optional
from datetime import datetime, timedelta
from pathlib import Path
from difflib import SequenceMatcher
from src.common.logger import get_logger
logger = get_logger("cache_manager")
class ToolCache:
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配"""
def __init__(
self,
cache_dir: str = "data/tool_cache",
max_age_hours: int = 24,
similarity_threshold: float = 0.65,
):
"""
初始化缓存管理器
Args:
cache_dir: 缓存目录路径
max_age_hours: 缓存最大存活时间(小时)
similarity_threshold: 近似匹配的相似度阈值 (0-1)
"""
self.cache_dir = Path(cache_dir)
self.max_age = timedelta(hours=max_age_hours)
self.max_age_seconds = max_age_hours * 3600
self.similarity_threshold = similarity_threshold
self.cache_dir.mkdir(parents=True, exist_ok=True)
@staticmethod
def _normalize_query(query: str) -> str:
"""
标准化查询文本,用于相似度比较
Args:
query: 原始查询文本
Returns:
标准化后的查询文本
"""
if not query:
return ""
# 纯 Python 实现
normalized = query.lower()
normalized = re.sub(r"[^\w\s]", " ", normalized)
normalized = " ".join(normalized.split())
return normalized
def _calculate_similarity(self, text1: str, text2: str) -> float:
"""
计算两个文本的相似度
Args:
text1: 文本1
text2: 文本2
Returns:
相似度分数 (0-1)
"""
if not text1 or not text2:
return 0.0
# 纯 Python 实现
norm_text1 = self._normalize_query(text1)
norm_text2 = self._normalize_query(text2)
if norm_text1 == norm_text2:
return 1.0
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
@staticmethod
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
"""
生成缓存键
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
缓存键字符串
"""
# 将参数排序后序列化,确保相同参数产生相同的键
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
# 纯 Python 实现
cache_string = f"{tool_name}:{sorted_args}"
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
def _get_cache_file_path(self, cache_key: str) -> Path:
"""获取缓存文件路径"""
return self.cache_dir / f"{cache_key}.json"
def _is_cache_expired(self, cached_time: datetime) -> bool:
"""检查缓存是否过期"""
return datetime.now() - cached_time > self.max_age
def _find_similar_cache(
self, tool_name: str, function_args: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
查找相似的缓存条目
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
相似的缓存结果如果不存在则返回None
"""
query = function_args.get("query", "")
if not query:
return None
candidates = []
cache_data_list = []
# 遍历所有缓存文件,收集候选项
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
# 检查是否是同一个工具
if cache_data.get("tool_name") != tool_name:
continue
# 检查缓存是否过期
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
continue
# 检查其他参数是否匹配除了query
cached_args = cache_data.get("function_args", {})
args_match = True
for key, value in function_args.items():
if key != "query" and cached_args.get(key) != value:
args_match = False
break
if not args_match:
continue
# 收集候选项
cached_query = cached_args.get("query", "")
candidates.append((cached_query, len(cache_data_list)))
cache_data_list.append(cache_data)
except Exception as e:
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
continue
if not candidates:
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None
# 纯 Python 实现
best_match = None
best_similarity = 0.0
for cached_query, index in candidates:
similarity = self._calculate_similarity(query, cached_query)
if similarity > best_similarity and similarity >= self.similarity_threshold:
best_similarity = similarity
best_match = cache_data_list[index]
if best_match is not None:
cached_query = best_match["function_args"].get("query", "")
logger.info(
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
)
return best_match["result"]
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None
def get(
self, tool_name: str, function_args: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
从缓存获取结果,支持精确匹配和近似匹配
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
缓存的结果如果不存在或已过期则返回None
"""
# 首先尝试精确匹配
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
if cache_file.exists():
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
# 检查缓存是否过期
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
logger.debug(f"缓存已过期: {cache_key}")
cache_file.unlink() # 删除过期缓存
else:
logger.debug(f"精确匹配缓存: {tool_name}")
return cache_data["result"]
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
# 删除损坏的缓存文件
if cache_file.exists():
cache_file.unlink()
# 如果精确匹配失败,尝试近似匹配
return self._find_similar_cache(tool_name, function_args)
def set(
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any]
) -> None:
"""
将结果保存到缓存
Args:
tool_name: 工具名称
function_args: 函数参数
result: 缓存结果
"""
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
cache_data = {
"tool_name": tool_name,
"function_args": function_args,
"result": result,
"timestamp": datetime.now().isoformat(),
}
try:
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
except Exception as e:
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
def clear_expired(self) -> int:
"""
清理过期缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
cache_file.unlink()
removed_count += 1
logger.debug(f"删除过期缓存: {cache_file}")
except Exception as e:
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
# 删除损坏的文件
try:
cache_file.unlink()
removed_count += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
return removed_count
def clear_all(self) -> int:
"""
清空所有缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
cache_file.unlink()
removed_count += 1
except Exception as e:
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
return removed_count
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
缓存统计信息字典
"""
total_files = 0
expired_files = 0
total_size = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
total_files += 1
total_size += cache_file.stat().st_size
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
expired_files += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
expired_files += 1 # 损坏的文件也算作过期
return {
"total_files": total_files,
"expired_files": expired_files,
"total_size_bytes": total_size,
"cache_dir": str(self.cache_dir),
"max_age_hours": self.max_age.total_seconds() / 3600,
"similarity_threshold": self.similarity_threshold,
}
tool_cache = ToolCache()

View File

@@ -42,7 +42,8 @@ from src.config.official_configs import (
ExaConfig, ExaConfig,
WebSearchConfig, WebSearchConfig,
TavilyConfig, TavilyConfig,
AntiPromptInjectionConfig AntiPromptInjectionConfig,
PluginsConfig
) )
from .api_ada_configs import ( from .api_ada_configs import (
@@ -365,6 +366,7 @@ class Config(ConfigBase):
exa: ExaConfig = field(default_factory=lambda: ExaConfig()) exa: ExaConfig = field(default_factory=lambda: ExaConfig())
web_search: WebSearchConfig = field(default_factory=lambda: WebSearchConfig()) web_search: WebSearchConfig = field(default_factory=lambda: WebSearchConfig())
tavily: TavilyConfig = field(default_factory=lambda: TavilyConfig()) tavily: TavilyConfig = field(default_factory=lambda: TavilyConfig())
plugins: PluginsConfig = field(default_factory=lambda: PluginsConfig())
@dataclass @dataclass

View File

@@ -1,7 +1,7 @@
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional, Dict from typing import Literal, Optional
from src.config.config_base import ConfigBase from src.config.config_base import ConfigBase
@@ -864,43 +864,6 @@ class ScheduleConfig(ConfigBase):
guidelines: Optional[str] = field(default=None) guidelines: Optional[str] = field(default=None)
"""日程生成指导原则如果为None则使用默认指导原则""" """日程生成指导原则如果为None则使用默认指导原则"""
@dataclass
class VideoAnalysisConfig(ConfigBase):
"""视频分析配置类"""
enable: bool = True
"""是否启用视频分析功能"""
analysis_mode: Literal["frame_by_frame", "batch_frames", "auto"] = "auto"
"""分析模式:逐帧分析(慢但详细)、批量分析(快但可能略简单)或自动选择"""
max_frames: int = 8
"""最大分析帧数"""
frame_quality: int = 85
"""帧图像JPEG质量 (1-100)"""
max_image_size: int = 800
"""单帧最大图像尺寸(像素)"""
batch_analysis_prompt: str = field(default="""请分析这个视频的内容。这些图片是从视频中按时间顺序提取的关键帧。
请提供详细的分析,包括:
1. 视频的整体内容和主题
2. 主要人物、对象和场景描述
3. 动作、情节和时间线发展
4. 视觉风格和艺术特点
5. 整体氛围和情感表达
6. 任何特殊的视觉效果或文字内容
请用中文回答,分析要详细准确。""")
"""批量分析时使用的提示词"""
enable_frame_timing: bool = True
"""是否在分析中包含帧的时间信息"""
@dataclass @dataclass
class DependencyManagementConfig(ConfigBase): class DependencyManagementConfig(ConfigBase):
"""插件Python依赖管理配置类""" """插件Python依赖管理配置类"""
@@ -988,10 +951,10 @@ class VideoAnalysisConfig(ConfigBase):
"""批量分析时使用的提示词""" """批量分析时使用的提示词"""
@dataclass @dataclass
class WebSearchConfig(ConfigBase): class WebSearchConfig(ConfigBase):
"""联网搜索组件配置类""" """联网搜索组件配置类"""
enable_web_search_tool: bool = True enable_web_search_tool: bool = True
"""是否启用联网搜索工具""" """是否启用联网搜索工具"""
@@ -1064,4 +1027,13 @@ class AntiPromptInjectionConfig(ConfigBase):
"""加盾消息前缀""" """加盾消息前缀"""
shield_suffix: str = " 🛡️" shield_suffix: str = " 🛡️"
"""加盾消息后缀""" """加盾消息后缀"""
@dataclass
class PluginsConfig(ConfigBase):
"""插件配置"""
centralized_config: bool = field(
default=True, metadata={"description": "是否启用插件配置集中化管理"}
)

View File

@@ -1,32 +1,54 @@
import asyncio import asyncio
import io import io
import base64 import base64
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict, Union
from google import genai import google.generativeai as genai
from google.genai.types import ( from google.generativeai.types import (
Content,
Part,
FunctionDeclaration,
GenerateContentResponse, GenerateContentResponse,
ContentListUnion,
ContentUnion,
ThinkingConfig,
Tool,
GenerateContentConfig,
EmbedContentResponse,
EmbedContentConfig,
SafetySetting,
HarmCategory, HarmCategory,
HarmBlockThreshold, HarmBlockThreshold,
) )
from google.genai.errors import (
ClientError, try:
ServerError, # 尝试从较新的API导入
UnknownFunctionCallArgumentError, from google.generativeai import configure
UnsupportedFunctionError, from google.generativeai.types import SafetySetting, GenerationConfig
FunctionInvocationError, except ImportError:
) # 回退到基本类型
SafetySetting = Dict
GenerationConfig = Dict
# 定义兼容性类型
ContentDict = Dict
PartDict = Dict
ToolDict = Dict
FunctionDeclaration = Dict
Tool = Dict
ContentListUnion = List[Dict]
ContentUnion = Dict
Content = Dict
Part = Dict
ThinkingConfig = Dict
GenerateContentConfig = Dict
EmbedContentConfig = Dict
EmbedContentResponse = Dict
# 定义异常类型
class ClientError(Exception):
pass
class ServerError(Exception):
pass
class UnknownFunctionCallArgumentError(Exception):
pass
class UnsupportedFunctionError(Exception):
pass
class FunctionInvocationError(Exception):
pass
from src.config.api_ada_configs import ModelInfo, APIProvider from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -44,18 +66,17 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("Gemini客户端") logger = get_logger("Gemini客户端")
gemini_safe_settings = [ SAFETY_SETTINGS = [
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE), {"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE},
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE), {"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE), {"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE), {"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE},
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
] ]
def _convert_messages( def _convert_messages(
messages: list[Message], messages: list[Message],
) -> tuple[ContentListUnion, list[str] | None]: ) -> tuple[List[Dict], list[str] | None]:
""" """
转换消息格式 - 将消息转换为Gemini API所需的格式 转换消息格式 - 将消息转换为Gemini API所需的格式
:param messages: 消息列表 :param messages: 消息列表
@@ -81,7 +102,7 @@ def _convert_messages(
normalized_format = format_mapping.get(image_format.lower(), image_format.lower()) normalized_format = format_mapping.get(image_format.lower(), image_format.lower())
return f"image/{normalized_format}" return f"image/{normalized_format}"
def _convert_message_item(message: Message) -> Content: def _convert_message_item(message: Message) -> Dict:
""" """
转换单个消息格式除了system和tool类型的消息 转换单个消息格式除了system和tool类型的消息
:param message: 消息对象 :param message: 消息对象
@@ -96,22 +117,25 @@ def _convert_messages(
# 添加Content # 添加Content
if isinstance(message.content, str): if isinstance(message.content, str):
content = [Part.from_text(text=message.content)] content = [{"text": message.content}]
elif isinstance(message.content, list): elif isinstance(message.content, list):
content: List[Part] = [] content = []
for item in message.content: for item in message.content:
if isinstance(item, tuple): if isinstance(item, tuple):
content.append( content.append({
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=_get_correct_mime_type(item[0])) "inline_data": {
) "mime_type": _get_correct_mime_type(item[0]),
"data": item[1]
}
})
elif isinstance(item, str): elif isinstance(item, str):
content.append(Part.from_text(text=item)) content.append({"text": item})
else: else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象") raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
return Content(role=role, parts=content) return {"role": role, "parts": content}
temp_list: list[ContentUnion] = [] temp_list: List[Dict] = []
system_instructions: list[str] = [] system_instructions: list[str] = []
for message in messages: for message in messages:
if message.role == RoleType.System: if message.role == RoleType.System:
@@ -338,13 +362,10 @@ def _default_normal_response_parser(
@client_registry.register_client_class("gemini") @client_registry.register_client_class("gemini")
class GeminiClient(BaseClient): class GeminiClient(BaseClient):
client: genai.Client
def __init__(self, api_provider: APIProvider): def __init__(self, api_provider: APIProvider):
super().__init__(api_provider) super().__init__(api_provider)
self.client = genai.Client( # 配置 Google Generative AI
api_key=api_provider.api_key, genai.configure(api_key=api_provider.api_key)
) # 这里和openai不一样gemini会自己决定自己是否需要retry
async def get_response( async def get_response(
self, self,
@@ -396,18 +417,18 @@ class GeminiClient(BaseClient):
"max_output_tokens": max_tokens, "max_output_tokens": max_tokens,
"temperature": temperature, "temperature": temperature,
"response_modalities": ["TEXT"], "response_modalities": ["TEXT"],
"thinking_config": ThinkingConfig( "thinking_config": {
include_thoughts=True, "include_thoughts": True,
thinking_budget=( "thinking_budget": (
extra_params["thinking_budget"] extra_params["thinking_budget"]
if extra_params and "thinking_budget" in extra_params if extra_params and "thinking_budget" in extra_params
else int(max_tokens / 2) # 默认思考预算为最大token数的一半防止空回复 else int(max_tokens / 2) # 默认思考预算为最大token数的一半防止空回复
), ),
), },
"safety_settings": gemini_safe_settings, # 防止空回复问题 "safety_settings": SAFETY_SETTINGS, # 防止空回复问题
} }
if tools: if tools:
generation_config_dict["tools"] = Tool(function_declarations=tools) generation_config_dict["tools"] = {"function_declarations": tools}
if messages[1]: if messages[1]:
# 如果有system消息则将其添加到配置中 # 如果有system消息则将其添加到配置中
generation_config_dict["system_instructions"] = messages[1] generation_config_dict["system_instructions"] = messages[1]
@@ -417,15 +438,18 @@ class GeminiClient(BaseClient):
generation_config_dict["response_mime_type"] = "application/json" generation_config_dict["response_mime_type"] = "application/json"
generation_config_dict["response_schema"] = response_format.to_dict() generation_config_dict["response_schema"] = response_format.to_dict()
generation_config = GenerateContentConfig(**generation_config_dict) generation_config = generation_config_dict
try: try:
# 创建模型实例
model = genai.GenerativeModel(model_info.model_identifier)
if model_info.force_stream_mode: if model_info.force_stream_mode:
req_task = asyncio.create_task( req_task = asyncio.create_task(
self.client.aio.models.generate_content_stream( model.generate_content_async(
model=model_info.model_identifier,
contents=messages[0], contents=messages[0],
config=generation_config, generation_config=generation_config,
stream=True
) )
) )
while not req_task.done(): while not req_task.done():
@@ -437,10 +461,9 @@ class GeminiClient(BaseClient):
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
else: else:
req_task = asyncio.create_task( req_task = asyncio.create_task(
self.client.aio.models.generate_content( model.generate_content_async(
model=model_info.model_identifier,
contents=messages[0], contents=messages[0],
config=generation_config, generation_config=generation_config
) )
) )
while not req_task.done(): while not req_task.done():
@@ -451,17 +474,18 @@ class GeminiClient(BaseClient):
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
resp, usage_record = async_response_parser(req_task.result()) resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.code, e.message) from None
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
) as e:
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
except Exception as e: except Exception as e:
raise NetworkConnectionError() from e # 处理Google Generative AI异常
if "rate limit" in str(e).lower():
raise RespNotOkException(429, "请求频率过高,请稍后再试") from None
elif "quota" in str(e).lower():
raise RespNotOkException(429, "配额已用完") from None
elif "invalid" in str(e).lower() or "bad request" in str(e).lower():
raise RespNotOkException(400, f"请求无效:{str(e)}") from None
elif "permission" in str(e).lower() or "forbidden" in str(e).lower():
raise RespNotOkException(403, "权限不足") from None
else:
raise NetworkConnectionError() from e
if usage_record: if usage_record:
resp.usage = UsageRecord( resp.usage = UsageRecord(
@@ -535,7 +559,7 @@ class GeminiClient(BaseClient):
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024 extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
), ),
), ),
"safety_settings": gemini_safe_settings, "safety_settings": SAFETY_SETTINGS,
} }
generate_content_config = GenerateContentConfig(**generation_config_dict) generate_content_config = GenerateContentConfig(**generation_config_dict)
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**." prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."

View File

@@ -17,6 +17,7 @@ from src.common.server import get_global_server, Server
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
from rich.traceback import install from rich.traceback import install
from src.manager.schedule_manager import schedule_manager from src.manager.schedule_manager import schedule_manager
from src.common.cache_manager import tool_cache
# from src.api.main import start_api_server # from src.api.main import start_api_server
# 导入新的插件管理器和热重载管理器 # 导入新的插件管理器和热重载管理器

View File

@@ -207,6 +207,9 @@ class VideoAnalyzer:
"""批量分析所有帧""" """批量分析所有帧"""
self.logger.info(f"开始批量分析{len(frames)}") self.logger.info(f"开始批量分析{len(frames)}")
if not frames:
return "❌ 没有可分析的帧"
# 构建提示词 # 构建提示词
prompt = self.batch_analysis_prompt prompt = self.batch_analysis_prompt
@@ -214,28 +217,77 @@ class VideoAnalyzer:
prompt += f"\n\n用户问题: {user_question}" prompt += f"\n\n用户问题: {user_question}"
# 添加帧信息到提示词 # 添加帧信息到提示词
frame_info = []
for i, (frame_base64, timestamp) in enumerate(frames): for i, (frame_base64, timestamp) in enumerate(frames):
if self.enable_frame_timing: if self.enable_frame_timing:
prompt += f"\n\n{i+1}帧 (时间: {timestamp:.2f}s):" frame_info.append(f"{i+1}帧 (时间: {timestamp:.2f}s)")
else:
frame_info.append(f"{i+1}")
prompt += f"\n\n视频包含{len(frames)}帧图像:{', '.join(frame_info)}"
prompt += "\n\n请基于所有提供的帧图像进行综合分析,描述视频的完整内容和故事发展。"
try: try:
# 使用第一帧进行分析(批量模式暂时使用单帧,后续可以优化为真正的多图片分析 # 尝试使用多图片分析
if frames: response = await self._analyze_multiple_frames(frames, prompt)
frame_base64, _ = frames[0] self.logger.info("✅ 批量多图片分析完成")
prompt += f"\n\n注意当前显示的是第1帧请基于这一帧和提示词进行分析。视频共有{len(frames)}帧。" return response
except Exception as e:
self.logger.error(f"❌ 多图片分析失败: {e}")
# 降级到单帧分析
self.logger.warning("降级到单帧分析模式")
try:
frame_base64, timestamp = frames[0]
fallback_prompt = prompt + f"\n\n注意由于技术限制当前仅显示第1帧 (时间: {timestamp:.2f}s),视频共有{len(frames)}帧。请基于这一帧进行分析。"
response, _ = await self.video_llm.generate_response_for_image( response, _ = await self.video_llm.generate_response_for_image(
prompt=prompt, prompt=fallback_prompt,
image_base64=frame_base64, image_base64=frame_base64,
image_format="jpeg" image_format="jpeg"
) )
self.logger.info("批量分析完成") self.logger.info("降级的单帧分析完成")
return response return response
else: except Exception as fallback_e:
return "❌ 没有可分析的帧" self.logger.error(f"❌ 降级分析也失败: {fallback_e}")
except Exception as e: raise
self.logger.error(f"❌ 批量分析失败: {e}")
raise async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str:
"""使用多图片分析方法"""
self.logger.info(f"开始构建包含{len(frames)}帧的多图片分析请求")
# 导入MessageBuilder用于构建多图片消息
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.utils_model import RequestType
# 构建包含多张图片的消息
message_builder = MessageBuilder().set_role(RoleType.User).add_text_content(prompt)
# 添加所有帧图像
for i, (frame_base64, timestamp) in enumerate(frames):
message_builder.add_image_content("jpeg", frame_base64)
# self.logger.info(f"已添加第{i+1}帧到分析请求 (时间: {timestamp:.2f}s, 图片大小: {len(frame_base64)} chars)")
message = message_builder.build()
self.logger.info(f"✅ 多图片消息构建完成,包含{len(frames)}张图片")
# 获取模型信息和客户端
model_info, api_provider, client = await self.video_llm._get_best_model_and_client()
self.logger.info(f"使用模型: {model_info.name} 进行多图片分析")
# 直接执行多图片请求
api_response = await self.video_llm._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=[message],
temperature=None,
max_tokens=None
)
self.logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
return api_response.content or "❌ 未获得响应内容"
async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
"""逐帧分析并汇总""" """逐帧分析并汇总"""
@@ -355,7 +407,7 @@ class VideoAnalyzer:
# 计算视频hash值 # 计算视频hash值
video_hash = self._calculate_video_hash(video_bytes) video_hash = self._calculate_video_hash(video_bytes)
logger.info(f"视频hash: {video_hash[:16]}...") # logger.info(f"视频hash: {video_hash[:16]}...")
# 检查数据库中是否已存在该视频的分析结果 # 检查数据库中是否已存在该视频的分析结果
existing_video = self._check_video_exists(video_hash) existing_video = self._check_video_exists(video_hash)

View File

@@ -1,13 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Any, Union from typing import Dict, List, Any, Union
import os import os
import inspect
import toml import toml
import json import json
import shutil import shutil
import datetime import datetime
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.plugin_system.base.component_types import ( from src.plugin_system.base.component_types import (
PluginInfo, PluginInfo,
PythonDependency, PythonDependency,
@@ -71,6 +71,7 @@ class PluginBase(ABC):
self.config: Dict[str, Any] = {} # 插件配置 self.config: Dict[str, Any] = {} # 插件配置
self.plugin_dir = plugin_dir # 插件目录路径 self.plugin_dir = plugin_dir # 插件目录路径
self.log_prefix = f"[Plugin:{self.plugin_name}]" self.log_prefix = f"[Plugin:{self.plugin_name}]"
self._is_enabled = self.enable_plugin # 从插件定义中获取默认启用状态
# 加载manifest文件 # 加载manifest文件
self._load_manifest() self._load_manifest()
@@ -100,7 +101,7 @@ class PluginBase(ABC):
description=self.plugin_description, description=self.plugin_description,
version=self.plugin_version, version=self.plugin_version,
author=self.plugin_author, author=self.plugin_author,
enabled=self.enable_plugin, enabled=self._is_enabled,
is_built_in=False, is_built_in=False,
config_file=self.config_file_name or "", config_file=self.config_file_name or "",
dependencies=self.dependencies.copy(), dependencies=self.dependencies.copy(),
@@ -453,86 +454,91 @@ class PluginBase(ABC):
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True) logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
def _load_plugin_config(self): # sourcery skip: extract-method def _load_plugin_config(self): # sourcery skip: extract-method
"""加载插件配置文件,支持版本检查和自动迁移""" """
加载插件配置文件,实现集中化管理和自动迁移。
处理逻辑:
1. 确定插件模板配置文件路径(位于插件目录内)。
2. 如果模板不存在,则在插件目录内生成一份默认配置。
3. 确定用户配置文件路径(位于 `config/plugins/` 目录下)。
4. 如果用户配置文件不存在,则从插件目录复制模板文件过去。
5. 加载用户配置文件,并进行版本检查和自动迁移(如果需要)。
6. 最终加载的配置是用户配置文件。
"""
if not self.config_file_name: if not self.config_file_name:
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载") logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
return return
# 优先使用传入的插件目录路径 # 1. 确定插件模板配置文件路径
if self.plugin_dir: template_config_path = os.path.join(self.plugin_dir, self.config_file_name)
plugin_dir = self.plugin_dir
else: # 2. 如果模板不存在,则在插件目录内生成
# fallback尝试从类的模块信息获取路径 if not os.path.exists(template_config_path):
logger.info(f"{self.log_prefix} 插件目录缺少配置文件 {template_config_path},将生成默认配置。")
self._generate_and_save_default_config(template_config_path)
# 3. 确定用户配置文件路径
plugin_config_dir = os.path.join(CONFIG_DIR, "plugins", self.plugin_name)
user_config_path = os.path.join(plugin_config_dir, self.config_file_name)
# 确保用户插件配置目录存在
os.makedirs(plugin_config_dir, exist_ok=True)
# 4. 如果用户配置文件不存在,从模板复制
if not os.path.exists(user_config_path):
try: try:
plugin_module_path = inspect.getfile(self.__class__) shutil.copy2(template_config_path, user_config_path)
plugin_dir = os.path.dirname(plugin_module_path) logger.info(f"{self.log_prefix} 已从模板创建用户配置文件: {user_config_path}")
except (TypeError, OSError): except IOError as e:
# 最后的fallback从模块的__file__属性获取 logger.error(f"{self.log_prefix} 复制配置文件失败: {e}", exc_info=True)
module = inspect.getmodule(self.__class__) # 如果复制失败,后续将无法加载,直接返回
if module and hasattr(module, "__file__") and module.__file__: return
plugin_dir = os.path.dirname(module.__file__)
else:
logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载")
return
config_file_path = os.path.join(plugin_dir, self.config_file_name) # 检查最终的用户配置文件是否存在
if not os.path.exists(user_config_path):
# 如果配置文件不存在,生成默认配置 logger.warning(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在且无法创建。")
if not os.path.exists(config_file_path):
logger.info(f"{self.log_prefix} 配置文件 {config_file_path} 不存在,将生成默认配置。")
self._generate_and_save_default_config(config_file_path)
if not os.path.exists(config_file_path):
logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在且无法生成。")
return return
file_ext = os.path.splitext(self.config_file_name)[1].lower() # 5. 加载、检查和迁移用户配置文件
_, file_ext = os.path.splitext(self.config_file_name)
if file_ext == ".toml": if file_ext.lower() != ".toml":
# 加载现有配置
with open(config_file_path, "r", encoding="utf-8") as f:
existing_config = toml.load(f) or {}
# 检查配置版本
current_version = self._get_current_config_version(existing_config)
# 如果配置文件没有版本信息,跳过版本检查
if current_version == "0.0.0":
logger.debug(f"{self.log_prefix} 配置文件无版本信息,跳过版本检查")
self.config = existing_config
else:
expected_version = self._get_expected_config_version()
if current_version != expected_version:
logger.info(
f"{self.log_prefix} 检测到配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
)
# 生成新的默认配置结构
new_config_structure = self._generate_config_from_schema()
# 迁移旧配置值到新结构
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
# 保存迁移后的配置
self._save_config_to_file(migrated_config, config_file_path)
logger.info(f"{self.log_prefix} 配置文件已从 v{current_version} 更新到 v{expected_version}")
self.config = migrated_config
else:
logger.debug(f"{self.log_prefix} 配置版本匹配 (v{current_version}),直接加载")
self.config = existing_config
logger.debug(f"{self.log_prefix} 配置已从 {config_file_path} 加载")
# 从配置中更新 enable_plugin
if "plugin" in self.config and "enabled" in self.config["plugin"]:
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
else:
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml") logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
self.config = {} self.config = {}
return
try:
with open(user_config_path, "r", encoding="utf-8") as f:
existing_config = toml.load(f) or {}
except Exception as e:
logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True)
self.config = {}
return
current_version = self._get_current_config_version(existing_config)
expected_version = self._get_expected_config_version()
if current_version == "0.0.0":
logger.debug(f"{self.log_prefix} 用户配置文件无版本信息,跳过版本检查")
self.config = existing_config
elif current_version != expected_version:
logger.info(
f"{self.log_prefix} 检测到用户配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
)
new_config_structure = self._generate_config_from_schema()
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
self._save_config_to_file(migrated_config, user_config_path)
logger.info(f"{self.log_prefix} 用户配置文件已从 v{current_version} 更新到 v{expected_version}")
self.config = migrated_config
else:
logger.debug(f"{self.log_prefix} 用户配置版本匹配 (v{current_version}),直接加载")
self.config = existing_config
logger.debug(f"{self.log_prefix} 配置已从 {user_config_path} 加载")
# 从配置中更新 enable_plugin 状态
if "plugin" in self.config and "enabled" in self.config["plugin"]:
self._is_enabled = self.config["plugin"]["enabled"]
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self._is_enabled}")
def _check_dependencies(self) -> bool: def _check_dependencies(self) -> bool:
"""检查插件依赖""" """检查插件依赖"""

View File

@@ -0,0 +1,134 @@
# -*- coding: utf-8 -*-
"""
本模块包含一个从Python包的“安装名”到其“导入名”的映射。
这个映射表主要用于解决一个常见问题某些Python包通过pip安装时使用的名称
与在代码中`import`时使用的名称不一致。例如,我们使用`pip install beautifulsoup4`
来安装,但在代码中却需要`import bs4`。
当插件系统检查依赖时,如果一个开发者只简单地在依赖列表中写了安装名
(例如 "beautifulsoup4"),标准的导入检查`import('beautifulsoup4')`会失败。
通过这个映射表,依赖管理器可以在初次导入检查失败后,查询是否存在一个
已知的别名(例如 "bs4"),并尝试使用该别名进行二次导入检查。
这样做的好处是:
1. 提升开发者体验:插件开发者无需强制记忆这些特殊的名称对应关系,或者强制
使用更复杂的`PythonDependency`对象来分别指定安装名和导入名。
2. 增强系统健壮性:减少因名称不一致导致的插件加载失败问题。
3. 兼容性:对遵循最佳实践、正确指定了`package_name`和`install_name`的
开发者没有任何影响。
开发者可以持续向这个列表中贡献新的映射关系,使其更加完善。
"""
INSTALL_NAME_TO_IMPORT_NAME = {
# ============== 数据科学与机器学习 (Data Science & Machine Learning) ==============
"scikit-learn": "sklearn", # 机器学习库
"scikit-image": "skimage", # 图像处理库
"opencv-python": "cv2", # OpenCV 计算机视觉库
"opencv-contrib-python": "cv2", # OpenCV 扩展模块
"tensorflow-gpu": "tensorflow", # TensorFlow GPU版本
"tensorboardx": "tensorboardX", # TensorBoard 的封装
"torchvision": "torchvision", # PyTorch 视觉库 (通常与 torch 一起)
"torchaudio": "torchaudio", # PyTorch 音频库
"catboost": "catboost", # CatBoost 梯度提升库
"lightgbm": "lightgbm", # LightGBM 梯度提升库
"xgboost": "xgboost", # XGBoost 梯度提升库
"imbalanced-learn": "imblearn", # 处理不平衡数据集
"seqeval": "seqeval", # 序列标注评估
"gensim": "gensim", # 主题建模和NLP
"nltk": "nltk", # 自然语言工具包
"spacy": "spacy", # 工业级自然语言处理
"fuzzywuzzy": "fuzzywuzzy", # 模糊字符串匹配
"python-levenshtein": "Levenshtein", # Levenshtein 距离计算
# ============== Web开发与API (Web Development & API) ==============
"python-socketio": "socketio", # Socket.IO 服务器和客户端
"python-engineio": "engineio", # Engine.IO 底层库
"aiohttp": "aiohttp", # 异步HTTP客户端/服务器
"python-multipart": "multipart", # 解析 multipart/form-data
"uvloop": "uvloop", # 高性能asyncio事件循环
"httptools": "httptools", # 高性能HTTP解析器
"websockets": "websockets", # WebSocket实现
"fastapi": "fastapi", # 高性能Web框架
"starlette": "starlette", # ASGI框架
"uvicorn": "uvicorn", # ASGI服务器
"gunicorn": "gunicorn", # WSGI服务器
"django-rest-framework": "rest_framework", # Django REST框架
"django-cors-headers": "corsheaders", # Django CORS处理
"flask-jwt-extended": "flask_jwt_extended", # Flask JWT扩展
"flask-sqlalchemy": "flask_sqlalchemy", # Flask SQLAlchemy扩展
"flask-migrate": "flask_migrate", # Flask Alembic迁移扩展
"python-jose": "jose", # JOSE (JWT, JWS, JWE) 实现
"passlib": "passlib", # 密码哈希库
"bcrypt": "bcrypt", # Bcrypt密码哈希
# ============== 数据库 (Database) ==============
"mysql-connector-python": "mysql.connector", # MySQL官方驱动
"psycopg2-binary": "psycopg2", # PostgreSQL驱动 (二进制)
"pymongo": "pymongo", # MongoDB驱动
"redis": "redis", # Redis客户端
"aioredis": "aioredis", # 异步Redis客户端
"sqlalchemy": "sqlalchemy", # SQL工具包和ORM
"alembic": "alembic", # SQLAlchemy数据库迁移工具
"tortoise-orm": "tortoise", # 异步ORM
# ============== 图像与多媒体 (Image & Multimedia) ==============
"Pillow": "PIL", # Python图像处理库 (PIL Fork)
"moviepy": "moviepy", # 视频编辑库
"pydub": "pydub", # 音频处理库
"pycairo": "cairo", # Cairo 2D图形库的Python绑定
"wand": "wand", # ImageMagick的Python绑定
# ============== 解析与序列化 (Parsing & Serialization) ==============
"beautifulsoup4": "bs4", # HTML/XML解析库
"lxml": "lxml", # 高性能HTML/XML解析库
"PyYAML": "yaml", # YAML解析库
"python-dotenv": "dotenv", # .env文件解析
"python-dateutil": "dateutil", # 强大的日期时间解析
"protobuf": "google.protobuf", # Protocol Buffers
"msgpack": "msgpack", # MessagePack序列化
"orjson": "orjson", # 高性能JSON库
"pydantic": "pydantic", # 数据验证和设置管理
# ============== 系统与硬件 (System & Hardware) ==============
"pyserial": "serial", # 串口通信
"pyusb": "usb", # USB访问
"pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异)
"psutil": "psutil", # 系统信息和进程管理
"watchdog": "watchdog", # 文件系统事件监控
"python-gnupg": "gnupg", # GnuPG的Python接口
# ============== 加密与安全 (Cryptography & Security) ==============
"pycrypto": "Crypto", # 加密库 (较旧)
"pycryptodome": "Crypto", # PyCrypto的现代分支
"cryptography": "cryptography", # 现代加密库
"pyopenssl": "OpenSSL", # OpenSSL的Python接口
"service-identity": "service_identity", # 服务身份验证
# ============== 工具与杂项 (Utilities & Miscellaneous) ==============
"setuptools": "setuptools", # 打包工具
"pip": "pip", # 包安装器
"tqdm": "tqdm", # 进度条
"regex": "regex", # 替代的正则表达式引擎
"colorama": "colorama", # 跨平台彩色终端文本
"termcolor": "termcolor", # 终端颜色格式化
"requests-oauthlib": "requests_oauthlib", # OAuth for Requests
"oauthlib": "oauthlib", # 通用OAuth库
"authlib": "authlib", # OAuth和OpenID Connect客户端/服务器
"pyjwt": "jwt", # JSON Web Token实现
"python-editor": "editor", # 程序化地调用编辑器
"prompt-toolkit": "prompt_toolkit", # 构建交互式命令行
"pygments": "pygments", # 语法高亮
"tabulate": "tabulate", # 生成漂亮的表格
"nats-client": "nats", # NATS客户端
"gitpython": "git", # Git的Python接口
"pygithub": "github", # GitHub API v3的Python接口
"python-gitlab": "gitlab", # GitLab API的Python接口
"jira": "jira", # JIRA API的Python接口
"python-jenkins": "jenkins", # Jenkins API的Python接口
"huggingface-hub": "huggingface_hub", # Hugging Face Hub API
"apache-airflow": "airflow", # Airflow工作流管理
"pandas-stubs": "pandas-stubs", # Pandas的类型存根
"data-science-types": "data_science_types", # 数据科学类型
}

View File

@@ -8,6 +8,7 @@ from packaging.requirements import Requirement
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import PythonDependency from src.plugin_system.base.component_types import PythonDependency
from src.plugin_system.utils.dependency_alias import INSTALL_NAME_TO_IMPORT_NAME
logger = get_logger("dependency_manager") logger = get_logger("dependency_manager")
@@ -190,41 +191,58 @@ class DependencyManager:
def _check_single_dependency(self, dep: PythonDependency) -> bool: def _check_single_dependency(self, dep: PythonDependency) -> bool:
"""检查单个依赖是否满足要求""" """检查单个依赖是否满足要求"""
try:
# 尝试导入包 def _try_check(import_name: str) -> bool:
spec = importlib.util.find_spec(dep.package_name) """尝试使用给定的导入名进行检查"""
if spec is None:
return False
# 如果没有版本要求,导入成功就够了
if not dep.version:
return True
# 检查版本要求
try: try:
module = importlib.import_module(dep.package_name) spec = importlib.util.find_spec(import_name)
installed_version = getattr(module, '__version__', None) if spec is None:
return False
if installed_version is None:
# 尝试其他常见的版本属性 # 如果没有版本要求,导入成功就够了
installed_version = getattr(module, 'VERSION', None) if not dep.version:
return True
# 检查版本要求
try:
module = importlib.import_module(import_name)
installed_version = getattr(module, '__version__', None)
if installed_version is None: if installed_version is None:
logger.debug(f"无法获取包 {dep.package_name} 的版本信息,假设满足要求") # 尝试其他常见的版本属性
return True installed_version = getattr(module, 'VERSION', None)
if installed_version is None:
# 解析版本要求 logger.debug(f"无法获取包 {import_name} 的版本信息,假设满足要求")
req = Requirement(f"{dep.package_name}{dep.version}") return True
return version.parse(str(installed_version)) in req.specifier
# 解析版本要求
req = Requirement(f"{dep.package_name}{dep.version}")
return version.parse(str(installed_version)) in req.specifier
except Exception as e:
logger.debug(f"检查包 {import_name} 版本时出错: {e}")
return True # 如果无法检查版本,假设满足要求
except ImportError:
return False
except Exception as e: except Exception as e:
logger.debug(f"检查 {dep.package_name} 版本时出错: {e}") logger.error(f"检查依赖 {import_name} 时发生未知错误: {e}")
return True # 如果无法检查版本,假设满足要求 return False
except ImportError: # 1. 首先尝试使用原始的 package_name 进行检查
return False if _try_check(dep.package_name):
except Exception as e: return True
logger.error(f"检查依赖 {dep.package_name} 时发生未知错误: {e}")
return False # 2. 如果失败,查询别名映射表
# 注意:此时 dep.package_name 可能是 simple "requests" 或 "beautifulsoup4"
import_alias = INSTALL_NAME_TO_IMPORT_NAME.get(dep.package_name)
if import_alias:
logger.debug(f"依赖 '{dep.package_name}' 导入失败, 尝试使用别名 '{import_alias}'")
if _try_check(import_alias):
return True
# 3. 如果别名也失败了,或者没有别名,最终确认失败
return False
def _install_single_package(self, package: str, plugin_name: str = "") -> bool: def _install_single_package(self, package: str, plugin_name: str = "") -> bool:
"""安装单个包""" """安装单个包"""

View File

@@ -9,8 +9,6 @@ from src.common.logger import get_logger
from src.plugin_system import ( from src.plugin_system import (
BasePlugin, BasePlugin,
ComponentInfo, ComponentInfo,
BaseAction,
BaseCommand,
register_plugin register_plugin
) )
from src.plugin_system.base.config_types import ConfigField from src.plugin_system.base.config_types import ConfigField
@@ -68,6 +66,8 @@ class MaiZoneRefactoredPlugin(BasePlugin):
}, },
"schedule": { "schedule": {
"enable_schedule": ConfigField(type=bool, default=False, description="是否启用定时发送"), "enable_schedule": ConfigField(type=bool, default=False, description="是否启用定时发送"),
"random_interval_min_minutes": ConfigField(type=int, default=5, description="随机间隔分钟数下限"),
"random_interval_max_minutes": ConfigField(type=int, default=15, description="随机间隔分钟数上限"),
}, },
"cookie": { "cookie": {
"http_fallback_host": ConfigField(type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"), "http_fallback_host": ConfigField(type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"),

View File

@@ -5,7 +5,6 @@ QQ空间服务模块
""" """
import asyncio import asyncio
import base64
import json import json
import os import os
import random import random
@@ -15,7 +14,6 @@ from typing import Callable, Optional, Dict, Any, List, Tuple
import aiohttp import aiohttp
import bs4 import bs4
import json5 import json5
from src.chat.utils.utils_image import get_image_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import config_api, person_api from src.plugin_system.apis import config_api, person_api

View File

@@ -5,6 +5,7 @@
""" """
import asyncio import asyncio
import datetime import datetime
import random
import traceback import traceback
from typing import Callable from typing import Callable
@@ -91,8 +92,12 @@ class SchedulerService:
result.get("message", "") result.get("message", "")
) )
# 6. 等待5分钟后进行下一次检查 # 6. 计算并等待一个随机的时间间隔
await asyncio.sleep(300) min_minutes = self.get_config("schedule.random_interval_min_minutes", 5)
max_minutes = self.get_config("schedule.random_interval_max_minutes", 15)
wait_seconds = random.randint(min_minutes * 60, max_minutes * 60)
logger.info(f"下一次检查将在 {wait_seconds / 60:.2f} 分钟后进行。")
await asyncio.sleep(wait_seconds)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("定时任务循环被取消。") logger.info("定时任务循环被取消。")
@@ -113,7 +118,7 @@ class SchedulerService:
with get_db_session() as session: with get_db_session() as session:
record = session.query(MaiZoneScheduleStatus).filter( record = session.query(MaiZoneScheduleStatus).filter(
MaiZoneScheduleStatus.datetime_hour == hour_str, MaiZoneScheduleStatus.datetime_hour == hour_str,
MaiZoneScheduleStatus.is_processed == True MaiZoneScheduleStatus.is_processed == True # noqa: E712
).first() ).first()
return record is not None return record is not None
except Exception as e: except Exception as e:
@@ -138,10 +143,10 @@ class SchedulerService:
if record: if record:
# 如果存在,则更新状态 # 如果存在,则更新状态
record.is_processed = True record.is_processed = True # type: ignore
record.processed_at = datetime.datetime.now() record.processed_at = datetime.datetime.now()# type: ignore
record.send_success = success record.send_success = success# type: ignore
record.story_content = content record.story_content = content# type: ignore
else: else:
# 如果不存在,则创建新记录 # 如果不存在,则创建新记录
new_record = MaiZoneScheduleStatus( new_record = MaiZoneScheduleStatus(

View File

@@ -20,6 +20,7 @@ from src.plugin_system import (
PythonDependency PythonDependency
) )
from src.plugin_system.apis import config_api # 添加config_api导入 from src.plugin_system.apis import config_api # 添加config_api导入
from src.common.cache_manager import tool_cache
import httpx import httpx
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@@ -86,6 +87,13 @@ class WebSurfingTool(BaseTool):
if not query: if not query:
return {"error": "搜索查询不能为空。"} return {"error": "搜索查询不能为空。"}
# 检查缓存
query = function_args.get("query")
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__, semantic_query=query)
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result
# 读取搜索配置 # 读取搜索配置
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"]) enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
search_strategy = config_api.get_global_config("web_search.search_strategy", "single") search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
@@ -94,11 +102,18 @@ class WebSurfingTool(BaseTool):
# 根据策略执行搜索 # 根据策略执行搜索
if search_strategy == "parallel": if search_strategy == "parallel":
return await self._execute_parallel_search(function_args, enabled_engines) result = await self._execute_parallel_search(function_args, enabled_engines)
elif search_strategy == "fallback": elif search_strategy == "fallback":
return await self._execute_fallback_search(function_args, enabled_engines) result = await self._execute_fallback_search(function_args, enabled_engines)
else: # single else: # single
return await self._execute_single_search(function_args, enabled_engines) result = await self._execute_single_search(function_args, enabled_engines)
# 保存到缓存
if "error" not in result:
query = function_args.get("query")
await tool_cache.set(self.name, function_args, self.__class__, result, semantic_query=query)
return result
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
"""并行搜索策略:同时使用所有启用的搜索引擎""" """并行搜索策略:同时使用所有启用的搜索引擎"""
@@ -449,6 +464,12 @@ class URLParserTool(BaseTool):
""" """
执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。 执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。
""" """
# 检查缓存
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__)
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result
urls_input = function_args.get("urls") urls_input = function_args.get("urls")
if not urls_input: if not urls_input:
return {"error": "URL列表不能为空。"} return {"error": "URL列表不能为空。"}
@@ -555,6 +576,10 @@ class URLParserTool(BaseTool):
"content": formatted_content, "content": formatted_content,
"errors": error_messages "errors": error_messages
} }
# 保存到缓存
if "error" not in result:
await tool_cache.set(self.name, function_args, self.__class__, result)
return result return result

View File

@@ -382,3 +382,6 @@ enable_url_tool = true # 是否启用URL解析tool
# 搜索引擎配置 # 搜索引擎配置
enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg" enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg"
search_strategy = "single" # 搜索策略: "single"(使用第一个可用引擎), "parallel"(并行使用所有启用的引擎), "fallback"(按顺序尝试,失败则尝试下一个) search_strategy = "single" # 搜索策略: "single"(使用第一个可用引擎), "parallel"(并行使用所有启用的引擎), "fallback"(按顺序尝试,失败则尝试下一个)
[plugins] # 插件配置
centralized_config = true # 是否启用插件配置集中化管理