From 1752831024b6ddc56ee9b1d3004209c761219c31 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sun, 17 Aug 2025 21:26:35 +0800 Subject: [PATCH 01/10] =?UTF-8?q?perf(web=5Fsearch):=20=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E7=BD=91=E7=BB=9C=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8=E4=BB=A5?= =?UTF-8?q?=E5=8A=A0=E9=80=9F=E5=93=8D=E5=BA=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为 `WebSurfingTool` 和 `URLParserTool` 集成 `tool_cache`,避免对相同参数的重复请求。 此更改通过在执行网络搜索或URL解析前检查缓存来优化性能。如果找到先前成功的结果,则立即返回缓存数据,从而显著减少延迟和外部API的使用。仅当缓存未命中时,工具才会继续执行其原始逻辑,并将成功的结果存入缓存以备将来使用。 --- .../built_in/web_search_tool/plugin.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index 05f318cd6..b2212f836 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -20,6 +20,7 @@ from src.plugin_system import ( PythonDependency ) from src.plugin_system.apis import config_api # 添加config_api导入 +from src.common.cache_manager import tool_cache import httpx from bs4 import BeautifulSoup @@ -86,6 +87,12 @@ class WebSurfingTool(BaseTool): if not query: return {"error": "搜索查询不能为空。"} + # 检查缓存 + cached_result = tool_cache.get(self.name, function_args) + if cached_result: + logger.info(f"缓存命中: {self.name} -> {function_args}") + return cached_result + # 读取搜索配置 enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"]) search_strategy = config_api.get_global_config("web_search.search_strategy", "single") @@ -94,11 +101,17 @@ class WebSurfingTool(BaseTool): # 根据策略执行搜索 if search_strategy == "parallel": - return await self._execute_parallel_search(function_args, enabled_engines) + result = await self._execute_parallel_search(function_args, enabled_engines) elif search_strategy == "fallback": - return await self._execute_fallback_search(function_args, enabled_engines) + result = await self._execute_fallback_search(function_args, enabled_engines) else: # single - return await self._execute_single_search(function_args, enabled_engines) + result = await self._execute_single_search(function_args, enabled_engines) + + # 保存到缓存 + if "error" not in result: + tool_cache.set(self.name, function_args, result) + + return result async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: """并行搜索策略:同时使用所有启用的搜索引擎""" @@ -449,6 +462,12 @@ class URLParserTool(BaseTool): """ 执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。 """ + # 检查缓存 + cached_result = tool_cache.get(self.name, function_args) + if cached_result: + logger.info(f"缓存命中: {self.name} -> {function_args}") + return cached_result + urls_input = function_args.get("urls") if not urls_input: return {"error": "URL列表不能为空。"} @@ -555,6 +574,10 @@ class URLParserTool(BaseTool): "content": formatted_content, "errors": error_messages } + + # 保存到缓存 + if "error" not in result: + tool_cache.set(self.name, function_args, result) return result From 51c0d2a1e86db39ac46e74288d3ceed665458514 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sun, 17 Aug 2025 22:18:26 +0800 Subject: [PATCH 02/10] =?UTF-8?q?refactor(cache):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E7=B3=BB=E7=BB=9F=E4=B8=BA=E5=88=86=E5=B1=82?= =?UTF-8?q?=E8=AF=AD=E4=B9=89=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将原有的基于文件的 `ToolCache` 替换为全新的 `CacheManager`,引入了更复杂和高效的分层语义缓存机制。 新系统特性: - **分层缓存**: - L1 缓存: 内存字典 (KV) + FAISS (向量),用于极速访问。 - L2 缓存: SQLite (KV) + ChromaDB (向量),用于持久化存储。 - **语义缓存**: 利用嵌入模型 (Embedding) 对查询进行向量化,实现基于语义相似度的缓存命中,显著提高了缓存命中率。 - **自动失效**: 缓存键包含工具源代码的哈希值,当工具代码更新时,相关缓存会自动失效,避免了脏数据问题。 - **异步支持**: 缓存的 `get` 和 `set` 方法现在是异步的,以适应项目中异步化的工具调用流程。 `web_search_tool` 已更新以使用新的 `CacheManager`,在调用缓存时传递 `tool_class` 和 `semantic_query` 以充分利用新功能。 Co-Authored-By: tt-P607 <68868379+tt-P607@users.noreply.github.com> --- src/common/cache_manager.py | 509 +++++++----------- src/common/cache_manager_backup.py | 344 ++++++++++++ src/main.py | 1 + .../built_in/web_search_tool/plugin.py | 10 +- 4 files changed, 546 insertions(+), 318 deletions(-) create mode 100644 src/common/cache_manager_backup.py diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index ecaff3458..c1141aa66 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -1,344 +1,225 @@ +import time import json +import sqlite3 +import chromadb import hashlib -import re +import inspect +import numpy as np +import faiss from typing import Any, Dict, Optional -from datetime import datetime, timedelta -from pathlib import Path -from difflib import SequenceMatcher - from src.common.logger import get_logger +from src.llm_models.utils_model import LLMRequest +from src.config.api_ada_configs import ModelTaskConfig +from src.config.config import global_config, model_config logger = get_logger("cache_manager") +class CacheManager: + """ + 一个支持分层和语义缓存的通用工具缓存管理器。 + 采用单例模式,确保在整个应用中只有一个缓存实例。 + L1缓存: 内存字典 (KV) + FAISS (Vector)。 + L2缓存: SQLite (KV) + ChromaDB (Vector)。 + """ + _instance = None -class ToolCache: - """工具缓存管理器,用于缓存工具调用结果,支持近似匹配""" + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(CacheManager, cls).__new__(cls) + return cls._instance - def __init__( - self, - cache_dir: str = "data/tool_cache", - max_age_hours: int = 24, - similarity_threshold: float = 0.65, - ): + def __init__(self, default_ttl: int = 3600, db_path: str = "data/cache.db", chroma_path: str = "data/chroma_db"): """ - 初始化缓存管理器 - - Args: - cache_dir: 缓存目录路径 - max_age_hours: 缓存最大存活时间(小时) - similarity_threshold: 近似匹配的相似度阈值 (0-1) + 初始化缓存管理器。 """ - self.cache_dir = Path(cache_dir) - self.max_age = timedelta(hours=max_age_hours) - self.max_age_seconds = max_age_hours * 3600 - self.similarity_threshold = similarity_threshold - self.cache_dir.mkdir(parents=True, exist_ok=True) + if not hasattr(self, '_initialized'): + self.default_ttl = default_ttl + + # L1 缓存 (内存) + self.l1_kv_cache: Dict[str, Dict[str, Any]] = {} + embedding_dim = global_config.lpmm_knowledge.embedding_dimension + self.l1_vector_index = faiss.IndexFlatIP(embedding_dim) + self.l1_vector_id_to_key: Dict[int, str] = {} + + # L2 缓存 (持久化) + self.db_path = db_path + self._init_sqlite() + self.chroma_client = chromadb.PersistentClient(path=chroma_path) + self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache") + + # 嵌入模型 + self.embedding_model = LLMRequest(model_config.model_task_config.embedding) - @staticmethod - def _normalize_query(query: str) -> str: + self._initialized = True + logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (SQLite+ChromaDB)") + + def _init_sqlite(self): + """初始化SQLite数据库和表结构。""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS cache ( + key TEXT PRIMARY KEY, + value TEXT, + expires_at REAL + ) + """) + conn.commit() + + def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str: + """生成确定性的缓存键,包含代码哈希以实现自动失效。""" + try: + source_code = inspect.getsource(tool_class) + code_hash = hashlib.md5(source_code.encode()).hexdigest() + except (TypeError, OSError) as e: + code_hash = "unknown" + logger.warning(f"无法获取 {tool_class.__name__} 的源代码,代码哈希将为 'unknown'。错误: {e}") + try: + sorted_args = json.dumps(function_args, sort_keys=True) + except TypeError: + sorted_args = repr(sorted(function_args.items())) + return f"{tool_name}::{sorted_args}::{code_hash}" + + async def get(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, semantic_query: Optional[str] = None) -> Optional[Any]: """ - 标准化查询文本,用于相似度比较 - - Args: - query: 原始查询文本 - - Returns: - 标准化后的查询文本 + 从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。 """ - if not query: - return "" + # 步骤 1: L1 精确缓存查询 + key = self._generate_key(tool_name, function_args, tool_class) + logger.debug(f"生成的缓存键: {key}") + if semantic_query: + logger.debug(f"使用的语义查询: '{semantic_query}'") - # 纯 Python 实现 - normalized = query.lower() - normalized = re.sub(r"[^\w\s]", " ", normalized) - normalized = " ".join(normalized.split()) - return normalized + if key in self.l1_kv_cache: + entry = self.l1_kv_cache[key] + if time.time() < entry["expires_at"]: + logger.info(f"命中L1键值缓存: {key}") + return entry["data"] + else: + del self.l1_kv_cache[key] - def _calculate_similarity(self, text1: str, text2: str) -> float: - """ - 计算两个文本的相似度 + # 步骤 2: L1/L2 语义和L2精确缓存查询 + query_embedding = None + if semantic_query and self.embedding_model: + embedding_result = await self.embedding_model.get_embedding(semantic_query) + if embedding_result: + query_embedding = np.array([embedding_result], dtype='float32') - Args: - text1: 文本1 - text2: 文本2 + # 步骤 2a: L1 语义缓存 (FAISS) + if query_embedding is not None and self.l1_vector_index.ntotal > 0: + faiss.normalize_L2(query_embedding) + distances, indices = self.l1_vector_index.search(query_embedding, 1) + if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似 + hit_index = indices[0][0] + l1_hit_key = self.l1_vector_id_to_key.get(hit_index) + if l1_hit_key and l1_hit_key in self.l1_kv_cache: + logger.info(f"命中L1语义缓存: {l1_hit_key}") + return self.l1_kv_cache[l1_hit_key]["data"] - Returns: - 相似度分数 (0-1) - """ - if not text1 or not text2: - return 0.0 + # 步骤 2b: L2 精确缓存 (SQLite) + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (key,)) + row = cursor.fetchone() + if row: + value, expires_at = row + if time.time() < expires_at: + logger.info(f"命中L2键值缓存: {key}") + data = json.loads(value) + # 回填 L1 + self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} + return data + else: + cursor.execute("DELETE FROM cache WHERE key = ?", (key,)) + conn.commit() - # 纯 Python 实现 - norm_text1 = self._normalize_query(text1) - norm_text2 = self._normalize_query(text2) + # 步骤 2c: L2 语义缓存 (ChromaDB) + if query_embedding is not None: + results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1) + if results and results['ids'] and results['ids'][0]: + distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A' + logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}") + if distance != 'N/A' and distance < 0.75: + l2_hit_key = results['ids'][0] + logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}") + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],)) + row = cursor.fetchone() + if row: + value, expires_at = row + if time.time() < expires_at: + data = json.loads(value) + logger.debug(f"L2语义缓存返回的数据: {data}") + # 回填 L1 + self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} + if query_embedding is not None: + new_id = self.l1_vector_index.ntotal + faiss.normalize_L2(query_embedding) + self.l1_vector_index.add(x=query_embedding) + self.l1_vector_id_to_key[new_id] = key + return data - if norm_text1 == norm_text2: - return 1.0 - - return SequenceMatcher(None, norm_text1, norm_text2).ratio() - - @staticmethod - def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str: - """ - 生成缓存键 - - Args: - tool_name: 工具名称 - function_args: 函数参数 - - Returns: - 缓存键字符串 - """ - # 将参数排序后序列化,确保相同参数产生相同的键 - sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False) - - # 纯 Python 实现 - cache_string = f"{tool_name}:{sorted_args}" - return hashlib.md5(cache_string.encode("utf-8")).hexdigest() - - def _get_cache_file_path(self, cache_key: str) -> Path: - """获取缓存文件路径""" - return self.cache_dir / f"{cache_key}.json" - - def _is_cache_expired(self, cached_time: datetime) -> bool: - """检查缓存是否过期""" - return datetime.now() - cached_time > self.max_age - - def _find_similar_cache( - self, tool_name: str, function_args: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - """ - 查找相似的缓存条目 - - Args: - tool_name: 工具名称 - function_args: 函数参数 - - Returns: - 相似的缓存结果,如果不存在则返回None - """ - query = function_args.get("query", "") - if not query: - return None - - candidates = [] - cache_data_list = [] - - # 遍历所有缓存文件,收集候选项 - for cache_file in self.cache_dir.glob("*.json"): - try: - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - # 检查是否是同一个工具 - if cache_data.get("tool_name") != tool_name: - continue - - # 检查缓存是否过期 - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - continue - - # 检查其他参数是否匹配(除了query) - cached_args = cache_data.get("function_args", {}) - args_match = True - for key, value in function_args.items(): - if key != "query" and cached_args.get(key) != value: - args_match = False - break - - if not args_match: - continue - - # 收集候选项 - cached_query = cached_args.get("query", "") - candidates.append((cached_query, len(cache_data_list))) - cache_data_list.append(cache_data) - - except Exception as e: - logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}") - continue - - if not candidates: - logger.debug( - f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" - ) - return None - - # 纯 Python 实现 - best_match = None - best_similarity = 0.0 - - for cached_query, index in candidates: - similarity = self._calculate_similarity(query, cached_query) - if similarity > best_similarity and similarity >= self.similarity_threshold: - best_similarity = similarity - best_match = cache_data_list[index] - - if best_match is not None: - cached_query = best_match["function_args"].get("query", "") - logger.info( - f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'" - ) - return best_match["result"] - - logger.debug( - f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" - ) + logger.debug(f"缓存未命中: {key}") return None - def get( - self, tool_name: str, function_args: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - """ - 从缓存获取结果,支持精确匹配和近似匹配 + async def set(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None): + """将结果存入所有缓存层。""" + if ttl is None: + ttl = self.default_ttl + if ttl <= 0: + return - Args: - tool_name: 工具名称 - function_args: 函数参数 + key = self._generate_key(tool_name, function_args, tool_class) + expires_at = time.time() + ttl + + # 写入 L1 + self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} - Returns: - 缓存的结果,如果不存在或已过期则返回None - """ - # 首先尝试精确匹配 - cache_key = self._generate_cache_key(tool_name, function_args) - cache_file = self._get_cache_file_path(cache_key) + # 写入 L2 + value = json.dumps(data) + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, value, expires_at)) + conn.commit() - if cache_file.exists(): - try: - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) + # 写入语义缓存 + if semantic_query and self.embedding_model: + embedding_result = await self.embedding_model.get_embedding(semantic_query) + if embedding_result: + embedding = np.array([embedding_result], dtype='float32') + # 写入 L1 Vector + new_id = self.l1_vector_index.ntotal + faiss.normalize_L2(embedding) + self.l1_vector_index.add(x=embedding) + self.l1_vector_id_to_key[new_id] = key + # 写入 L2 Vector + self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key]) - # 检查缓存是否过期 - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - logger.debug(f"缓存已过期: {cache_key}") - cache_file.unlink() # 删除过期缓存 - else: - logger.debug(f"精确匹配缓存: {tool_name}") - return cache_data["result"] + logger.info(f"已缓存条目: {key}, TTL: {ttl}s") - except (json.JSONDecodeError, KeyError, ValueError) as e: - logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}") - # 删除损坏的缓存文件 - if cache_file.exists(): - cache_file.unlink() + def clear_l1(self): + """清空L1缓存。""" + self.l1_kv_cache.clear() + self.l1_vector_index.reset() + self.l1_vector_id_to_key.clear() + logger.info("L1 (内存+FAISS) 缓存已清空。") - # 如果精确匹配失败,尝试近似匹配 - return self._find_similar_cache(tool_name, function_args) + def clear_l2(self): + """清空L2缓存。""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM cache") + conn.commit() + self.chroma_client.delete_collection(name="semantic_cache") + self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache") + logger.info("L2 (SQLite & ChromaDB) 缓存已清空。") - def set( - self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any] - ) -> None: - """ - 将结果保存到缓存 + def clear_all(self): + """清空所有缓存。""" + self.clear_l1() + self.clear_l2() + logger.info("所有缓存层级已清空。") - Args: - tool_name: 工具名称 - function_args: 函数参数 - result: 缓存结果 - """ - cache_key = self._generate_cache_key(tool_name, function_args) - cache_file = self._get_cache_file_path(cache_key) - - cache_data = { - "tool_name": tool_name, - "function_args": function_args, - "result": result, - "timestamp": datetime.now().isoformat(), - } - - try: - with open(cache_file, "w", encoding="utf-8") as f: - json.dump(cache_data, f, ensure_ascii=False, indent=2) - logger.debug(f"缓存已保存: {tool_name} -> {cache_key}") - except Exception as e: - logger.error(f"保存缓存失败: {cache_file}, 错误: {e}") - - def clear_expired(self) -> int: - """ - 清理过期缓存 - - Returns: - 删除的文件数量 - """ - removed_count = 0 - - for cache_file in self.cache_dir.glob("*.json"): - try: - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - cache_file.unlink() - removed_count += 1 - logger.debug(f"删除过期缓存: {cache_file}") - - except Exception as e: - logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}") - # 删除损坏的文件 - try: - cache_file.unlink() - removed_count += 1 - except (OSError, json.JSONDecodeError, KeyError, ValueError): - logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}") - - logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件") - return removed_count - - def clear_all(self) -> int: - """ - 清空所有缓存 - - Returns: - 删除的文件数量 - """ - removed_count = 0 - - for cache_file in self.cache_dir.glob("*.json"): - try: - cache_file.unlink() - removed_count += 1 - except Exception as e: - logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}") - - logger.info(f"清空缓存完成,删除了 {removed_count} 个文件") - return removed_count - - def get_stats(self) -> Dict[str, Any]: - """ - 获取缓存统计信息 - - Returns: - 缓存统计信息字典 - """ - total_files = 0 - expired_files = 0 - total_size = 0 - - for cache_file in self.cache_dir.glob("*.json"): - try: - total_files += 1 - total_size += cache_file.stat().st_size - - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - expired_files += 1 - - except (OSError, json.JSONDecodeError, KeyError, ValueError): - expired_files += 1 # 损坏的文件也算作过期 - - return { - "total_files": total_files, - "expired_files": expired_files, - "total_size_bytes": total_size, - "cache_dir": str(self.cache_dir), - "max_age_hours": self.max_age.total_seconds() / 3600, - "similarity_threshold": self.similarity_threshold, - } - -tool_cache = ToolCache() \ No newline at end of file +# 全局实例 +tool_cache = CacheManager() \ No newline at end of file diff --git a/src/common/cache_manager_backup.py b/src/common/cache_manager_backup.py new file mode 100644 index 000000000..ecaff3458 --- /dev/null +++ b/src/common/cache_manager_backup.py @@ -0,0 +1,344 @@ +import json +import hashlib +import re +from typing import Any, Dict, Optional +from datetime import datetime, timedelta +from pathlib import Path +from difflib import SequenceMatcher + +from src.common.logger import get_logger + +logger = get_logger("cache_manager") + + +class ToolCache: + """工具缓存管理器,用于缓存工具调用结果,支持近似匹配""" + + def __init__( + self, + cache_dir: str = "data/tool_cache", + max_age_hours: int = 24, + similarity_threshold: float = 0.65, + ): + """ + 初始化缓存管理器 + + Args: + cache_dir: 缓存目录路径 + max_age_hours: 缓存最大存活时间(小时) + similarity_threshold: 近似匹配的相似度阈值 (0-1) + """ + self.cache_dir = Path(cache_dir) + self.max_age = timedelta(hours=max_age_hours) + self.max_age_seconds = max_age_hours * 3600 + self.similarity_threshold = similarity_threshold + self.cache_dir.mkdir(parents=True, exist_ok=True) + + @staticmethod + def _normalize_query(query: str) -> str: + """ + 标准化查询文本,用于相似度比较 + + Args: + query: 原始查询文本 + + Returns: + 标准化后的查询文本 + """ + if not query: + return "" + + # 纯 Python 实现 + normalized = query.lower() + normalized = re.sub(r"[^\w\s]", " ", normalized) + normalized = " ".join(normalized.split()) + return normalized + + def _calculate_similarity(self, text1: str, text2: str) -> float: + """ + 计算两个文本的相似度 + + Args: + text1: 文本1 + text2: 文本2 + + Returns: + 相似度分数 (0-1) + """ + if not text1 or not text2: + return 0.0 + + # 纯 Python 实现 + norm_text1 = self._normalize_query(text1) + norm_text2 = self._normalize_query(text2) + + if norm_text1 == norm_text2: + return 1.0 + + return SequenceMatcher(None, norm_text1, norm_text2).ratio() + + @staticmethod + def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str: + """ + 生成缓存键 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + + Returns: + 缓存键字符串 + """ + # 将参数排序后序列化,确保相同参数产生相同的键 + sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False) + + # 纯 Python 实现 + cache_string = f"{tool_name}:{sorted_args}" + return hashlib.md5(cache_string.encode("utf-8")).hexdigest() + + def _get_cache_file_path(self, cache_key: str) -> Path: + """获取缓存文件路径""" + return self.cache_dir / f"{cache_key}.json" + + def _is_cache_expired(self, cached_time: datetime) -> bool: + """检查缓存是否过期""" + return datetime.now() - cached_time > self.max_age + + def _find_similar_cache( + self, tool_name: str, function_args: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """ + 查找相似的缓存条目 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + + Returns: + 相似的缓存结果,如果不存在则返回None + """ + query = function_args.get("query", "") + if not query: + return None + + candidates = [] + cache_data_list = [] + + # 遍历所有缓存文件,收集候选项 + for cache_file in self.cache_dir.glob("*.json"): + try: + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + # 检查是否是同一个工具 + if cache_data.get("tool_name") != tool_name: + continue + + # 检查缓存是否过期 + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + continue + + # 检查其他参数是否匹配(除了query) + cached_args = cache_data.get("function_args", {}) + args_match = True + for key, value in function_args.items(): + if key != "query" and cached_args.get(key) != value: + args_match = False + break + + if not args_match: + continue + + # 收集候选项 + cached_query = cached_args.get("query", "") + candidates.append((cached_query, len(cache_data_list))) + cache_data_list.append(cache_data) + + except Exception as e: + logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}") + continue + + if not candidates: + logger.debug( + f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" + ) + return None + + # 纯 Python 实现 + best_match = None + best_similarity = 0.0 + + for cached_query, index in candidates: + similarity = self._calculate_similarity(query, cached_query) + if similarity > best_similarity and similarity >= self.similarity_threshold: + best_similarity = similarity + best_match = cache_data_list[index] + + if best_match is not None: + cached_query = best_match["function_args"].get("query", "") + logger.info( + f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'" + ) + return best_match["result"] + + logger.debug( + f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" + ) + return None + + def get( + self, tool_name: str, function_args: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """ + 从缓存获取结果,支持精确匹配和近似匹配 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + + Returns: + 缓存的结果,如果不存在或已过期则返回None + """ + # 首先尝试精确匹配 + cache_key = self._generate_cache_key(tool_name, function_args) + cache_file = self._get_cache_file_path(cache_key) + + if cache_file.exists(): + try: + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + # 检查缓存是否过期 + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + logger.debug(f"缓存已过期: {cache_key}") + cache_file.unlink() # 删除过期缓存 + else: + logger.debug(f"精确匹配缓存: {tool_name}") + return cache_data["result"] + + except (json.JSONDecodeError, KeyError, ValueError) as e: + logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}") + # 删除损坏的缓存文件 + if cache_file.exists(): + cache_file.unlink() + + # 如果精确匹配失败,尝试近似匹配 + return self._find_similar_cache(tool_name, function_args) + + def set( + self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any] + ) -> None: + """ + 将结果保存到缓存 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + result: 缓存结果 + """ + cache_key = self._generate_cache_key(tool_name, function_args) + cache_file = self._get_cache_file_path(cache_key) + + cache_data = { + "tool_name": tool_name, + "function_args": function_args, + "result": result, + "timestamp": datetime.now().isoformat(), + } + + try: + with open(cache_file, "w", encoding="utf-8") as f: + json.dump(cache_data, f, ensure_ascii=False, indent=2) + logger.debug(f"缓存已保存: {tool_name} -> {cache_key}") + except Exception as e: + logger.error(f"保存缓存失败: {cache_file}, 错误: {e}") + + def clear_expired(self) -> int: + """ + 清理过期缓存 + + Returns: + 删除的文件数量 + """ + removed_count = 0 + + for cache_file in self.cache_dir.glob("*.json"): + try: + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + cache_file.unlink() + removed_count += 1 + logger.debug(f"删除过期缓存: {cache_file}") + + except Exception as e: + logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}") + # 删除损坏的文件 + try: + cache_file.unlink() + removed_count += 1 + except (OSError, json.JSONDecodeError, KeyError, ValueError): + logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}") + + logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件") + return removed_count + + def clear_all(self) -> int: + """ + 清空所有缓存 + + Returns: + 删除的文件数量 + """ + removed_count = 0 + + for cache_file in self.cache_dir.glob("*.json"): + try: + cache_file.unlink() + removed_count += 1 + except Exception as e: + logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}") + + logger.info(f"清空缓存完成,删除了 {removed_count} 个文件") + return removed_count + + def get_stats(self) -> Dict[str, Any]: + """ + 获取缓存统计信息 + + Returns: + 缓存统计信息字典 + """ + total_files = 0 + expired_files = 0 + total_size = 0 + + for cache_file in self.cache_dir.glob("*.json"): + try: + total_files += 1 + total_size += cache_file.stat().st_size + + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + expired_files += 1 + + except (OSError, json.JSONDecodeError, KeyError, ValueError): + expired_files += 1 # 损坏的文件也算作过期 + + return { + "total_files": total_files, + "expired_files": expired_files, + "total_size_bytes": total_size, + "cache_dir": str(self.cache_dir), + "max_age_hours": self.max_age.total_seconds() / 3600, + "similarity_threshold": self.similarity_threshold, + } + +tool_cache = ToolCache() \ No newline at end of file diff --git a/src/main.py b/src/main.py index 6e07ee692..475b92699 100644 --- a/src/main.py +++ b/src/main.py @@ -17,6 +17,7 @@ from src.common.server import get_global_server, Server from src.mood.mood_manager import mood_manager from rich.traceback import install from src.manager.schedule_manager import schedule_manager +from src.common.cache_manager import tool_cache # from src.api.main import start_api_server # 导入新的插件管理器和热重载管理器 diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index b2212f836..c4415fd99 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -88,7 +88,8 @@ class WebSurfingTool(BaseTool): return {"error": "搜索查询不能为空。"} # 检查缓存 - cached_result = tool_cache.get(self.name, function_args) + query = function_args.get("query") + cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__, semantic_query=query) if cached_result: logger.info(f"缓存命中: {self.name} -> {function_args}") return cached_result @@ -109,7 +110,8 @@ class WebSurfingTool(BaseTool): # 保存到缓存 if "error" not in result: - tool_cache.set(self.name, function_args, result) + query = function_args.get("query") + await tool_cache.set(self.name, function_args, self.__class__, result, semantic_query=query) return result @@ -463,7 +465,7 @@ class URLParserTool(BaseTool): 执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。 """ # 检查缓存 - cached_result = tool_cache.get(self.name, function_args) + cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__) if cached_result: logger.info(f"缓存命中: {self.name} -> {function_args}") return cached_result @@ -577,7 +579,7 @@ class URLParserTool(BaseTool): # 保存到缓存 if "error" not in result: - tool_cache.set(self.name, function_args, result) + await tool_cache.set(self.name, function_args, self.__class__, result) return result From c47d666d0758e2083f3e17c0fa48ba1ac0098349 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Mon, 18 Aug 2025 11:54:00 +0800 Subject: [PATCH 03/10] =?UTF-8?q?chore(deps):=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E6=9C=AA=E4=BD=BF=E7=94=A8=E7=9A=84=20ModelTaskConfig=20?= =?UTF-8?q?=E5=AF=BC=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 从 cache_manager.py 文件中删除了对 `src.config.api_ada_configs` 中 `ModelTaskConfig` 的导入,因为它在该文件中并未被使用。 添加了记忆系统的大饼 Co-authored-by: 雅诺狐 --- docs/memory_system_design_v2.md | 66 +++++++++++++++++++++++++++++++++ src/common/cache_manager.py | 1 - 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 docs/memory_system_design_v2.md diff --git a/docs/memory_system_design_v2.md b/docs/memory_system_design_v2.md new file mode 100644 index 000000000..2e16d5c2a --- /dev/null +++ b/docs/memory_system_design_v2.md @@ -0,0 +1,66 @@ +# 全新三层记忆系统架构 (V2.0) 设计文档 + +## 1. 核心思想 + +本架构旨在建立一个清晰、有序的信息处理流水线,模拟人类记忆从瞬时感知到长期知识沉淀的过程。信息将经历**短期记忆 (STM)**、**中期记忆 (MTM)** 和 **长期记忆 (LTM)** 三个阶段,实现从海量、零散到结构化、深刻的转化。 + +## 2. 架构分层详解 + +### 2.1. 短期记忆 (STM - Short-Term Memory) - “消息缓冲区” + +* **职责**: 捕获并暂存所有进入核心的最新消息,为即时对话提供上下文,实现快速响应。 +* **实现方式**: + * **内存队列**: 采用定长的内存队列(如 `collections.deque`),存储最近的 N 条原始消息(建议初始值为 200)。 + * **实时向量化**: 消息入队时,异步进行文本内容的语义向量化,生成“意义指纹”。 + * **快速检索**: 利用高效的向量相似度计算库(如 FAISS, Annoy),在新消息到来时,快速从队列中检索最相关的历史消息,构建即时上下文。 +* **触发机制**: 当队列达到容量上限时,将最老的一批消息(例如前 50 条)打包,移交给中期记忆模块处理。 + +### 2.2. 中期记忆 (MTM - Mid-Term Memory) - “记忆压缩器” + +* **职责**: 对来自短期记忆的大量零散信息进行压缩、总结,形成结构化的“记忆片段”。 +* **实现方式**: + * **LLM 总结**: 调用大语言模型(LLM)对 STM 移交的消息包进行深度分析和总结,提炼成一段精简的“记忆陈述”(Memory Statement)。 + * **信息结构化**: 每个记忆片段都将包含以下元数据: + * `memory_text`: 记忆陈述本身。 + * `keywords`: 关联的关键词列表。 + * `time_range`: 记忆所涉及的时间范围。 + * `importance_score`: LLM 评估的重要性评分。 + * `access_count`: 访问计数器,初始为 0。 + * **持久化存储**: 将结构化的记忆片段存储在数据库中,可复用或改造现有 `Memory` 表。 +* **触发机制**: 由 STM 的队列溢出事件触发。 + +### 2.3. 长期记忆 (LTM - Long-Term Memory) - “知识图谱” + +* **职责**: 将经过验证的、具有高价值的中期记忆,内化为系统核心知识的一部分,构建深层联系。 +* **实现方式**: + * **晋升机制**: 通过一个定期的“记忆整理”任务,扫描中期记忆数据库。当某个记忆片段的 `access_count` 达到预设阈值(例如 10 次),则触发晋升。 + * **融入图谱**: 晋升的记忆片段将被送往 `Hippocampus` 模块。`Hippocampus` 将不再直接处理原始聊天记录,而是处理这些高质量、经过预处理的记忆片段。它会从中提取核心概念(节点)和它们之间的关系(边),然后将这些信息融入并更新现有的知识图谱。 +* **触发机制**: 由定时任务(例如每天执行一次)触发。 + +## 3. 信息处理流程 + +```mermaid +graph TD + A[输入: 新消息] --> B{短期记忆 STM}; + B --> |实时向量检索| C[输出: 对话上下文]; + B --> |队列满| D{中期记忆 MTM}; + D --> |LLM 总结| E[存入数据库: 记忆片段]; + E --> |关键词/时间检索| C; + E --> |访问次数高| F{长期记忆 LTM}; + F --> |LLM 提取概念/关系| G[更新: 知识图谱]; + G --> |图谱扩散激活检索| C; + + subgraph "内存中 (高速)" + B + end + + subgraph "数据库中 (持久化)" + E + G + end +``` + +## 4. 现有模块改造计划 + +* **`InstantMemory`**: 将被新的 **STM** 和 **MTM** 模块取代。其原有的“判断是否需要记忆”和“总结”的功能,将融入到 MTM 的处理流程中。 +* **`Hippocampus`**: 将保留其作为 **LTM** 的核心地位,但其输入源将从“随机抽样的历史聊天记录”变更为“从 MTM 晋升的高价值记忆片段”。这将极大提升其构建知识图谱的效率和质量。 \ No newline at end of file diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index c1141aa66..28fcd0d87 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -9,7 +9,6 @@ import faiss from typing import Any, Dict, Optional from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest -from src.config.api_ada_configs import ModelTaskConfig from src.config.config import global_config, model_config logger = get_logger("cache_manager") From dcc2e4caff271f90de863372faa0c4ccedcd040d Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Mon, 18 Aug 2025 12:04:35 +0800 Subject: [PATCH 04/10] =?UTF-8?q?feat(maizone):=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E5=AE=9A=E6=97=B6=E4=BB=BB=E5=8A=A1=E9=9A=8F=E6=9C=BA=E9=97=B4?= =?UTF-8?q?=E9=9A=94=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为了避免定时任务在完全相同的时间点触发,引入了随机延迟机制。 现在,定时任务的执行间隔将在设定的最小和最大分钟数之间随机波动,使行为模式更难被预测。 此功能可通过配置项进行调整,默认间隔为 5 到 15 分钟。 Co-Authored-By: tt-P607 <68868379+tt-P607@users.noreply.github.com> --- src/plugins/built_in/maizone_refactored/plugin.py | 2 ++ .../maizone_refactored/services/scheduler_service.py | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/plugins/built_in/maizone_refactored/plugin.py b/src/plugins/built_in/maizone_refactored/plugin.py index ccf8874e1..17fb6678b 100644 --- a/src/plugins/built_in/maizone_refactored/plugin.py +++ b/src/plugins/built_in/maizone_refactored/plugin.py @@ -68,6 +68,8 @@ class MaiZoneRefactoredPlugin(BasePlugin): }, "schedule": { "enable_schedule": ConfigField(type=bool, default=False, description="是否启用定时发送"), + "random_interval_min_minutes": ConfigField(type=int, default=5, description="随机间隔分钟数下限"), + "random_interval_max_minutes": ConfigField(type=int, default=15, description="随机间隔分钟数上限"), }, "cookie": { "http_fallback_host": ConfigField(type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"), diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index 6047f43c5..501f9958c 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -5,6 +5,7 @@ """ import asyncio import datetime +import random import traceback from typing import Callable @@ -91,8 +92,12 @@ class SchedulerService: result.get("message", "") ) - # 6. 等待5分钟后进行下一次检查 - await asyncio.sleep(300) + # 6. 计算并等待一个随机的时间间隔 + min_minutes = self.get_config("schedule.random_interval_min_minutes", 5) + max_minutes = self.get_config("schedule.random_interval_max_minutes", 15) + wait_seconds = random.randint(min_minutes * 60, max_minutes * 60) + logger.info(f"下一次检查将在 {wait_seconds / 60:.2f} 分钟后进行。") + await asyncio.sleep(wait_seconds) except asyncio.CancelledError: logger.info("定时任务循环被取消。") From 3b3eb080daa7183ddcc803831ca4c9c3def78b09 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Mon, 18 Aug 2025 13:00:13 +0800 Subject: [PATCH 05/10] =?UTF-8?q?feat(plugin):=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E6=8F=92=E4=BB=B6=E9=85=8D=E7=BD=AE=E9=9B=86=E4=B8=AD=E5=8C=96?= =?UTF-8?q?=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将插件配置文件从各自的插件目录迁移至项目根目录下的 `config/plugins/` 文件夹中,方便用户统一管理和修改。 主要变更: - 新增 `plugins.centralized_config` 总开关,用于控制是否启用此功能。 - 修改插件加载逻辑,现在会从 `config/plugins//` 目录读取用户配置。 - 如果用户配置不存在,会自动从插件目录下的模板配置文件复制一份。 - 保留了原有的配置版本检查和自动迁移功能,现在作用于用户配置文件。 --- src/config/config.py | 4 +- src/config/official_configs.py | 15 +- src/plugin_system/base/plugin_base.py | 149 +++++++++--------- .../built_in/maizone_refactored/plugin.py | 2 - .../services/scheduler_service.py | 8 +- template/bot_config_template.toml | 5 +- 6 files changed, 101 insertions(+), 82 deletions(-) diff --git a/src/config/config.py b/src/config/config.py index 2d1b12d2d..654016f6a 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -41,7 +41,8 @@ from src.config.official_configs import ( DependencyManagementConfig, ExaConfig, WebSearchConfig, - TavilyConfig + TavilyConfig, + PluginsConfig ) from .api_ada_configs import ( @@ -362,6 +363,7 @@ class Config(ConfigBase): exa: ExaConfig = field(default_factory=lambda: ExaConfig()) web_search: WebSearchConfig = field(default_factory=lambda: WebSearchConfig()) tavily: TavilyConfig = field(default_factory=lambda: TavilyConfig()) + plugins: PluginsConfig = field(default_factory=lambda: PluginsConfig()) @dataclass diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 1de885121..3a7c6112a 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -988,10 +988,10 @@ class VideoAnalysisConfig(ConfigBase): """批量分析时使用的提示词""" -@dataclass +@dataclass class WebSearchConfig(ConfigBase): """联网搜索组件配置类""" - + enable_web_search_tool: bool = True """是否启用联网搜索工具""" @@ -1002,4 +1002,13 @@ class WebSearchConfig(ConfigBase): """启用的搜索引擎列表,可选: 'exa', 'tavily', 'ddg'""" search_strategy: str = "single" - """搜索策略: 'single'(使用第一个可用引擎), 'parallel'(并行使用所有启用的引擎), 'fallback'(按顺序尝试,失败则尝试下一个)""" \ No newline at end of file + """搜索策略: 'single'(使用第一个可用引擎), 'parallel'(并行使用所有启用的引擎), 'fallback'(按顺序尝试,失败则尝试下一个)""" + + +@dataclass +class PluginsConfig(ConfigBase): + """插件配置""" + + centralized_config: bool = field( + default=True, metadata={"description": "是否启用插件配置集中化管理"} + ) \ No newline at end of file diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index 36af7dd31..39e57b1aa 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -8,6 +8,7 @@ import shutil import datetime from src.common.logger import get_logger +from src.config.config import CONFIG_DIR from src.plugin_system.base.component_types import ( PluginInfo, PythonDependency, @@ -71,6 +72,7 @@ class PluginBase(ABC): self.config: Dict[str, Any] = {} # 插件配置 self.plugin_dir = plugin_dir # 插件目录路径 self.log_prefix = f"[Plugin:{self.plugin_name}]" + self._is_enabled = self.enable_plugin # 从插件定义中获取默认启用状态 # 加载manifest文件 self._load_manifest() @@ -100,7 +102,7 @@ class PluginBase(ABC): description=self.plugin_description, version=self.plugin_version, author=self.plugin_author, - enabled=self.enable_plugin, + enabled=self._is_enabled, is_built_in=False, config_file=self.config_file_name or "", dependencies=self.dependencies.copy(), @@ -453,86 +455,91 @@ class PluginBase(ABC): logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True) def _load_plugin_config(self): # sourcery skip: extract-method - """加载插件配置文件,支持版本检查和自动迁移""" + """ + 加载插件配置文件,实现集中化管理和自动迁移。 + + 处理逻辑: + 1. 确定插件模板配置文件路径(位于插件目录内)。 + 2. 如果模板不存在,则在插件目录内生成一份默认配置。 + 3. 确定用户配置文件路径(位于 `config/plugins/` 目录下)。 + 4. 如果用户配置文件不存在,则从插件目录复制模板文件过去。 + 5. 加载用户配置文件,并进行版本检查和自动迁移(如果需要)。 + 6. 最终加载的配置是用户配置文件。 + """ if not self.config_file_name: logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载") return - # 优先使用传入的插件目录路径 - if self.plugin_dir: - plugin_dir = self.plugin_dir - else: - # fallback:尝试从类的模块信息获取路径 + # 1. 确定插件模板配置文件路径 + template_config_path = os.path.join(self.plugin_dir, self.config_file_name) + + # 2. 如果模板不存在,则在插件目录内生成 + if not os.path.exists(template_config_path): + logger.info(f"{self.log_prefix} 插件目录缺少配置文件 {template_config_path},将生成默认配置。") + self._generate_and_save_default_config(template_config_path) + + # 3. 确定用户配置文件路径 + plugin_config_dir = os.path.join(CONFIG_DIR, "plugins", self.plugin_name) + user_config_path = os.path.join(plugin_config_dir, self.config_file_name) + + # 确保用户插件配置目录存在 + os.makedirs(plugin_config_dir, exist_ok=True) + + # 4. 如果用户配置文件不存在,从模板复制 + if not os.path.exists(user_config_path): try: - plugin_module_path = inspect.getfile(self.__class__) - plugin_dir = os.path.dirname(plugin_module_path) - except (TypeError, OSError): - # 最后的fallback:从模块的__file__属性获取 - module = inspect.getmodule(self.__class__) - if module and hasattr(module, "__file__") and module.__file__: - plugin_dir = os.path.dirname(module.__file__) - else: - logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载") - return + shutil.copy2(template_config_path, user_config_path) + logger.info(f"{self.log_prefix} 已从模板创建用户配置文件: {user_config_path}") + except IOError as e: + logger.error(f"{self.log_prefix} 复制配置文件失败: {e}", exc_info=True) + # 如果复制失败,后续将无法加载,直接返回 + return - config_file_path = os.path.join(plugin_dir, self.config_file_name) - - # 如果配置文件不存在,生成默认配置 - if not os.path.exists(config_file_path): - logger.info(f"{self.log_prefix} 配置文件 {config_file_path} 不存在,将生成默认配置。") - self._generate_and_save_default_config(config_file_path) - - if not os.path.exists(config_file_path): - logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在且无法生成。") + # 检查最终的用户配置文件是否存在 + if not os.path.exists(user_config_path): + logger.warning(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在且无法创建。") return - file_ext = os.path.splitext(self.config_file_name)[1].lower() - - if file_ext == ".toml": - # 加载现有配置 - with open(config_file_path, "r", encoding="utf-8") as f: - existing_config = toml.load(f) or {} - - # 检查配置版本 - current_version = self._get_current_config_version(existing_config) - - # 如果配置文件没有版本信息,跳过版本检查 - if current_version == "0.0.0": - logger.debug(f"{self.log_prefix} 配置文件无版本信息,跳过版本检查") - self.config = existing_config - else: - expected_version = self._get_expected_config_version() - - if current_version != expected_version: - logger.info( - f"{self.log_prefix} 检测到配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}" - ) - - # 生成新的默认配置结构 - new_config_structure = self._generate_config_from_schema() - - # 迁移旧配置值到新结构 - migrated_config = self._migrate_config_values(existing_config, new_config_structure) - - # 保存迁移后的配置 - self._save_config_to_file(migrated_config, config_file_path) - - logger.info(f"{self.log_prefix} 配置文件已从 v{current_version} 更新到 v{expected_version}") - - self.config = migrated_config - else: - logger.debug(f"{self.log_prefix} 配置版本匹配 (v{current_version}),直接加载") - self.config = existing_config - - logger.debug(f"{self.log_prefix} 配置已从 {config_file_path} 加载") - - # 从配置中更新 enable_plugin - if "plugin" in self.config and "enabled" in self.config["plugin"]: - self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore - logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}") - else: + # 5. 加载、检查和迁移用户配置文件 + _, file_ext = os.path.splitext(self.config_file_name) + if file_ext.lower() != ".toml": logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml") self.config = {} + return + + try: + with open(user_config_path, "r", encoding="utf-8") as f: + existing_config = toml.load(f) or {} + except Exception as e: + logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True) + self.config = {} + return + + current_version = self._get_current_config_version(existing_config) + expected_version = self._get_expected_config_version() + + if current_version == "0.0.0": + logger.debug(f"{self.log_prefix} 用户配置文件无版本信息,跳过版本检查") + self.config = existing_config + elif current_version != expected_version: + logger.info( + f"{self.log_prefix} 检测到用户配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}" + ) + new_config_structure = self._generate_config_from_schema() + migrated_config = self._migrate_config_values(existing_config, new_config_structure) + self._save_config_to_file(migrated_config, user_config_path) + logger.info(f"{self.log_prefix} 用户配置文件已从 v{current_version} 更新到 v{expected_version}") + self.config = migrated_config + else: + logger.debug(f"{self.log_prefix} 用户配置版本匹配 (v{current_version}),直接加载") + self.config = existing_config + + logger.debug(f"{self.log_prefix} 配置已从 {user_config_path} 加载") + + # 从配置中更新 enable_plugin 状态 + if "plugin" in self.config and "enabled" in self.config["plugin"]: + self._is_enabled = self.config["plugin"]["enabled"] + logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self._is_enabled}") def _check_dependencies(self) -> bool: """检查插件依赖""" diff --git a/src/plugins/built_in/maizone_refactored/plugin.py b/src/plugins/built_in/maizone_refactored/plugin.py index 17fb6678b..d0f9d2ad9 100644 --- a/src/plugins/built_in/maizone_refactored/plugin.py +++ b/src/plugins/built_in/maizone_refactored/plugin.py @@ -9,8 +9,6 @@ from src.common.logger import get_logger from src.plugin_system import ( BasePlugin, ComponentInfo, - BaseAction, - BaseCommand, register_plugin ) from src.plugin_system.base.config_types import ConfigField diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index 501f9958c..1288f2953 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -143,10 +143,10 @@ class SchedulerService: if record: # 如果存在,则更新状态 - record.is_processed = True - record.processed_at = datetime.datetime.now() - record.send_success = success - record.story_content = content + record.is_processed = True # type: ignore + record.processed_at = datetime.datetime.now()# type: ignore + record.send_success = success# type: ignore + record.story_content = content# type: ignore else: # 如果不存在,则创建新记录 new_record = MaiZoneScheduleStatus( diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index b82185d9d..ac057dac2 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "6.3.6" +version = "6.3.7" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -350,3 +350,6 @@ enable_url_tool = true # 是否启用URL解析tool # 搜索引擎配置 enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg" search_strategy = "single" # 搜索策略: "single"(使用第一个可用引擎), "parallel"(并行使用所有启用的引擎), "fallback"(按顺序尝试,失败则尝试下一个) + +[plugins] # 插件配置 +centralized_config = true # 是否启用插件配置集中化管理 From d43d352ca52c7c0f9261730afb38d6cd3f37bd2b Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Mon, 18 Aug 2025 13:16:54 +0800 Subject: [PATCH 06/10] =?UTF-8?q?refactor(config):=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E8=A7=86=E9=A2=91=E5=88=86=E6=9E=90=E7=9B=B8=E5=85=B3=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E7=9A=84=E9=85=8D=E7=BD=AE=E5=8F=8A=E6=9C=AA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E7=9A=84=E5=AF=BC=E5=85=A5(=E6=89=80=E4=BB=A5?= =?UTF-8?q?=E6=88=91=E6=8C=BA=E5=A5=BD=E5=A5=87=E4=B8=BA=E4=BB=80=E4=B9=88?= =?UTF-8?q?VideoAnalysisConfig=E8=83=BD=E6=9C=89=E4=B8=A4=E4=B8=AA)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 同时,清理了多个文件中未使用的导入,包括 `inspect`、`base64` 和 `get_image_manager`,以保持代码库的整洁。 --- src/config/official_configs.py | 39 +------------------ src/plugin_system/base/plugin_base.py | 1 - .../services/qzone_service.py | 2 - .../services/scheduler_service.py | 2 +- 4 files changed, 2 insertions(+), 42 deletions(-) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 3a7c6112a..8a81c785b 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1,7 +1,7 @@ import re from dataclasses import dataclass, field -from typing import Literal, Optional, Dict +from typing import Literal, Optional from src.config.config_base import ConfigBase @@ -864,43 +864,6 @@ class ScheduleConfig(ConfigBase): guidelines: Optional[str] = field(default=None) """日程生成指导原则,如果为None则使用默认指导原则""" - -@dataclass -class VideoAnalysisConfig(ConfigBase): - """视频分析配置类""" - - enable: bool = True - """是否启用视频分析功能""" - - analysis_mode: Literal["frame_by_frame", "batch_frames", "auto"] = "auto" - """分析模式:逐帧分析(慢但详细)、批量分析(快但可能略简单)或自动选择""" - - max_frames: int = 8 - """最大分析帧数""" - - frame_quality: int = 85 - """帧图像JPEG质量 (1-100)""" - - max_image_size: int = 800 - """单帧最大图像尺寸(像素)""" - - batch_analysis_prompt: str = field(default="""请分析这个视频的内容。这些图片是从视频中按时间顺序提取的关键帧。 - -请提供详细的分析,包括: -1. 视频的整体内容和主题 -2. 主要人物、对象和场景描述 -3. 动作、情节和时间线发展 -4. 视觉风格和艺术特点 -5. 整体氛围和情感表达 -6. 任何特殊的视觉效果或文字内容 - -请用中文回答,分析要详细准确。""") - """批量分析时使用的提示词""" - - enable_frame_timing: bool = True - """是否在分析中包含帧的时间信息""" - - @dataclass class DependencyManagementConfig(ConfigBase): """插件Python依赖管理配置类""" diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index 39e57b1aa..3c71bb6e8 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from typing import Dict, List, Any, Union import os -import inspect import toml import json import shutil diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 9495bd414..fa9d1df90 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -5,7 +5,6 @@ QQ空间服务模块 """ import asyncio -import base64 import json import os import random @@ -15,7 +14,6 @@ from typing import Callable, Optional, Dict, Any, List, Tuple import aiohttp import bs4 import json5 -from src.chat.utils.utils_image import get_image_manager from src.common.logger import get_logger from src.plugin_system.apis import config_api, person_api diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index 1288f2953..5aff2a218 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -118,7 +118,7 @@ class SchedulerService: with get_db_session() as session: record = session.query(MaiZoneScheduleStatus).filter( MaiZoneScheduleStatus.datetime_hour == hour_str, - MaiZoneScheduleStatus.is_processed == True + MaiZoneScheduleStatus.is_processed == True # noqa: E712 ).first() return record is not None except Exception as e: From 22f6cd2d947a8ad327cc70844056f7297d5c2527 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Mon, 18 Aug 2025 13:48:55 +0800 Subject: [PATCH 07/10] =?UTF-8?q?feat(deps):=20=E5=AE=9E=E7=8E=B0=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E5=8C=85=E6=99=BA=E8=83=BD=E5=88=AB=E5=90=8D=E8=A7=A3?= =?UTF-8?q?=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 引入了依赖包智能别名解析机制,以解决 Python 生态中常见的安装名与导入名不一致的问题(如 `beautifulsoup4` -> `bs4`)。 当通过包名直接导入失败时,依赖管理器会自动查询一个内置的别名映射表,并尝试使用别名再次导入。这大大提升了开发者在定义简单字符串格式依赖时的体验,减少了因名称不一致导致的依赖检查失败。 同时,更新了相关文档,详细说明了该功能的工作原理、解决了什么问题,并更新了最佳实践。 --- docs/plugins/dependency-management.md | 54 +++++-- src/plugin_system/utils/dependency_alias.py | 134 ++++++++++++++++++ src/plugin_system/utils/dependency_manager.py | 82 ++++++----- 3 files changed, 226 insertions(+), 44 deletions(-) create mode 100644 src/plugin_system/utils/dependency_alias.py diff --git a/docs/plugins/dependency-management.md b/docs/plugins/dependency-management.md index ada951db7..e5eed554a 100644 --- a/docs/plugins/dependency-management.md +++ b/docs/plugins/dependency-management.md @@ -165,10 +165,38 @@ configure_dependency_settings(auto_install_timeout=600) ## 工作流程 1. **插件初始化**: 当插件类被实例化时,系统自动检查依赖 -2. **依赖标准化**: 将字符串格式的依赖转换为PythonDependency对象 +2. **依赖标准化**: 将字符串格式的依赖转换为`PythonDependency`对象 3. **检查已安装**: 尝试导入每个依赖包并检查版本 -4. **自动安装**: 如果启用,自动安装缺失的依赖 -5. **错误处理**: 记录详细的错误信息和安装日志 +4. **智能别名解析 (新增)**: 如果直接导入失败 (例如 `import beautifulsoup4` 失败),系统会查询内置的别名映射表 (例如 `beautifulsoup4` -> `bs4`),并尝试使用别名再次导入。 +5. **自动安装**: 如果启用,自动安装缺失的依赖 +6. **错误处理**: 记录详细的错误信息和安装日志 + +## 智能别名解析 (Smart Alias Resolution) + +为了提升开发体验,依赖管理系统内置了一套智能别名解析机制。 + +### 解决的问题 + +Python生态中存在一些特殊的包,它们的**安装名** (在 `pip install` 中使用) 与**导入名** (在 `import` 语句中使用) 不一致。最典型的例子就是: +- 安装名: `beautifulsoup4`, 导入名: `bs4` +- 安装名: `Pillow`, 导入名: `PIL` +- 安装名: `scikit-learn`, 导入名: `sklearn` + +如果开发者在 `python_dependencies` 列表中使用简单的字符串格式 `"beautifulsoup4"`,标准的依赖检查会因为无法 `import beautifulsoup4` 而失败。 + +### 工作原理 + +当依赖管理器通过包名直接导入失败时,它会: +1. 查询一个内置的、包含上百个常见包的别名映射表。 +2. 如果在表中找到对应的导入名,则使用该别名再次尝试导入。 +3. 如果使用别名导入成功,则依赖检查通过,并继续进行版本验证。 + +这个过程是自动的,旨在处理绝大多数常见情况,减少开发者手动配置的麻烦。 + +### 注意事项 + +- **最佳实践**: 尽管有智能别名解析,我们仍然**强烈推荐**使用 `PythonDependency` 对象来明确指定 `package_name` (导入名) 和 `install_name` (安装名),这能确保最高的准确性和可读性。 +- **覆盖范围**: 内置的别名映射表涵盖了大量常用库,但无法保证100%覆盖所有情况。如果遇到别名库未收录的包,请使用 `PythonDependency` 对象进行精确定义。 ## 日志输出示例 @@ -192,12 +220,13 @@ configure_dependency_settings(auto_install_timeout=600) ## 最佳实践 -1. **使用详细的PythonDependency对象** 以获得更好的控制和文档 -2. **配置PyPI镜像源** 特别是在中国大陆地区,可显著提升下载速度 -3. **合理设置可选依赖** 避免非核心功能阻止插件加载 -4. **指定版本要求** 确保兼容性 -5. **添加描述信息** 帮助用户理解依赖的用途 -6. **测试依赖配置** 在不同环境中验证依赖是否正确 +1. **优先使用`PythonDependency`对象**: 这是最可靠、最明确的方式,尤其是在安装名和导入名不同时。 +2. **利用智能别名解析**: 对于常见的、安装名与导入名不一致的包 (如 `beautifulsoup4`, `Pillow` 等),可以直接在字符串列表里使用安装名,系统会自动解析。 +3. **配置PyPI镜像源**: 特别是在中国大陆地区,可显著提升下载速度。 +4. **合理设置可选依赖**: 避免非核心功能阻止插件加载。 +5. **指定版本要求**: 确保兼容性。 +6. **添加描述信息**: 帮助用户理解依赖的用途。 +7. **测试依赖配置**: 在不同环境中验证依赖是否正确。 ## 安全考虑 @@ -225,7 +254,8 @@ configure_dependency_settings(auto_install_timeout=600) ### 导入错误 -1. 确认包名与导入名一致 -2. 检查可选依赖配置 -3. 验证安装是否成功 +1. **确认包名与导入名**: 检查安装名和导入名是否一致。如果不一致,推荐使用 `PythonDependency` 对象明确指定 `package_name` 和 `install_name`。 +2. **利用自动别名解析**: 对于常见库,系统会自动尝试解析别名。如果你的库比较冷门且名称不一致,请使用 `PythonDependency` 对象。 +3. **检查可选依赖配置**: 确认 `optional=True` 是否被正确设置。 +4. **验证安装是否成功**: 查看日志,确认 `pip install` 过程没有报错。 diff --git a/src/plugin_system/utils/dependency_alias.py b/src/plugin_system/utils/dependency_alias.py new file mode 100644 index 000000000..5a817286c --- /dev/null +++ b/src/plugin_system/utils/dependency_alias.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +""" +本模块包含一个从Python包的“安装名”到其“导入名”的映射。 + +这个映射表主要用于解决一个常见问题:某些Python包通过pip安装时使用的名称 +与在代码中`import`时使用的名称不一致。例如,我们使用`pip install beautifulsoup4` +来安装,但在代码中却需要`import bs4`。 + +当插件系统检查依赖时,如果一个开发者只简单地在依赖列表中写了安装名 +(例如 "beautifulsoup4"),标准的导入检查`import('beautifulsoup4')`会失败。 +通过这个映射表,依赖管理器可以在初次导入检查失败后,查询是否存在一个 +已知的别名(例如 "bs4"),并尝试使用该别名进行二次导入检查。 + +这样做的好处是: +1. 提升开发者体验:插件开发者无需强制记忆这些特殊的名称对应关系,或者强制 + 使用更复杂的`PythonDependency`对象来分别指定安装名和导入名。 +2. 增强系统健壮性:减少因名称不一致导致的插件加载失败问题。 +3. 兼容性:对遵循最佳实践、正确指定了`package_name`和`install_name`的 + 开发者没有任何影响。 + +开发者可以持续向这个列表中贡献新的映射关系,使其更加完善。 +""" + +INSTALL_NAME_TO_IMPORT_NAME = { + # ============== 数据科学与机器学习 (Data Science & Machine Learning) ============== + "scikit-learn": "sklearn", # 机器学习库 + "scikit-image": "skimage", # 图像处理库 + "opencv-python": "cv2", # OpenCV 计算机视觉库 + "opencv-contrib-python": "cv2", # OpenCV 扩展模块 + "tensorflow-gpu": "tensorflow", # TensorFlow GPU版本 + "tensorboardx": "tensorboardX", # TensorBoard 的封装 + "torchvision": "torchvision", # PyTorch 视觉库 (通常与 torch 一起) + "torchaudio": "torchaudio", # PyTorch 音频库 + "catboost": "catboost", # CatBoost 梯度提升库 + "lightgbm": "lightgbm", # LightGBM 梯度提升库 + "xgboost": "xgboost", # XGBoost 梯度提升库 + "imbalanced-learn": "imblearn", # 处理不平衡数据集 + "seqeval": "seqeval", # 序列标注评估 + "gensim": "gensim", # 主题建模和NLP + "nltk": "nltk", # 自然语言工具包 + "spacy": "spacy", # 工业级自然语言处理 + "fuzzywuzzy": "fuzzywuzzy", # 模糊字符串匹配 + "python-levenshtein": "Levenshtein", # Levenshtein 距离计算 + + # ============== Web开发与API (Web Development & API) ============== + "python-socketio": "socketio", # Socket.IO 服务器和客户端 + "python-engineio": "engineio", # Engine.IO 底层库 + "aiohttp": "aiohttp", # 异步HTTP客户端/服务器 + "python-multipart": "multipart", # 解析 multipart/form-data + "uvloop": "uvloop", # 高性能asyncio事件循环 + "httptools": "httptools", # 高性能HTTP解析器 + "websockets": "websockets", # WebSocket实现 + "fastapi": "fastapi", # 高性能Web框架 + "starlette": "starlette", # ASGI框架 + "uvicorn": "uvicorn", # ASGI服务器 + "gunicorn": "gunicorn", # WSGI服务器 + "django-rest-framework": "rest_framework", # Django REST框架 + "django-cors-headers": "corsheaders", # Django CORS处理 + "flask-jwt-extended": "flask_jwt_extended", # Flask JWT扩展 + "flask-sqlalchemy": "flask_sqlalchemy", # Flask SQLAlchemy扩展 + "flask-migrate": "flask_migrate", # Flask Alembic迁移扩展 + "python-jose": "jose", # JOSE (JWT, JWS, JWE) 实现 + "passlib": "passlib", # 密码哈希库 + "bcrypt": "bcrypt", # Bcrypt密码哈希 + + # ============== 数据库 (Database) ============== + "mysql-connector-python": "mysql.connector", # MySQL官方驱动 + "psycopg2-binary": "psycopg2", # PostgreSQL驱动 (二进制) + "pymongo": "pymongo", # MongoDB驱动 + "redis": "redis", # Redis客户端 + "aioredis": "aioredis", # 异步Redis客户端 + "sqlalchemy": "sqlalchemy", # SQL工具包和ORM + "alembic": "alembic", # SQLAlchemy数据库迁移工具 + "tortoise-orm": "tortoise", # 异步ORM + + # ============== 图像与多媒体 (Image & Multimedia) ============== + "Pillow": "PIL", # Python图像处理库 (PIL Fork) + "moviepy": "moviepy", # 视频编辑库 + "pydub": "pydub", # 音频处理库 + "pycairo": "cairo", # Cairo 2D图形库的Python绑定 + "wand": "wand", # ImageMagick的Python绑定 + + # ============== 解析与序列化 (Parsing & Serialization) ============== + "beautifulsoup4": "bs4", # HTML/XML解析库 + "lxml": "lxml", # 高性能HTML/XML解析库 + "PyYAML": "yaml", # YAML解析库 + "python-dotenv": "dotenv", # .env文件解析 + "python-dateutil": "dateutil", # 强大的日期时间解析 + "protobuf": "google.protobuf", # Protocol Buffers + "msgpack": "msgpack", # MessagePack序列化 + "orjson": "orjson", # 高性能JSON库 + "pydantic": "pydantic", # 数据验证和设置管理 + + # ============== 系统与硬件 (System & Hardware) ============== + "pyserial": "serial", # 串口通信 + "pyusb": "usb", # USB访问 + "pybluez": "bluetooth", # 蓝牙通信 (可能因平台而异) + "psutil": "psutil", # 系统信息和进程管理 + "watchdog": "watchdog", # 文件系统事件监控 + "python-gnupg": "gnupg", # GnuPG的Python接口 + + # ============== 加密与安全 (Cryptography & Security) ============== + "pycrypto": "Crypto", # 加密库 (较旧) + "pycryptodome": "Crypto", # PyCrypto的现代分支 + "cryptography": "cryptography", # 现代加密库 + "pyopenssl": "OpenSSL", # OpenSSL的Python接口 + "service-identity": "service_identity", # 服务身份验证 + + # ============== 工具与杂项 (Utilities & Miscellaneous) ============== + "setuptools": "setuptools", # 打包工具 + "pip": "pip", # 包安装器 + "tqdm": "tqdm", # 进度条 + "regex": "regex", # 替代的正则表达式引擎 + "colorama": "colorama", # 跨平台彩色终端文本 + "termcolor": "termcolor", # 终端颜色格式化 + "requests-oauthlib": "requests_oauthlib", # OAuth for Requests + "oauthlib": "oauthlib", # 通用OAuth库 + "authlib": "authlib", # OAuth和OpenID Connect客户端/服务器 + "pyjwt": "jwt", # JSON Web Token实现 + "python-editor": "editor", # 程序化地调用编辑器 + "prompt-toolkit": "prompt_toolkit", # 构建交互式命令行 + "pygments": "pygments", # 语法高亮 + "tabulate": "tabulate", # 生成漂亮的表格 + "nats-client": "nats", # NATS客户端 + "gitpython": "git", # Git的Python接口 + "pygithub": "github", # GitHub API v3的Python接口 + "python-gitlab": "gitlab", # GitLab API的Python接口 + "jira": "jira", # JIRA API的Python接口 + "python-jenkins": "jenkins", # Jenkins API的Python接口 + "huggingface-hub": "huggingface_hub", # Hugging Face Hub API + "apache-airflow": "airflow", # Airflow工作流管理 + "pandas-stubs": "pandas-stubs", # Pandas的类型存根 + "data-science-types": "data_science_types", # 数据科学类型 +} \ No newline at end of file diff --git a/src/plugin_system/utils/dependency_manager.py b/src/plugin_system/utils/dependency_manager.py index 33363168b..e524a7bd7 100644 --- a/src/plugin_system/utils/dependency_manager.py +++ b/src/plugin_system/utils/dependency_manager.py @@ -8,6 +8,7 @@ from packaging.requirements import Requirement from src.common.logger import get_logger from src.plugin_system.base.component_types import PythonDependency +from src.plugin_system.utils.dependency_alias import INSTALL_NAME_TO_IMPORT_NAME logger = get_logger("dependency_manager") @@ -190,41 +191,58 @@ class DependencyManager: def _check_single_dependency(self, dep: PythonDependency) -> bool: """检查单个依赖是否满足要求""" - try: - # 尝试导入包 - spec = importlib.util.find_spec(dep.package_name) - if spec is None: - return False - - # 如果没有版本要求,导入成功就够了 - if not dep.version: - return True - - # 检查版本要求 + + def _try_check(import_name: str) -> bool: + """尝试使用给定的导入名进行检查""" try: - module = importlib.import_module(dep.package_name) - installed_version = getattr(module, '__version__', None) - - if installed_version is None: - # 尝试其他常见的版本属性 - installed_version = getattr(module, 'VERSION', None) + spec = importlib.util.find_spec(import_name) + if spec is None: + return False + + # 如果没有版本要求,导入成功就够了 + if not dep.version: + return True + + # 检查版本要求 + try: + module = importlib.import_module(import_name) + installed_version = getattr(module, '__version__', None) + if installed_version is None: - logger.debug(f"无法获取包 {dep.package_name} 的版本信息,假设满足要求") - return True - - # 解析版本要求 - req = Requirement(f"{dep.package_name}{dep.version}") - return version.parse(str(installed_version)) in req.specifier - + # 尝试其他常见的版本属性 + installed_version = getattr(module, 'VERSION', None) + if installed_version is None: + logger.debug(f"无法获取包 {import_name} 的版本信息,假设满足要求") + return True + + # 解析版本要求 + req = Requirement(f"{dep.package_name}{dep.version}") + return version.parse(str(installed_version)) in req.specifier + + except Exception as e: + logger.debug(f"检查包 {import_name} 版本时出错: {e}") + return True # 如果无法检查版本,假设满足要求 + + except ImportError: + return False except Exception as e: - logger.debug(f"检查包 {dep.package_name} 版本时出错: {e}") - return True # 如果无法检查版本,假设满足要求 - - except ImportError: - return False - except Exception as e: - logger.error(f"检查依赖 {dep.package_name} 时发生未知错误: {e}") - return False + logger.error(f"检查依赖 {import_name} 时发生未知错误: {e}") + return False + + # 1. 首先尝试使用原始的 package_name 进行检查 + if _try_check(dep.package_name): + return True + + # 2. 如果失败,查询别名映射表 + # 注意:此时 dep.package_name 可能是 simple "requests" 或 "beautifulsoup4" + import_alias = INSTALL_NAME_TO_IMPORT_NAME.get(dep.package_name) + if import_alias: + logger.debug(f"依赖 '{dep.package_name}' 导入失败, 尝试使用别名 '{import_alias}'") + if _try_check(import_alias): + return True + + # 3. 如果别名也失败了,或者没有别名,最终确认失败 + return False def _install_single_package(self, package: str, plugin_name: str = "") -> bool: """安装单个包""" From 9205edf8cabfeb3d3c1f5c6eed9384c002089f9a Mon Sep 17 00:00:00 2001 From: Furina-1013-create <189647097+Furina-1013-create@users.noreply.github.com> Date: Mon, 18 Aug 2025 15:47:36 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E4=B8=80=E9=97=AA?= =?UTF-8?q?=E5=BF=98=E8=AE=B0=E6=9B=B4=E6=96=B0=E7=9A=84requirements.txt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 4c83dda83..b5d553cb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -60,5 +60,6 @@ exa_py asyncddgs opencv-python Pillow +chromadb asyncio tavily-python \ No newline at end of file From 23aec68cc010bf790ea3c256dda9b98f60a8ca38 Mon Sep 17 00:00:00 2001 From: Furina-1013-create <189647097+Furina-1013-create@users.noreply.github.com> Date: Mon, 18 Aug 2025 16:18:21 +0800 Subject: [PATCH 09/10] =?UTF-8?q?=E4=BF=AE=E5=A4=8DGemini=20api=E4=B8=93?= =?UTF-8?q?=E5=B1=9E=E7=9A=84=E9=82=A3=E4=B8=AAgemini=5Fclient.py=E9=87=8C?= =?UTF-8?q?=E9=9D=A2=E7=9A=84=E4=B8=80=E4=B8=AA=E6=BD=9C=E5=9C=A8=E7=9A=84?= =?UTF-8?q?=E5=AF=BC=E5=85=A5=E9=97=AE=E9=A2=98=E5=B9=B6=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=9B=9E=E9=80=80=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/model_client/gemini_client.py | 156 +++++++++++-------- 1 file changed, 90 insertions(+), 66 deletions(-) diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index 60c0c3901..9bda858ef 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -1,32 +1,54 @@ import asyncio import io import base64 -from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List +from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict, Union -from google import genai -from google.genai.types import ( - Content, - Part, - FunctionDeclaration, +import google.generativeai as genai +from google.generativeai.types import ( GenerateContentResponse, - ContentListUnion, - ContentUnion, - ThinkingConfig, - Tool, - GenerateContentConfig, - EmbedContentResponse, - EmbedContentConfig, - SafetySetting, HarmCategory, HarmBlockThreshold, ) -from google.genai.errors import ( - ClientError, - ServerError, - UnknownFunctionCallArgumentError, - UnsupportedFunctionError, - FunctionInvocationError, -) + +try: + # 尝试从较新的API导入 + from google.generativeai import configure + from google.generativeai.types import SafetySetting, GenerationConfig +except ImportError: + # 回退到基本类型 + SafetySetting = Dict + GenerationConfig = Dict + +# 定义兼容性类型 +ContentDict = Dict +PartDict = Dict +ToolDict = Dict +FunctionDeclaration = Dict +Tool = Dict +ContentListUnion = List[Dict] +ContentUnion = Dict +Content = Dict +Part = Dict +ThinkingConfig = Dict +GenerateContentConfig = Dict +EmbedContentConfig = Dict +EmbedContentResponse = Dict + +# 定义异常类型 +class ClientError(Exception): + pass + +class ServerError(Exception): + pass + +class UnknownFunctionCallArgumentError(Exception): + pass + +class UnsupportedFunctionError(Exception): + pass + +class FunctionInvocationError(Exception): + pass from src.config.api_ada_configs import ModelInfo, APIProvider from src.common.logger import get_logger @@ -44,18 +66,17 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("Gemini客户端") -gemini_safe_settings = [ - SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE), - SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE), - SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE), - SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE), - SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE), +SAFETY_SETTINGS = [ + {"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE}, + {"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE}, + {"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE}, + {"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE}, ] def _convert_messages( messages: list[Message], -) -> tuple[ContentListUnion, list[str] | None]: +) -> tuple[List[Dict], list[str] | None]: """ 转换消息格式 - 将消息转换为Gemini API所需的格式 :param messages: 消息列表 @@ -81,7 +102,7 @@ def _convert_messages( normalized_format = format_mapping.get(image_format.lower(), image_format.lower()) return f"image/{normalized_format}" - def _convert_message_item(message: Message) -> Content: + def _convert_message_item(message: Message) -> Dict: """ 转换单个消息格式,除了system和tool类型的消息 :param message: 消息对象 @@ -96,22 +117,25 @@ def _convert_messages( # 添加Content if isinstance(message.content, str): - content = [Part.from_text(text=message.content)] + content = [{"text": message.content}] elif isinstance(message.content, list): - content: List[Part] = [] + content = [] for item in message.content: if isinstance(item, tuple): - content.append( - Part.from_bytes(data=base64.b64decode(item[1]), mime_type=_get_correct_mime_type(item[0])) - ) + content.append({ + "inline_data": { + "mime_type": _get_correct_mime_type(item[0]), + "data": item[1] + } + }) elif isinstance(item, str): - content.append(Part.from_text(text=item)) + content.append({"text": item}) else: raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") - return Content(role=role, parts=content) + return {"role": role, "parts": content} - temp_list: list[ContentUnion] = [] + temp_list: List[Dict] = [] system_instructions: list[str] = [] for message in messages: if message.role == RoleType.System: @@ -338,13 +362,10 @@ def _default_normal_response_parser( @client_registry.register_client_class("gemini") class GeminiClient(BaseClient): - client: genai.Client - def __init__(self, api_provider: APIProvider): super().__init__(api_provider) - self.client = genai.Client( - api_key=api_provider.api_key, - ) # 这里和openai不一样,gemini会自己决定自己是否需要retry + # 配置 Google Generative AI + genai.configure(api_key=api_provider.api_key) async def get_response( self, @@ -396,18 +417,18 @@ class GeminiClient(BaseClient): "max_output_tokens": max_tokens, "temperature": temperature, "response_modalities": ["TEXT"], - "thinking_config": ThinkingConfig( - include_thoughts=True, - thinking_budget=( + "thinking_config": { + "include_thoughts": True, + "thinking_budget": ( extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复 ), - ), - "safety_settings": gemini_safe_settings, # 防止空回复问题 + }, + "safety_settings": SAFETY_SETTINGS, # 防止空回复问题 } if tools: - generation_config_dict["tools"] = Tool(function_declarations=tools) + generation_config_dict["tools"] = {"function_declarations": tools} if messages[1]: # 如果有system消息,则将其添加到配置中 generation_config_dict["system_instructions"] = messages[1] @@ -417,15 +438,18 @@ class GeminiClient(BaseClient): generation_config_dict["response_mime_type"] = "application/json" generation_config_dict["response_schema"] = response_format.to_dict() - generation_config = GenerateContentConfig(**generation_config_dict) + generation_config = generation_config_dict try: + # 创建模型实例 + model = genai.GenerativeModel(model_info.model_identifier) + if model_info.force_stream_mode: req_task = asyncio.create_task( - self.client.aio.models.generate_content_stream( - model=model_info.model_identifier, + model.generate_content_async( contents=messages[0], - config=generation_config, + generation_config=generation_config, + stream=True ) ) while not req_task.done(): @@ -437,10 +461,9 @@ class GeminiClient(BaseClient): resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) else: req_task = asyncio.create_task( - self.client.aio.models.generate_content( - model=model_info.model_identifier, + model.generate_content_async( contents=messages[0], - config=generation_config, + generation_config=generation_config ) ) while not req_task.done(): @@ -451,17 +474,18 @@ class GeminiClient(BaseClient): await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 resp, usage_record = async_response_parser(req_task.result()) - except (ClientError, ServerError) as e: - # 重封装ClientError和ServerError为RespNotOkException - raise RespNotOkException(e.code, e.message) from None - except ( - UnknownFunctionCallArgumentError, - UnsupportedFunctionError, - FunctionInvocationError, - ) as e: - raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None except Exception as e: - raise NetworkConnectionError() from e + # 处理Google Generative AI异常 + if "rate limit" in str(e).lower(): + raise RespNotOkException(429, "请求频率过高,请稍后再试") from None + elif "quota" in str(e).lower(): + raise RespNotOkException(429, "配额已用完") from None + elif "invalid" in str(e).lower() or "bad request" in str(e).lower(): + raise RespNotOkException(400, f"请求无效:{str(e)}") from None + elif "permission" in str(e).lower() or "forbidden" in str(e).lower(): + raise RespNotOkException(403, "权限不足") from None + else: + raise NetworkConnectionError() from e if usage_record: resp.usage = UsageRecord( @@ -535,7 +559,7 @@ class GeminiClient(BaseClient): extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024 ), ), - "safety_settings": gemini_safe_settings, + "safety_settings": SAFETY_SETTINGS, } generate_content_config = GenerateContentConfig(**generation_config_dict) prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**." From 6eec2daaa69225f16e0756824ebd61c5f1237743 Mon Sep 17 00:00:00 2001 From: Furina-1013-create <189647097+Furina-1013-create@users.noreply.github.com> Date: Mon, 18 Aug 2025 17:11:09 +0800 Subject: [PATCH 10/10] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=A7=86=E9=A2=91?= =?UTF-8?q?=E5=8F=AA=E6=9C=89=E5=8D=95=E5=B8=A7=E8=A2=AB=E8=AF=86=E5=88=AB?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98=E5=B9=B6=E5=86=8D=E6=AC=A1=E6=9B=B4?= =?UTF-8?q?=E6=96=B0requirements.txt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 3 +- src/multimodal/video_analyzer.py | 78 ++++++++++++++++++++++++++------ 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/requirements.txt b/requirements.txt index b5d553cb9..b3b690e30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -62,4 +62,5 @@ opencv-python Pillow chromadb asyncio -tavily-python \ No newline at end of file +tavily-python +google-generativeai = 0.8.5 \ No newline at end of file diff --git a/src/multimodal/video_analyzer.py b/src/multimodal/video_analyzer.py index db25fb105..033e28109 100644 --- a/src/multimodal/video_analyzer.py +++ b/src/multimodal/video_analyzer.py @@ -207,6 +207,9 @@ class VideoAnalyzer: """批量分析所有帧""" self.logger.info(f"开始批量分析{len(frames)}帧") + if not frames: + return "❌ 没有可分析的帧" + # 构建提示词 prompt = self.batch_analysis_prompt @@ -214,28 +217,77 @@ class VideoAnalyzer: prompt += f"\n\n用户问题: {user_question}" # 添加帧信息到提示词 + frame_info = [] for i, (frame_base64, timestamp) in enumerate(frames): if self.enable_frame_timing: - prompt += f"\n\n第{i+1}帧 (时间: {timestamp:.2f}s):" + frame_info.append(f"第{i+1}帧 (时间: {timestamp:.2f}s)") + else: + frame_info.append(f"第{i+1}帧") + + prompt += f"\n\n视频包含{len(frames)}帧图像:{', '.join(frame_info)}" + prompt += "\n\n请基于所有提供的帧图像进行综合分析,描述视频的完整内容和故事发展。" try: - # 使用第一帧进行分析(批量模式暂时使用单帧,后续可以优化为真正的多图片分析) - if frames: - frame_base64, _ = frames[0] - prompt += f"\n\n注意:当前显示的是第1帧,请基于这一帧和提示词进行分析。视频共有{len(frames)}帧。" + # 尝试使用多图片分析 + response = await self._analyze_multiple_frames(frames, prompt) + self.logger.info("✅ 批量多图片分析完成") + return response + + except Exception as e: + self.logger.error(f"❌ 多图片分析失败: {e}") + # 降级到单帧分析 + self.logger.warning("降级到单帧分析模式") + try: + frame_base64, timestamp = frames[0] + fallback_prompt = prompt + f"\n\n注意:由于技术限制,当前仅显示第1帧 (时间: {timestamp:.2f}s),视频共有{len(frames)}帧。请基于这一帧进行分析。" response, _ = await self.video_llm.generate_response_for_image( - prompt=prompt, + prompt=fallback_prompt, image_base64=frame_base64, image_format="jpeg" ) - self.logger.info("✅ 批量分析完成") + self.logger.info("✅ 降级的单帧分析完成") return response - else: - return "❌ 没有可分析的帧" - except Exception as e: - self.logger.error(f"❌ 批量分析失败: {e}") - raise + except Exception as fallback_e: + self.logger.error(f"❌ 降级分析也失败: {fallback_e}") + raise + + async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str: + """使用多图片分析方法""" + self.logger.info(f"开始构建包含{len(frames)}帧的多图片分析请求") + + # 导入MessageBuilder用于构建多图片消息 + from src.llm_models.payload_content.message import MessageBuilder, RoleType + from src.llm_models.utils_model import RequestType + + # 构建包含多张图片的消息 + message_builder = MessageBuilder().set_role(RoleType.User).add_text_content(prompt) + + # 添加所有帧图像 + for i, (frame_base64, timestamp) in enumerate(frames): + message_builder.add_image_content("jpeg", frame_base64) + # self.logger.info(f"已添加第{i+1}帧到分析请求 (时间: {timestamp:.2f}s, 图片大小: {len(frame_base64)} chars)") + + message = message_builder.build() + self.logger.info(f"✅ 多图片消息构建完成,包含{len(frames)}张图片") + + # 获取模型信息和客户端 + model_info, api_provider, client = await self.video_llm._get_best_model_and_client() + self.logger.info(f"使用模型: {model_info.name} 进行多图片分析") + + # 直接执行多图片请求 + api_response = await self.video_llm._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=[message], + temperature=None, + max_tokens=None + ) + + self.logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ") + return api_response.content or "❌ 未获得响应内容" async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str: """逐帧分析并汇总""" @@ -355,7 +407,7 @@ class VideoAnalyzer: # 计算视频hash值 video_hash = self._calculate_video_hash(video_bytes) - logger.info(f"视频hash: {video_hash[:16]}...") + # logger.info(f"视频hash: {video_hash[:16]}...") # 检查数据库中是否已存在该视频的分析结果 existing_video = self._check_video_exists(video_hash)