Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -19,14 +19,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
|
||||||
|
|
||||||
# 添加项目根目录到路径
|
# 添加项目根目录到路径
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
|
||||||
async def generate_missing_embeddings(
|
async def generate_missing_embeddings(
|
||||||
target_node_types: List[str] = None,
|
target_node_types: list[str] = None,
|
||||||
batch_size: int = 50,
|
batch_size: int = 50,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -46,13 +45,13 @@ async def generate_missing_embeddings(
|
|||||||
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
|
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
|
||||||
|
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'='*80}")
|
||||||
print(f"🔧 为节点生成嵌入向量")
|
print("🔧 为节点生成嵌入向量")
|
||||||
print(f"{'='*80}\n")
|
print(f"{'='*80}\n")
|
||||||
print(f"目标节点类型: {', '.join(target_node_types)}")
|
print(f"目标节点类型: {', '.join(target_node_types)}")
|
||||||
print(f"批处理大小: {batch_size}\n")
|
print(f"批处理大小: {batch_size}\n")
|
||||||
|
|
||||||
# 1. 初始化记忆管理器
|
# 1. 初始化记忆管理器
|
||||||
print(f"🔧 正在初始化记忆管理器...")
|
print("🔧 正在初始化记忆管理器...")
|
||||||
await initialize_memory_manager()
|
await initialize_memory_manager()
|
||||||
manager = get_memory_manager()
|
manager = get_memory_manager()
|
||||||
|
|
||||||
@@ -60,10 +59,10 @@ async def generate_missing_embeddings(
|
|||||||
print("❌ 记忆管理器初始化失败")
|
print("❌ 记忆管理器初始化失败")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"✅ 记忆管理器已初始化\n")
|
print("✅ 记忆管理器已初始化\n")
|
||||||
|
|
||||||
# 2. 获取已索引的节点ID
|
# 2. 获取已索引的节点ID
|
||||||
print(f"🔍 检查现有向量索引...")
|
print("🔍 检查现有向量索引...")
|
||||||
existing_node_ids = set()
|
existing_node_ids = set()
|
||||||
try:
|
try:
|
||||||
vector_count = manager.vector_store.collection.count()
|
vector_count = manager.vector_store.collection.count()
|
||||||
@@ -82,10 +81,10 @@ async def generate_missing_embeddings(
|
|||||||
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
|
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取已索引节点ID失败: {e}")
|
logger.warning(f"获取已索引节点ID失败: {e}")
|
||||||
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||||
|
|
||||||
# 3. 收集需要生成嵌入的节点
|
# 3. 收集需要生成嵌入的节点
|
||||||
print(f"🔍 扫描需要生成嵌入的节点...")
|
print("🔍 扫描需要生成嵌入的节点...")
|
||||||
all_memories = manager.graph_store.get_all_memories()
|
all_memories = manager.graph_store.get_all_memories()
|
||||||
|
|
||||||
nodes_to_process = []
|
nodes_to_process = []
|
||||||
@@ -110,7 +109,7 @@ async def generate_missing_embeddings(
|
|||||||
})
|
})
|
||||||
type_stats[node.node_type.value]["need_emb"] += 1
|
type_stats[node.node_type.value]["need_emb"] += 1
|
||||||
|
|
||||||
print(f"\n📊 扫描结果:")
|
print("\n📊 扫描结果:")
|
||||||
for node_type in target_node_types:
|
for node_type in target_node_types:
|
||||||
stats = type_stats[node_type]
|
stats = type_stats[node_type]
|
||||||
already_ok = stats["already_indexed"]
|
already_ok = stats["already_indexed"]
|
||||||
@@ -121,11 +120,11 @@ async def generate_missing_embeddings(
|
|||||||
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
|
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
|
||||||
|
|
||||||
if len(nodes_to_process) == 0:
|
if len(nodes_to_process) == 0:
|
||||||
print(f"✅ 所有节点已有嵌入向量,无需生成")
|
print("✅ 所有节点已有嵌入向量,无需生成")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 3. 批量生成嵌入
|
# 3. 批量生成嵌入
|
||||||
print(f"🚀 开始生成嵌入向量...\n")
|
print("🚀 开始生成嵌入向量...\n")
|
||||||
|
|
||||||
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
|
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
|
||||||
success_count = 0
|
success_count = 0
|
||||||
@@ -193,22 +192,22 @@ async def generate_missing_embeddings(
|
|||||||
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
|
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
|
||||||
|
|
||||||
# 4. 保存图数据(更新节点的 embedding 字段)
|
# 4. 保存图数据(更新节点的 embedding 字段)
|
||||||
print(f"💾 保存图数据...")
|
print("💾 保存图数据...")
|
||||||
try:
|
try:
|
||||||
await manager.persistence.save_graph_store(manager.graph_store)
|
await manager.persistence.save_graph_store(manager.graph_store)
|
||||||
print(f"✅ 图数据已保存\n")
|
print("✅ 图数据已保存\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存图数据失败", exc_info=True)
|
logger.error("保存图数据失败", exc_info=True)
|
||||||
print(f"❌ 保存失败: {e}\n")
|
print(f"❌ 保存失败: {e}\n")
|
||||||
|
|
||||||
# 5. 验证结果
|
# 5. 验证结果
|
||||||
print(f"🔍 验证向量索引...")
|
print("🔍 验证向量索引...")
|
||||||
final_vector_count = manager.vector_store.collection.count()
|
final_vector_count = manager.vector_store.collection.count()
|
||||||
stats = manager.graph_store.get_statistics()
|
stats = manager.graph_store.get_statistics()
|
||||||
total_nodes = stats["total_nodes"]
|
total_nodes = stats["total_nodes"]
|
||||||
|
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'='*80}")
|
||||||
print(f"📊 生成完成")
|
print("📊 生成完成")
|
||||||
print(f"{'='*80}")
|
print(f"{'='*80}")
|
||||||
print(f"处理节点数: {len(nodes_to_process)}")
|
print(f"处理节点数: {len(nodes_to_process)}")
|
||||||
print(f"成功生成: {success_count}")
|
print(f"成功生成: {success_count}")
|
||||||
@@ -219,7 +218,7 @@ async def generate_missing_embeddings(
|
|||||||
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
|
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
|
||||||
|
|
||||||
# 6. 测试搜索
|
# 6. 测试搜索
|
||||||
print(f"🧪 测试搜索功能...")
|
print("🧪 测试搜索功能...")
|
||||||
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
|
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
|
||||||
|
|
||||||
for query in test_queries:
|
for query in test_queries:
|
||||||
|
|||||||
@@ -4,13 +4,13 @@
|
|||||||
提供 Web API 用于可视化记忆图数据
|
提供 Web API 用于可视化记忆图数据
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from fastapi import APIRouter, HTTPException, Request, Query
|
from fastapi import APIRouter, HTTPException, Query, Request
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ router = APIRouter()
|
|||||||
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
|
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
|
||||||
|
|
||||||
|
|
||||||
def find_available_data_files() -> List[Path]:
|
def find_available_data_files() -> list[Path]:
|
||||||
"""查找所有可用的记忆图数据文件"""
|
"""查找所有可用的记忆图数据文件"""
|
||||||
files = []
|
files = []
|
||||||
if not data_dir.exists():
|
if not data_dir.exists():
|
||||||
@@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]:
|
|||||||
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
|
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
|
def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
|
||||||
"""从磁盘加载图数据"""
|
"""从磁盘加载图数据"""
|
||||||
global graph_data_cache, current_data_file
|
global graph_data_cache, current_data_file
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any
|
|||||||
if not graph_file.exists():
|
if not graph_file.exists():
|
||||||
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
|
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
|
||||||
|
|
||||||
with open(graph_file, "r", encoding="utf-8") as f:
|
with open(graph_file, encoding="utf-8") as f:
|
||||||
data = orjson.loads(f.read())
|
data = orjson.loads(f.read())
|
||||||
|
|
||||||
nodes = data.get("nodes", [])
|
nodes = data.get("nodes", [])
|
||||||
@@ -150,7 +150,7 @@ async def index(request: Request):
|
|||||||
return templates.TemplateResponse("visualizer.html", {"request": request})
|
return templates.TemplateResponse("visualizer.html", {"request": request})
|
||||||
|
|
||||||
|
|
||||||
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
|
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
|
||||||
"""从 MemoryManager 提取并格式化图数据"""
|
"""从 MemoryManager 提取并格式化图数据"""
|
||||||
if not memory_manager.graph_store:
|
if not memory_manager.graph_store:
|
||||||
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
|
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
|
||||||
@@ -261,7 +261,7 @@ async def get_paginated_graph(
|
|||||||
page: int = Query(1, ge=1, description="页码"),
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
|
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
|
||||||
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
|
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
|
||||||
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
|
node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
|
||||||
):
|
):
|
||||||
"""分页获取图数据,支持重要性过滤"""
|
"""分页获取图数据,支持重要性过滤"""
|
||||||
try:
|
try:
|
||||||
@@ -383,7 +383,7 @@ async def get_clustered_graph(
|
|||||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||||
|
|
||||||
|
|
||||||
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
|
def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
|
||||||
"""简单的图聚类算法:按类型和连接度聚类"""
|
"""简单的图聚类算法:按类型和连接度聚类"""
|
||||||
# 构建邻接表
|
# 构建邻接表
|
||||||
adjacency = defaultdict(set)
|
adjacency = defaultdict(set)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Literal
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
|
||||||
|
|||||||
@@ -481,7 +481,7 @@ class MessageManager:
|
|||||||
try:
|
try:
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
chat_stream = await chat_manager.get_stream(stream_id)
|
chat_stream = await chat_manager.get_stream(stream_id)
|
||||||
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
|
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
|
||||||
chat_stream.context_manager.context.is_chatter_processing = is_processing
|
chat_stream.context_manager.context.is_chatter_processing = is_processing
|
||||||
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
|
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -517,7 +517,7 @@ class MessageManager:
|
|||||||
try:
|
try:
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
chat_stream = await chat_manager.get_stream(stream_id)
|
chat_stream = await chat_manager.get_stream(stream_id)
|
||||||
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
|
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
|
||||||
return chat_stream.context_manager.context.is_chatter_processing
|
return chat_stream.context_manager.context.is_chatter_processing
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1177,10 +1177,10 @@ class DefaultReplyer:
|
|||||||
if unread_messages:
|
if unread_messages:
|
||||||
# 使用最后一条未读消息作为参考
|
# 使用最后一条未读消息作为参考
|
||||||
last_msg = unread_messages[-1]
|
last_msg = unread_messages[-1]
|
||||||
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform
|
platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
|
||||||
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else ""
|
user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
|
||||||
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else ""
|
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
|
||||||
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else ""
|
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
|
||||||
processed_plain_text = last_msg.processed_plain_text or ""
|
processed_plain_text = last_msg.processed_plain_text or ""
|
||||||
else:
|
else:
|
||||||
# 没有未读消息,使用默认值
|
# 没有未读消息,使用默认值
|
||||||
|
|||||||
@@ -5,12 +5,12 @@
|
|||||||
插件可以通过实现这些接口来扩展安全功能。
|
插件可以通过实现这些接口来扩展安全功能。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .interfaces import SecurityCheckResult, SecurityChecker
|
from .interfaces import SecurityChecker, SecurityCheckResult
|
||||||
from .manager import SecurityManager, get_security_manager
|
from .manager import SecurityManager, get_security_manager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SecurityChecker",
|
|
||||||
"SecurityCheckResult",
|
"SecurityCheckResult",
|
||||||
|
"SecurityChecker",
|
||||||
"SecurityManager",
|
"SecurityManager",
|
||||||
"get_security_manager",
|
"get_security_manager",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import Any
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel
|
from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
|
||||||
|
|
||||||
logger = get_logger("security.manager")
|
logger = get_logger("security.manager")
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
import re
|
import re
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
from src.chat.utils.prompt_params import PromptParameters
|
from src.chat.utils.prompt_params import PromptParameters
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -12,122 +14,205 @@ logger = get_logger("prompt_component_manager")
|
|||||||
|
|
||||||
class PromptComponentManager:
|
class PromptComponentManager:
|
||||||
"""
|
"""
|
||||||
管理所有 `BasePrompt` 组件的单例类。
|
一个统一的、动态的、可观测的提示词组件管理中心。
|
||||||
|
|
||||||
该管理器负责:
|
该管理器是整个提示词动态注入系统的核心,它负责:
|
||||||
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
|
1. **规则加载**: 在系统启动时,自动扫描所有已注册的 `BasePrompt` 组件,
|
||||||
2. 根据注入点(目标Prompt名称)对它们进行筛选。
|
并将其静态定义的 `injection_rules` 加载为默认的动态规则。
|
||||||
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
2. **动态管理**: 提供线程安全的 API,允许在运行时动态地添加、更新或移除注入规则,
|
||||||
|
使得提示词的结构可以被实时调整。
|
||||||
|
3. **状态观测**: 提供丰富的查询 API,用于观测系统当前完整的注入状态,
|
||||||
|
例如查询所有注入到特定目标的规则、或查询某个组件定义的所有规则。
|
||||||
|
4. **注入应用**: 在构建核心 Prompt 时,根据统一的、按优先级排序的规则集,
|
||||||
|
动态地修改和装配提示词模板,实现灵活的提示词组合。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, type[BasePrompt]]]:
|
def __init__(self):
|
||||||
"""
|
"""初始化管理器实例。"""
|
||||||
获取指定目标Prompt的所有注入规则及其关联的组件类。
|
# _dynamic_rules 是管理器的核心状态,存储所有注入规则。
|
||||||
|
# 结构: {
|
||||||
|
# "target_prompt_name": {
|
||||||
|
# "prompt_component_name": (InjectionRule, content_provider, source)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# content_provider 是一个异步函数,用于在应用规则时动态生成注入内容。
|
||||||
|
# source 记录了规则的来源(例如 "static_default" 或 "runtime")。
|
||||||
|
self._dynamic_rules: dict[str, dict[str, tuple[InjectionRule, Callable[..., Awaitable[str]], str]]] = {}
|
||||||
|
self._lock = asyncio.Lock() # 使用异步锁确保对 _dynamic_rules 的并发访问安全。
|
||||||
|
self._initialized = False # 标记静态规则是否已加载,防止重复加载。
|
||||||
|
|
||||||
Args:
|
# --- 核心生命周期与初始化 ---
|
||||||
target_prompt_name (str): 目标 Prompt 的名称。
|
|
||||||
|
|
||||||
Returns:
|
def load_static_rules(self):
|
||||||
list[tuple[InjectionRule, Type[BasePrompt]]]: 一个元组列表,
|
|
||||||
每个元组包含一个注入规则和其对应的 Prompt 组件类,并已根据优先级排序。
|
|
||||||
"""
|
"""
|
||||||
# 从注册表中获取所有已启用的 PROMPT 类型的组件
|
在系统启动时加载所有静态注入规则。
|
||||||
|
|
||||||
|
该方法会扫描所有已在 `component_registry` 中注册并启用的 Prompt 组件,
|
||||||
|
将其类变量 `injection_rules` 转换为管理器的动态规则。
|
||||||
|
这确保了所有插件定义的默认注入行为在系统启动时就能生效。
|
||||||
|
此操作是幂等的,一旦初始化完成就不会重复执行。
|
||||||
|
"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
logger.info("正在加载静态 Prompt 注入规则...")
|
||||||
|
|
||||||
|
# 从组件注册表中获取所有已启用的 Prompt 组件
|
||||||
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
|
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
|
||||||
matching_rules = []
|
|
||||||
|
|
||||||
# 遍历所有启用的 Prompt 组件,查找与目标 Prompt 相关的注入规则
|
|
||||||
for prompt_name, prompt_info in enabled_prompts.items():
|
for prompt_name, prompt_info in enabled_prompts.items():
|
||||||
if not isinstance(prompt_info, PromptInfo):
|
if not isinstance(prompt_info, PromptInfo):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# prompt_info.injection_rules 已经经过了后向兼容处理,确保总是列表
|
|
||||||
for rule in prompt_info.injection_rules:
|
|
||||||
# 如果规则的目标是当前指定的 Prompt
|
|
||||||
if rule.target_prompt == target_prompt_name:
|
|
||||||
# 获取该规则对应的组件类
|
|
||||||
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
|
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
|
||||||
# 确保获取到的确实是一个 BasePrompt 的子类
|
if not (component_class and issubclass(component_class, BasePrompt)):
|
||||||
if component_class and issubclass(component_class, BasePrompt):
|
logger.warning(f"无法为 '{prompt_name}' 加载静态规则,因为它不是一个有效的 Prompt 组件。")
|
||||||
matching_rules.append((rule, component_class))
|
continue
|
||||||
|
|
||||||
# 根据规则的优先级进行排序,数字越小,优先级越高,越先应用
|
def create_provider(cls: type[BasePrompt]) -> Callable[[PromptParameters], Awaitable[str]]:
|
||||||
matching_rules.sort(key=lambda x: x[0].priority)
|
"""
|
||||||
return matching_rules
|
为静态组件创建一个内容提供者闭包 (Content Provider Closure)。
|
||||||
|
|
||||||
|
这个闭包捕获了组件的类 `cls`,并返回一个标准的 `content_provider` 异步函数。
|
||||||
|
当 `apply_injections` 需要内容时,它会调用这个函数。
|
||||||
|
函数内部会实例化组件,并执行其 `execute` 方法来获取注入内容。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls (type[BasePrompt]): 需要为其创建提供者的 Prompt 组件类。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[PromptParameters], Awaitable[str]]: 一个符合管理器标准的异步内容提供者。
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def content_provider(params: PromptParameters) -> str:
|
||||||
|
"""实际执行内容生成的异步函数。"""
|
||||||
|
try:
|
||||||
|
# 从注册表获取最新的组件信息,包括插件配置
|
||||||
|
p_info = component_registry.get_component_info(cls.prompt_name, ComponentType.PROMPT)
|
||||||
|
plugin_config = {}
|
||||||
|
if isinstance(p_info, PromptInfo):
|
||||||
|
plugin_config = component_registry.get_plugin_config(p_info.plugin_name)
|
||||||
|
|
||||||
|
# 实例化组件并执行
|
||||||
|
instance = cls(params=params, plugin_config=plugin_config)
|
||||||
|
result = await instance.execute()
|
||||||
|
return str(result) if result is not None else ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"执行静态规则提供者 '{cls.prompt_name}' 时出错: {e}", exc_info=True)
|
||||||
|
return "" # 出错时返回空字符串,避免影响主流程
|
||||||
|
|
||||||
|
return content_provider
|
||||||
|
|
||||||
|
# 为该组件的每条静态注入规则创建并注册一个动态规则
|
||||||
|
for rule in prompt_info.injection_rules:
|
||||||
|
provider = create_provider(component_class)
|
||||||
|
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
|
||||||
|
target_rules[prompt_name] = (rule, provider, "static_default")
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info(f"静态 Prompt 注入规则加载完成,共处理 {len(enabled_prompts)} 个组件。")
|
||||||
|
|
||||||
|
# --- 运行时规则管理 API ---
|
||||||
|
|
||||||
|
async def add_injection_rule(
|
||||||
|
self,
|
||||||
|
prompt_name: str,
|
||||||
|
rule: InjectionRule,
|
||||||
|
content_provider: Callable[..., Awaitable[str]],
|
||||||
|
source: str = "runtime",
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
动态添加或更新一条注入规则。
|
||||||
|
|
||||||
|
此方法允许在系统运行时,由外部逻辑(如插件、命令)向管理器中添加新的注入行为。
|
||||||
|
如果已存在同名组件针对同一目标的规则,此方法会覆盖旧规则。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_name (str): 动态注入组件的唯一名称。
|
||||||
|
rule (InjectionRule): 描述注入行为的规则对象。
|
||||||
|
content_provider (Callable[..., Awaitable[str]]):
|
||||||
|
一个异步函数,用于在应用注入时动态生成内容。
|
||||||
|
函数签名应为: `async def provider(params: "PromptParameters") -> str`
|
||||||
|
source (str, optional): 规则的来源标识,默认为 "runtime"。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果成功添加或更新,则返回 True。
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
|
||||||
|
target_rules[prompt_name] = (rule, content_provider, source)
|
||||||
|
logger.info(f"成功添加/更新注入规则: '{prompt_name}' -> '{rule.target_prompt}' (来源: {source})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def remove_injection_rule(self, prompt_name: str, target_prompt: str) -> bool:
|
||||||
|
"""
|
||||||
|
移除一条动态注入规则。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_name (str): 要移除的注入组件的名称。
|
||||||
|
target_prompt (str): 该组件注入的目标核心提示词名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果成功移除,则返回 True;如果规则不存在,则返回 False。
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if target_prompt in self._dynamic_rules and prompt_name in self._dynamic_rules[target_prompt]:
|
||||||
|
del self._dynamic_rules[target_prompt][prompt_name]
|
||||||
|
# 如果目标下已无任何规则,则清理掉这个键
|
||||||
|
if not self._dynamic_rules[target_prompt]:
|
||||||
|
del self._dynamic_rules[target_prompt]
|
||||||
|
logger.info(f"成功移除注入规则: '{prompt_name}' from '{target_prompt}'")
|
||||||
|
return True
|
||||||
|
logger.warning(f"尝试移除注入规则失败: 未找到 '{prompt_name}' on '{target_prompt}'")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# --- 核心注入逻辑 ---
|
||||||
|
|
||||||
async def apply_injections(
|
async def apply_injections(
|
||||||
self, target_prompt_name: str, original_template: str, params: PromptParameters
|
self, target_prompt_name: str, original_template: str, params: PromptParameters
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
获取、实例化并执行所有相关组件,然后根据注入规则修改原始模板。
|
【核心方法】根据目标名称,应用所有匹配的注入规则,返回修改后的模板。
|
||||||
|
|
||||||
这是一个三步走的过程:
|
这是提示词构建流程中的关键步骤。它会执行以下操作:
|
||||||
1. 实例化所有需要执行的组件。
|
1. 检查并确保静态规则已加载。
|
||||||
2. 并行执行它们的 `execute` 方法以获取注入内容。
|
2. 获取所有注入到 `target_prompt_name` 的规则。
|
||||||
3. 按照优先级顺序,将内容注入到原始模板中。
|
3. 按照规则的 `priority` 属性进行升序排序,优先级数字越小越先应用。
|
||||||
|
4. 依次执行每个规则的 `content_provider` 来异步获取注入内容。
|
||||||
|
5. 根据规则的 `injection_type` (如 PREPEND, APPEND, REPLACE 等) 将内容应用到模板上。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_prompt_name (str): 目标 Prompt 的名称。
|
target_prompt_name (str): 目标核心提示词的名称。
|
||||||
original_template (str): 原始的、未经修改的 Prompt 模板字符串。
|
original_template (str): 未经修改的原始提示词模板。
|
||||||
params (PromptParameters): 传递给 Prompt 组件实例的参数。
|
params (PromptParameters): 当前请求的参数,会传递给 `content_provider`。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 应用了所有注入规则后,修改过的 Prompt 模板字符串。
|
str: 应用了所有注入规则后,最终生成的提示词模板字符串。
|
||||||
"""
|
"""
|
||||||
rules_with_classes = self._get_rules_for(target_prompt_name)
|
if not self._initialized:
|
||||||
# 如果没有找到任何匹配的规则,就直接返回原始模板,啥也不干
|
self.load_static_rules()
|
||||||
if not rules_with_classes:
|
|
||||||
|
# 步骤 1: 获取所有指向当前目标的规则
|
||||||
|
# 使用 .values() 获取 (rule, provider, source) 元组列表
|
||||||
|
rules_for_target = list(self._dynamic_rules.get(target_prompt_name, {}).values())
|
||||||
|
if not rules_for_target:
|
||||||
return original_template
|
return original_template
|
||||||
|
|
||||||
# --- 第一步: 实例化所有需要执行的组件 ---
|
# 步骤 2: 按优先级排序,数字越小越优先
|
||||||
instance_map = {} # 存储组件实例,虽然目前没直接用,但留着总没错
|
rules_for_target.sort(key=lambda x: x[0].priority)
|
||||||
tasks = [] # 存放所有需要并行执行的 execute 异步任务
|
|
||||||
components_to_execute = [] # 存放需要执行的组件类,用于后续结果映射
|
|
||||||
|
|
||||||
for rule, component_class in rules_with_classes:
|
# 步骤 3: 依次执行内容提供者并根据注入类型修改模板
|
||||||
# 如果注入类型是 REMOVE,那就不需要执行组件了,因为它不产生内容
|
modified_template = original_template
|
||||||
|
for rule, provider, source in rules_for_target:
|
||||||
|
content = ""
|
||||||
|
# 对于非 REMOVE 类型的注入,需要先获取内容
|
||||||
if rule.injection_type != InjectionType.REMOVE:
|
if rule.injection_type != InjectionType.REMOVE:
|
||||||
try:
|
try:
|
||||||
# 获取组件的元信息,主要是为了拿到插件名称来读取插件配置
|
content = await provider(params)
|
||||||
prompt_info = component_registry.get_component_info(
|
|
||||||
component_class.prompt_name, ComponentType.PROMPT
|
|
||||||
)
|
|
||||||
if not isinstance(prompt_info, PromptInfo):
|
|
||||||
plugin_config = {}
|
|
||||||
else:
|
|
||||||
# 从注册表获取该组件所属插件的配置
|
|
||||||
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
|
|
||||||
|
|
||||||
# 实例化组件,并传入参数和插件配置
|
|
||||||
instance = component_class(params=params, plugin_config=plugin_config)
|
|
||||||
instance_map[component_class.prompt_name] = instance
|
|
||||||
# 将组件的 execute 方法作为一个任务添加到列表中
|
|
||||||
tasks.append(instance.execute())
|
|
||||||
components_to_execute.append(component_class)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
|
logger.error(f"执行规则 '{rule}' (来源: {source}) 的内容提供者时失败: {e}", exc_info=True)
|
||||||
# 即使失败,也添加一个立即完成的空任务,以保持与其他任务的索引同步
|
continue # 跳过失败的 provider,不中断整个流程
|
||||||
tasks.append(asyncio.create_task(asyncio.sleep(0, result=e))) # type: ignore
|
|
||||||
|
|
||||||
# --- 第二步: 并行执行所有组件的 execute 方法 ---
|
|
||||||
# 使用 asyncio.gather 来同时运行所有任务,提高效率
|
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
# 创建一个从组件名到执行结果的映射,方便后续查找
|
|
||||||
result_map = {
|
|
||||||
components_to_execute[i].prompt_name: res
|
|
||||||
for i, res in enumerate(results)
|
|
||||||
if not isinstance(res, Exception) # 只包含成功的结果
|
|
||||||
}
|
|
||||||
# 单独处理并记录执行失败的组件
|
|
||||||
for i, res in enumerate(results):
|
|
||||||
if isinstance(res, Exception):
|
|
||||||
logger.error(f"执行 Prompt 组件 '{components_to_execute[i].prompt_name}' 失败: {res}")
|
|
||||||
|
|
||||||
# --- 第三步: 按优先级顺序应用注入规则 ---
|
|
||||||
modified_template = original_template
|
|
||||||
for rule, component_class in rules_with_classes:
|
|
||||||
# 从结果映射中获取该组件生成的内容
|
|
||||||
content = result_map.get(component_class.prompt_name)
|
|
||||||
|
|
||||||
|
# 应用注入逻辑
|
||||||
try:
|
try:
|
||||||
if rule.injection_type == InjectionType.PREPEND:
|
if rule.injection_type == InjectionType.PREPEND:
|
||||||
if content:
|
if content:
|
||||||
@@ -136,28 +221,178 @@ class PromptComponentManager:
|
|||||||
if content:
|
if content:
|
||||||
modified_template = f"{modified_template}\n{content}"
|
modified_template = f"{modified_template}\n{content}"
|
||||||
elif rule.injection_type == InjectionType.REPLACE:
|
elif rule.injection_type == InjectionType.REPLACE:
|
||||||
# 使用正则表达式替换目标内容
|
# 只有在 content 不为 None 且 target_content 有效时才执行替换
|
||||||
if content and rule.target_content:
|
if content is not None and rule.target_content:
|
||||||
modified_template = re.sub(rule.target_content, str(content), modified_template)
|
modified_template = re.sub(rule.target_content, str(content), modified_template)
|
||||||
elif rule.injection_type == InjectionType.INSERT_AFTER:
|
elif rule.injection_type == InjectionType.INSERT_AFTER:
|
||||||
# 在匹配到的内容后面插入
|
|
||||||
if content and rule.target_content:
|
if content and rule.target_content:
|
||||||
# re.sub a little trick: \g<0> represents the entire matched string
|
# 使用 `\g<0>` 在正则匹配的整个内容后添加新内容
|
||||||
replacement = f"\\g<0>\n{content}"
|
replacement = f"\\g<0>\n{content}"
|
||||||
modified_template = re.sub(rule.target_content, replacement, modified_template)
|
modified_template = re.sub(rule.target_content, replacement, modified_template)
|
||||||
elif rule.injection_type == InjectionType.REMOVE:
|
elif rule.injection_type == InjectionType.REMOVE:
|
||||||
# 使用正则表达式移除目标内容
|
|
||||||
if rule.target_content:
|
if rule.target_content:
|
||||||
modified_template = re.sub(rule.target_content, "", modified_template)
|
modified_template = re.sub(rule.target_content, "", modified_template)
|
||||||
except re.error as e:
|
except re.error as e:
|
||||||
logger.error(
|
logger.error(f"应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')")
|
||||||
f"在为 '{component_class.prompt_name}' 应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"应用 Prompt 注入规则 '{rule}' 失败: {e}")
|
logger.error(f"应用注入规则 '{rule}' (来源: {source}) 失败: {e}", exc_info=True)
|
||||||
|
|
||||||
return modified_template
|
return modified_template
|
||||||
|
|
||||||
|
async def preview_prompt_injections(
|
||||||
|
self, target_prompt_name: str, params: PromptParameters
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
【预览功能】模拟应用所有注入规则,返回最终生成的模板字符串,而不实际修改任何状态。
|
||||||
|
|
||||||
# 创建全局单例
|
这个方法对于调试和测试非常有用,可以查看在特定参数下,
|
||||||
|
一个核心提示词经过所有注入规则处理后会变成什么样子。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_prompt_name (str): 希望预览的目标核心提示词名称。
|
||||||
|
params (PromptParameters): 模拟的请求参数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 模拟生成的最终提示词模板字符串。如果找不到模板,则返回错误信息。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 从全局提示词管理器获取最原始的模板内容
|
||||||
|
from src.chat.utils.prompt import global_prompt_manager
|
||||||
|
original_prompt = global_prompt_manager._prompts.get(target_prompt_name)
|
||||||
|
if not original_prompt:
|
||||||
|
logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。")
|
||||||
|
return f"Error: Prompt '{target_prompt_name}' not found."
|
||||||
|
original_template = original_prompt.template
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。")
|
||||||
|
return f"Error: Prompt '{target_prompt_name}' not found."
|
||||||
|
|
||||||
|
# 直接调用核心注入逻辑来模拟结果
|
||||||
|
return await self.apply_injections(target_prompt_name, original_template, params)
|
||||||
|
|
||||||
|
# --- 状态观测与查询 API ---
|
||||||
|
|
||||||
|
def get_core_prompts(self) -> list[str]:
|
||||||
|
"""获取所有已注册的核心提示词模板名称列表(即所有可注入的目标)。"""
|
||||||
|
from src.chat.utils.prompt import global_prompt_manager
|
||||||
|
return list(global_prompt_manager._prompts.keys())
|
||||||
|
|
||||||
|
def get_core_prompt_contents(self) -> dict[str, str]:
|
||||||
|
"""获取所有核心提示词模板的原始内容。"""
|
||||||
|
from src.chat.utils.prompt import global_prompt_manager
|
||||||
|
return {name: prompt.template for name, prompt in global_prompt_manager._prompts.items()}
|
||||||
|
|
||||||
|
def get_registered_prompt_component_info(self) -> list[PromptInfo]:
|
||||||
|
"""获取所有在 ComponentRegistry 中注册的 Prompt 组件信息。"""
|
||||||
|
components = component_registry.get_components_by_type(ComponentType.PROMPT).values()
|
||||||
|
return [info for info in components if isinstance(info, PromptInfo)]
|
||||||
|
|
||||||
|
async def get_full_injection_map(self) -> dict[str, list[dict]]:
|
||||||
|
"""
|
||||||
|
获取当前完整的注入映射图。
|
||||||
|
|
||||||
|
此方法提供了一个系统全局的注入视图,展示了每个核心提示词(target)
|
||||||
|
被哪些注入组件(source)以何种优先级注入。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, list[dict]]: 一个字典,键是目标提示词名称,
|
||||||
|
值是按优先级排序的注入信息列表。
|
||||||
|
`[{"name": str, "priority": int, "source": str}]`
|
||||||
|
"""
|
||||||
|
injection_map = {}
|
||||||
|
async with self._lock:
|
||||||
|
# 合并所有动态规则的目标和所有核心提示词,确保所有潜在目标都被包含
|
||||||
|
all_targets = set(self._dynamic_rules.keys()) | set(self.get_core_prompts())
|
||||||
|
for target in sorted(all_targets):
|
||||||
|
rules = self._dynamic_rules.get(target, {})
|
||||||
|
if not rules:
|
||||||
|
injection_map[target] = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
info_list = []
|
||||||
|
for prompt_name, (rule, _, source) in rules.items():
|
||||||
|
info_list.append({"name": prompt_name, "priority": rule.priority, "source": source})
|
||||||
|
|
||||||
|
# 按优先级排序后存入 map
|
||||||
|
info_list.sort(key=lambda x: x["priority"])
|
||||||
|
injection_map[target] = info_list
|
||||||
|
return injection_map
|
||||||
|
|
||||||
|
async def get_injections_for_prompt(self, target_prompt_name: str) -> list[dict]:
|
||||||
|
"""
|
||||||
|
获取指定核心提示词模板的所有注入信息(包含详细规则)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_prompt_name (str): 目标核心提示词的名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: 一个包含注入规则详细信息的列表,已按优先级排序。
|
||||||
|
"""
|
||||||
|
rules_for_target = self._dynamic_rules.get(target_prompt_name, {})
|
||||||
|
if not rules_for_target:
|
||||||
|
return []
|
||||||
|
|
||||||
|
info_list = []
|
||||||
|
for prompt_name, (rule, _, source) in rules_for_target.items():
|
||||||
|
info_list.append(
|
||||||
|
{
|
||||||
|
"name": prompt_name,
|
||||||
|
"priority": rule.priority,
|
||||||
|
"source": source,
|
||||||
|
"injection_type": rule.injection_type.value,
|
||||||
|
"target_content": rule.target_content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
info_list.sort(key=lambda x: x["priority"])
|
||||||
|
return info_list
|
||||||
|
|
||||||
|
def get_all_dynamic_rules(self) -> dict[str, dict[str, "InjectionRule"]]:
|
||||||
|
"""
|
||||||
|
获取所有当前的动态注入规则,以 InjectionRule 对象形式返回。
|
||||||
|
|
||||||
|
此方法返回一个深拷贝的规则副本,隐藏了 `content_provider` 等内部实现细节。
|
||||||
|
适合用于展示或序列化当前的规则配置。
|
||||||
|
"""
|
||||||
|
rules_copy = {}
|
||||||
|
for target, rules in self._dynamic_rules.items():
|
||||||
|
target_copy = {name: rule for name, (rule, _, _) in rules.items()}
|
||||||
|
rules_copy[target] = target_copy
|
||||||
|
return copy.deepcopy(rules_copy)
|
||||||
|
|
||||||
|
def get_rules_for_target(self, target_prompt: str) -> dict[str, InjectionRule]:
|
||||||
|
"""
|
||||||
|
获取所有注入到指定核心提示词的动态规则。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_prompt (str): 目标核心提示词的名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, InjectionRule]: 一个字典,键是注入组件的名称,值是 `InjectionRule` 对象。
|
||||||
|
如果找不到任何注入到该目标的规则,则返回一个空字典。
|
||||||
|
"""
|
||||||
|
target_rules = self._dynamic_rules.get(target_prompt, {})
|
||||||
|
return {name: copy.deepcopy(rule_info[0]) for name, rule_info in target_rules.items()}
|
||||||
|
|
||||||
|
def get_rules_by_component(self, component_name: str) -> dict[str, InjectionRule]:
|
||||||
|
"""
|
||||||
|
获取由指定的单个注入组件定义的所有动态规则。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_name (str): 注入组件的名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, InjectionRule]: 一个字典,键是目标核心提示词的名称,值是 `InjectionRule` 对象。
|
||||||
|
如果该组件没有定义任何注入规则,则返回一个空字典。
|
||||||
|
"""
|
||||||
|
found_rules = {}
|
||||||
|
for target, rules in self._dynamic_rules.items():
|
||||||
|
if component_name in rules:
|
||||||
|
rule_info = rules[component_name]
|
||||||
|
found_rules[target] = copy.deepcopy(rule_info[0])
|
||||||
|
return found_rules
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局单例 (Singleton)
|
||||||
|
# 在整个应用程序中,应该只使用这一个 `prompt_component_manager` 实例,
|
||||||
|
# 以确保所有部分都共享和操作同一份动态规则集。
|
||||||
prompt_component_manager = PromptComponentManager()
|
prompt_component_manager = PromptComponentManager()
|
||||||
|
|||||||
@@ -9,11 +9,12 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import builtins
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.memory_utils import estimate_size_smart
|
from src.common.memory_utils import estimate_size_smart
|
||||||
@@ -96,7 +97,7 @@ class LRUCache(Generic[T]):
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._stats = CacheStats()
|
self._stats = CacheStats()
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[T]:
|
async def get(self, key: str) -> T | None:
|
||||||
"""获取缓存值
|
"""获取缓存值
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -137,8 +138,8 @@ class LRUCache(Generic[T]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
value: T,
|
value: T,
|
||||||
size: Optional[int] = None,
|
size: int | None = None,
|
||||||
ttl: Optional[float] = None,
|
ttl: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""设置缓存值
|
"""设置缓存值
|
||||||
|
|
||||||
@@ -287,8 +288,8 @@ class MultiLevelCache:
|
|||||||
async def get(
|
async def get(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
loader: Optional[Callable[[], Any]] = None,
|
loader: Callable[[], Any] | None = None,
|
||||||
) -> Optional[Any]:
|
) -> Any | None:
|
||||||
"""从缓存获取数据
|
"""从缓存获取数据
|
||||||
|
|
||||||
查询顺序:L1 -> L2 -> loader
|
查询顺序:L1 -> L2 -> loader
|
||||||
@@ -329,8 +330,8 @@ class MultiLevelCache:
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
value: Any,
|
value: Any,
|
||||||
size: Optional[int] = None,
|
size: int | None = None,
|
||||||
ttl: Optional[float] = None,
|
ttl: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""设置缓存值
|
"""设置缓存值
|
||||||
|
|
||||||
@@ -390,7 +391,7 @@ class MultiLevelCache:
|
|||||||
await self.l2_cache.clear()
|
await self.l2_cache.clear()
|
||||||
logger.info("所有缓存已清空")
|
logger.info("所有缓存已清空")
|
||||||
|
|
||||||
async def get_stats(self) -> Dict[str, Any]:
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
|
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
|
||||||
# 🔧 修复:并行获取统计信息,避免锁嵌套
|
# 🔧 修复:并行获取统计信息,避免锁嵌套
|
||||||
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
|
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
|
||||||
@@ -492,7 +493,7 @@ class MultiLevelCache:
|
|||||||
logger.error(f"{cache_name}统计获取异常: {e}")
|
logger.error(f"{cache_name}统计获取异常: {e}")
|
||||||
return CacheStats()
|
return CacheStats()
|
||||||
|
|
||||||
async def _get_cache_keys_safe(self, cache) -> Set[str]:
|
async def _get_cache_keys_safe(self, cache) -> builtins.set[str]:
|
||||||
"""安全获取缓存键集合(带超时)"""
|
"""安全获取缓存键集合(带超时)"""
|
||||||
try:
|
try:
|
||||||
# 快速获取键集合,使用超时避免死锁
|
# 快速获取键集合,使用超时避免死锁
|
||||||
@@ -507,12 +508,12 @@ class MultiLevelCache:
|
|||||||
logger.error(f"缓存键获取异常: {e}")
|
logger.error(f"缓存键获取异常: {e}")
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
async def _extract_keys_with_lock(self, cache) -> Set[str]:
|
async def _extract_keys_with_lock(self, cache) -> builtins.set[str]:
|
||||||
"""在锁保护下提取键集合"""
|
"""在锁保护下提取键集合"""
|
||||||
async with cache._lock:
|
async with cache._lock:
|
||||||
return set(cache._cache.keys())
|
return set(cache._cache.keys())
|
||||||
|
|
||||||
async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int:
|
async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int:
|
||||||
"""安全计算内存使用(带超时)"""
|
"""安全计算内存使用(带超时)"""
|
||||||
if not keys:
|
if not keys:
|
||||||
return 0
|
return 0
|
||||||
@@ -529,7 +530,7 @@ class MultiLevelCache:
|
|||||||
logger.error(f"内存计算异常: {e}")
|
logger.error(f"内存计算异常: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int:
|
async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int:
|
||||||
"""在锁保护下计算内存使用"""
|
"""在锁保护下计算内存使用"""
|
||||||
total_size = 0
|
total_size = 0
|
||||||
async with cache._lock:
|
async with cache._lock:
|
||||||
@@ -749,7 +750,7 @@ class MultiLevelCache:
|
|||||||
|
|
||||||
|
|
||||||
# 全局缓存实例
|
# 全局缓存实例
|
||||||
_global_cache: Optional[MultiLevelCache] = None
|
_global_cache: MultiLevelCache | None = None
|
||||||
_cache_lock = asyncio.Lock()
|
_cache_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import socket
|
|||||||
|
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from uvicorn import Config
|
from uvicorn import Config
|
||||||
from uvicorn import Server as UvicornServer
|
from uvicorn import Server as UvicornServer
|
||||||
|
|||||||
@@ -1098,7 +1098,7 @@ class MemoryManager:
|
|||||||
# 2. 清理孤立边(指向已删除节点的边)
|
# 2. 清理孤立边(指向已删除节点的边)
|
||||||
edges_to_remove = []
|
edges_to_remove = []
|
||||||
|
|
||||||
for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'):
|
for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
|
||||||
# 检查边的源节点和目标节点是否还存在于node_to_memories中
|
# 检查边的源节点和目标节点是否还存在于node_to_memories中
|
||||||
if source not in self.graph_store.node_to_memories or \
|
if source not in self.graph_store.node_to_memories or \
|
||||||
target not in self.graph_store.node_to_memories:
|
target not in self.graph_store.node_to_memories:
|
||||||
@@ -2301,7 +2301,7 @@ class MemoryManager:
|
|||||||
# 使用 asyncio.wait_for 来支持取消
|
# 使用 asyncio.wait_for 来支持取消
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
asyncio.sleep(initial_delay),
|
asyncio.sleep(initial_delay),
|
||||||
timeout=float('inf') # 允许随时取消
|
timeout=float("inf") # 允许随时取消
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否仍然需要运行
|
# 检查是否仍然需要运行
|
||||||
|
|||||||
@@ -559,8 +559,8 @@ class MemoryTools:
|
|||||||
)
|
)
|
||||||
if len(initial_memory_ids) == 0:
|
if len(initial_memory_ids) == 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"⚠️ 向量搜索未找到任何记忆!"
|
"⚠️ 向量搜索未找到任何记忆!"
|
||||||
f"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
|
"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
|
||||||
)
|
)
|
||||||
# 输出相似节点的详细信息用于调试
|
# 输出相似节点的详细信息用于调试
|
||||||
if similar_nodes:
|
if similar_nodes:
|
||||||
@@ -738,12 +738,12 @@ class MemoryTools:
|
|||||||
activation_score = memory.activation
|
activation_score = memory.activation
|
||||||
|
|
||||||
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
|
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
|
||||||
memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type)
|
memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type)
|
||||||
|
|
||||||
# 检测记忆的主要节点类型
|
# 检测记忆的主要节点类型
|
||||||
node_types_count = {}
|
node_types_count = {}
|
||||||
for node in memory.nodes:
|
for node in memory.nodes:
|
||||||
nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type)
|
nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type)
|
||||||
node_types_count[nt] = node_types_count.get(nt, 0) + 1
|
node_types_count[nt] = node_types_count.get(nt, 0) + 1
|
||||||
|
|
||||||
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
|
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
|
||||||
@@ -1092,6 +1092,7 @@ class MemoryTools:
|
|||||||
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
|
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
# 清理Markdown代码块
|
# 清理Markdown代码块
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ async def expand_memories_with_semantic_filter(
|
|||||||
source_node_id = edge.source_id
|
source_node_id = edge.source_id
|
||||||
|
|
||||||
# 🆕 根据边类型设置权重(优先扩展REFERENCE、ATTRIBUTE相关的边)
|
# 🆕 根据边类型设置权重(优先扩展REFERENCE、ATTRIBUTE相关的边)
|
||||||
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
|
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type)
|
||||||
if edge_type_str == "REFERENCE":
|
if edge_type_str == "REFERENCE":
|
||||||
edge_weight = 1.3 # REFERENCE边权重最高(引用关系)
|
edge_weight = 1.3 # REFERENCE边权重最高(引用关系)
|
||||||
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
|
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
|
||||||
|
|||||||
@@ -78,11 +78,9 @@ __all__ = [
|
|||||||
# 消息
|
# 消息
|
||||||
"MaiMessages",
|
"MaiMessages",
|
||||||
# 工具函数
|
# 工具函数
|
||||||
"ManifestValidator",
|
|
||||||
"PluginInfo",
|
"PluginInfo",
|
||||||
# 增强命令系统
|
# 增强命令系统
|
||||||
"PlusCommand",
|
"PlusCommand",
|
||||||
"PlusCommandAdapter",
|
|
||||||
"PythonDependency",
|
"PythonDependency",
|
||||||
"ToolInfo",
|
"ToolInfo",
|
||||||
"ToolParamType",
|
"ToolParamType",
|
||||||
|
|||||||
@@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
# 导入插件主类
|
# 导入插件主类
|
||||||
from .plugin import AntiInjectionPlugin
|
from .plugin import AntiInjectionPlugin
|
||||||
|
|
||||||
__all__ = ["__plugin_meta__", "AntiInjectionPlugin"]
|
__all__ = ["AntiInjectionPlugin", "__plugin_meta__"]
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import time
|
|||||||
|
|
||||||
from src.chat.security.interfaces import (
|
from src.chat.security.interfaces import (
|
||||||
SecurityAction,
|
SecurityAction,
|
||||||
SecurityCheckResult,
|
|
||||||
SecurityChecker,
|
SecurityChecker,
|
||||||
|
SecurityCheckResult,
|
||||||
SecurityLevel,
|
SecurityLevel,
|
||||||
)
|
)
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
|
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.chat.security.interfaces import SecurityAction, SecurityCheckResult
|
from src.chat.security.interfaces import SecurityCheckResult
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from .counter_attack import CounterAttackGenerator
|
from .counter_attack import CounterAttackGenerator
|
||||||
|
|||||||
@@ -6,10 +6,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import filetype
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import filetype
|
||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
from src.chat.utils.prompt_component_manager import prompt_component_manager
|
||||||
from src.plugin_system.apis import (
|
from src.plugin_system.apis import (
|
||||||
plugin_manage_api,
|
plugin_manage_api,
|
||||||
)
|
)
|
||||||
@@ -74,6 +74,7 @@ class SystemCommand(PlusCommand):
|
|||||||
• `/system permission` - 权限管理
|
• `/system permission` - 权限管理
|
||||||
• `/system plugin` - 插件管理
|
• `/system plugin` - 插件管理
|
||||||
• `/system schedule` - 定时任务管理
|
• `/system schedule` - 定时任务管理
|
||||||
|
• `/system prompt` - 提示词注入管理
|
||||||
"""
|
"""
|
||||||
elif target == "schedule":
|
elif target == "schedule":
|
||||||
help_text = """📅 定时任务管理帮助
|
help_text = """📅 定时任务管理帮助
|
||||||
@@ -113,8 +114,17 @@ class SystemCommand(PlusCommand):
|
|||||||
• /system permission nodes [插件名] - 查看权限节点
|
• /system permission nodes [插件名] - 查看权限节点
|
||||||
• /system permission allnodes - 查看所有权限节点详情
|
• /system permission allnodes - 查看所有权限节点详情
|
||||||
"""
|
"""
|
||||||
await self.send_text(help_text)
|
elif target == "prompt":
|
||||||
|
help_text = """📝 提示词注入管理帮助
|
||||||
|
|
||||||
|
🔎 查询命令 (需要 `system.prompt.view` 权限):
|
||||||
|
• `/system prompt help` - 显示此帮助
|
||||||
|
• `/system prompt map` - 查看全局注入关系图
|
||||||
|
• `/system prompt targets` - 列出所有可被注入的核心提示词
|
||||||
|
• `/system prompt components` - 列出所有已注册的提示词组件
|
||||||
|
• `/system prompt info <目标名>` - 查看特定核心提示词的注入详情
|
||||||
|
"""
|
||||||
|
await self.send_text(help_text)
|
||||||
|
|
||||||
# =================================================================
|
# =================================================================
|
||||||
# Plugin Management Section
|
# Plugin Management Section
|
||||||
@@ -231,6 +241,101 @@ class SystemCommand(PlusCommand):
|
|||||||
else:
|
else:
|
||||||
await self.send_text(f"❌ 恢复任务失败: `{schedule_id}`")
|
await self.send_text(f"❌ 恢复任务失败: `{schedule_id}`")
|
||||||
|
|
||||||
|
# =================================================================
|
||||||
|
# Prompt Management Section
|
||||||
|
# =================================================================
|
||||||
|
async def _handle_prompt_commands(self, args: list[str]):
|
||||||
|
"""处理提示词管理相关命令"""
|
||||||
|
if not args or args[0].lower() in ["help", "帮助"]:
|
||||||
|
await self._show_help("prompt")
|
||||||
|
return
|
||||||
|
|
||||||
|
action = args[0].lower()
|
||||||
|
remaining_args = args[1:]
|
||||||
|
|
||||||
|
if action in ["map", "关系图"]:
|
||||||
|
await self._show_injection_map()
|
||||||
|
elif action in ["targets", "目标"]:
|
||||||
|
await self._list_core_prompts()
|
||||||
|
elif action in ["components", "组件"]:
|
||||||
|
await self._list_prompt_components()
|
||||||
|
elif action in ["info", "详情"] and remaining_args:
|
||||||
|
await self._get_prompt_injection_info(remaining_args[0])
|
||||||
|
else:
|
||||||
|
await self.send_text("❌ 提示词管理命令不合法\n使用 /system prompt help 查看帮助")
|
||||||
|
|
||||||
|
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
|
||||||
|
async def _show_injection_map(self):
|
||||||
|
"""显示全局注入关系图"""
|
||||||
|
injection_map = await prompt_component_manager.get_full_injection_map()
|
||||||
|
if not injection_map:
|
||||||
|
await self.send_text("📊 当前没有任何提示词注入关系")
|
||||||
|
return
|
||||||
|
|
||||||
|
response_parts = ["📊 全局提示词注入关系图:\n"]
|
||||||
|
for target, injections in injection_map.items():
|
||||||
|
if injections:
|
||||||
|
response_parts.append(f"🎯 **{target}** (注入源):")
|
||||||
|
for inj in injections:
|
||||||
|
source_tag = f"({inj['source']})" if inj['source'] != 'static_default' else ''
|
||||||
|
response_parts.append(f" ⎿ `{inj['name']}` (优先级: {inj['priority']}) {source_tag}")
|
||||||
|
else:
|
||||||
|
response_parts.append(f"🎯 **{target}** (无注入)")
|
||||||
|
|
||||||
|
await self._send_long_message("\n".join(response_parts))
|
||||||
|
|
||||||
|
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
|
||||||
|
async def _list_core_prompts(self):
|
||||||
|
"""列出所有可注入的核心提示词"""
|
||||||
|
targets = prompt_component_manager.get_core_prompts()
|
||||||
|
if not targets:
|
||||||
|
await self.send_text("🎯 当前没有可注入的核心提示词")
|
||||||
|
return
|
||||||
|
|
||||||
|
response = "🎯 所有可注入的核心提示词:\n" + "\n".join([f"• `{name}`" for name in targets])
|
||||||
|
await self.send_text(response)
|
||||||
|
|
||||||
|
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
|
||||||
|
async def _list_prompt_components(self):
|
||||||
|
"""列出所有已注册的提示词组件"""
|
||||||
|
components = prompt_component_manager.get_registered_prompt_component_info()
|
||||||
|
if not components:
|
||||||
|
await self.send_text("🧩 当前没有已注册的提示词组件")
|
||||||
|
return
|
||||||
|
|
||||||
|
response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"]
|
||||||
|
for comp in components:
|
||||||
|
response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)")
|
||||||
|
|
||||||
|
await self._send_long_message("\n".join(response_parts))
|
||||||
|
|
||||||
|
|
||||||
|
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
|
||||||
|
async def _get_prompt_injection_info(self, target_name: str):
|
||||||
|
"""获取特定核心提示词的注入详情"""
|
||||||
|
injections = await prompt_component_manager.get_injections_for_prompt(target_name)
|
||||||
|
|
||||||
|
core_prompts = prompt_component_manager.get_core_prompts()
|
||||||
|
if target_name not in core_prompts:
|
||||||
|
await self.send_text(f"❌ 找不到核心提示词: `{target_name}`")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not injections:
|
||||||
|
await self.send_text(f"🎯 核心提示词 `{target_name}` 当前没有被任何组件注入。")
|
||||||
|
return
|
||||||
|
|
||||||
|
response_parts = [f"🔎 核心提示词 `{target_name}` 的注入详情:"]
|
||||||
|
for inj in injections:
|
||||||
|
response_parts.append(
|
||||||
|
f" • **`{inj['name']}`** (优先级: {inj['priority']})"
|
||||||
|
)
|
||||||
|
response_parts.append(f" - 来源: `{inj['source']}`")
|
||||||
|
response_parts.append(f" - 类型: `{inj['injection_type']}`")
|
||||||
|
if inj.get('target_content'):
|
||||||
|
response_parts.append(f" - 操作目标: `{inj['target_content']}`")
|
||||||
|
|
||||||
|
await self.send_text("\n".join(response_parts))
|
||||||
|
|
||||||
# =================================================================
|
# =================================================================
|
||||||
# Permission Management Section
|
# Permission Management Section
|
||||||
# =================================================================
|
# =================================================================
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import uuid
|
|||||||
import weakref
|
import weakref
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from contextlib import suppress
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -31,6 +30,7 @@ logger = get_logger("unified_scheduler")
|
|||||||
|
|
||||||
# ==================== 配置和常量 ====================
|
# ==================== 配置和常量 ====================
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SchedulerConfig:
|
class SchedulerConfig:
|
||||||
"""调度器配置"""
|
"""调度器配置"""
|
||||||
@@ -61,8 +61,10 @@ class SchedulerConfig:
|
|||||||
|
|
||||||
# ==================== 枚举类型 ====================
|
# ==================== 枚举类型 ====================
|
||||||
|
|
||||||
|
|
||||||
class TriggerType(Enum):
|
class TriggerType(Enum):
|
||||||
"""触发类型枚举"""
|
"""触发类型枚举"""
|
||||||
|
|
||||||
TIME = "time" # 时间触发
|
TIME = "time" # 时间触发
|
||||||
EVENT = "event" # 事件触发(通过 event_manager)
|
EVENT = "event" # 事件触发(通过 event_manager)
|
||||||
CUSTOM = "custom" # 自定义条件触发
|
CUSTOM = "custom" # 自定义条件触发
|
||||||
@@ -70,6 +72,7 @@ class TriggerType(Enum):
|
|||||||
|
|
||||||
class TaskStatus(Enum):
|
class TaskStatus(Enum):
|
||||||
"""任务状态枚举"""
|
"""任务状态枚举"""
|
||||||
|
|
||||||
PENDING = "pending" # 等待触发
|
PENDING = "pending" # 等待触发
|
||||||
RUNNING = "running" # 正在执行
|
RUNNING = "running" # 正在执行
|
||||||
COMPLETED = "completed" # 已完成
|
COMPLETED = "completed" # 已完成
|
||||||
@@ -81,9 +84,11 @@ class TaskStatus(Enum):
|
|||||||
|
|
||||||
# ==================== 任务模型 ====================
|
# ==================== 任务模型 ====================
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TaskExecution:
|
class TaskExecution:
|
||||||
"""任务执行记录"""
|
"""任务执行记录"""
|
||||||
|
|
||||||
execution_id: str
|
execution_id: str
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
ended_at: datetime | None = None
|
ended_at: datetime | None = None
|
||||||
@@ -176,10 +181,7 @@ class ScheduleTask:
|
|||||||
|
|
||||||
def start_execution(self) -> TaskExecution:
|
def start_execution(self) -> TaskExecution:
|
||||||
"""开始新的执行"""
|
"""开始新的执行"""
|
||||||
execution = TaskExecution(
|
execution = TaskExecution(execution_id=str(uuid.uuid4()), started_at=datetime.now())
|
||||||
execution_id=str(uuid.uuid4()),
|
|
||||||
started_at=datetime.now()
|
|
||||||
)
|
|
||||||
self.current_execution = execution
|
self.current_execution = execution
|
||||||
self.status = TaskStatus.RUNNING
|
self.status = TaskStatus.RUNNING
|
||||||
return execution
|
return execution
|
||||||
@@ -218,6 +220,7 @@ class ScheduleTask:
|
|||||||
|
|
||||||
# ==================== 死锁检测器(重构版)====================
|
# ==================== 死锁检测器(重构版)====================
|
||||||
|
|
||||||
|
|
||||||
class DeadlockDetector:
|
class DeadlockDetector:
|
||||||
"""死锁检测器(重构版)
|
"""死锁检测器(重构版)
|
||||||
|
|
||||||
@@ -296,6 +299,7 @@ class DeadlockDetector:
|
|||||||
|
|
||||||
# ==================== 统一调度器(完全重构版)====================
|
# ==================== 统一调度器(完全重构版)====================
|
||||||
|
|
||||||
|
|
||||||
class UnifiedScheduler:
|
class UnifiedScheduler:
|
||||||
"""统一调度器(完全重构版)
|
"""统一调度器(完全重构版)
|
||||||
|
|
||||||
@@ -367,22 +371,14 @@ class UnifiedScheduler:
|
|||||||
self._start_time = datetime.now()
|
self._start_time = datetime.now()
|
||||||
|
|
||||||
# 启动后台任务
|
# 启动后台任务
|
||||||
self._check_loop_task = asyncio.create_task(
|
self._check_loop_task = asyncio.create_task(self._check_loop(), name="scheduler_check_loop")
|
||||||
self._check_loop(),
|
self._deadlock_check_task = asyncio.create_task(self._deadlock_check_loop(), name="scheduler_deadlock_check")
|
||||||
name="scheduler_check_loop"
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop(), name="scheduler_cleanup")
|
||||||
)
|
|
||||||
self._deadlock_check_task = asyncio.create_task(
|
|
||||||
self._deadlock_check_loop(),
|
|
||||||
name="scheduler_deadlock_check"
|
|
||||||
)
|
|
||||||
self._cleanup_task = asyncio.create_task(
|
|
||||||
self._cleanup_loop(),
|
|
||||||
name="scheduler_cleanup"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 注册到 event_manager
|
# 注册到 event_manager
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
|
|
||||||
event_manager.register_scheduler_callback(self._handle_event_trigger)
|
event_manager.register_scheduler_callback(self._handle_event_trigger)
|
||||||
logger.debug("调度器已注册到 event_manager")
|
logger.debug("调度器已注册到 event_manager")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -416,6 +412,7 @@ class UnifiedScheduler:
|
|||||||
# 取消注册 event_manager
|
# 取消注册 event_manager
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
|
|
||||||
event_manager.unregister_scheduler_callback()
|
event_manager.unregister_scheduler_callback()
|
||||||
logger.debug("调度器已从 event_manager 注销")
|
logger.debug("调度器已从 event_manager 注销")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -426,9 +423,11 @@ class UnifiedScheduler:
|
|||||||
|
|
||||||
# 显示最终统计
|
# 显示最终统计
|
||||||
stats = self.get_statistics()
|
stats = self.get_statistics()
|
||||||
logger.info(f"调度器最终统计: 总任务={stats['total_tasks']}, "
|
logger.info(
|
||||||
|
f"调度器最终统计: 总任务={stats['total_tasks']}, "
|
||||||
f"执行次数={stats['total_executions']}, "
|
f"执行次数={stats['total_executions']}, "
|
||||||
f"失败={stats['total_failures']}")
|
f"失败={stats['total_failures']}"
|
||||||
|
)
|
||||||
|
|
||||||
# 清理资源
|
# 清理资源
|
||||||
self._tasks.clear()
|
self._tasks.clear()
|
||||||
@@ -442,8 +441,7 @@ class UnifiedScheduler:
|
|||||||
async def _cancel_all_running_tasks(self) -> None:
|
async def _cancel_all_running_tasks(self) -> None:
|
||||||
"""取消所有正在运行的任务"""
|
"""取消所有正在运行的任务"""
|
||||||
running_tasks = [
|
running_tasks = [
|
||||||
task for task in self._tasks.values()
|
task for task in self._tasks.values() if task.status == TaskStatus.RUNNING and task._asyncio_task
|
||||||
if task.status == TaskStatus.RUNNING and task._asyncio_task
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if not running_tasks:
|
if not running_tasks:
|
||||||
@@ -458,15 +456,13 @@ class UnifiedScheduler:
|
|||||||
|
|
||||||
# 第二阶段:等待取消完成(带超时)
|
# 第二阶段:等待取消完成(带超时)
|
||||||
cancel_tasks = [
|
cancel_tasks = [
|
||||||
task._asyncio_task for task in running_tasks
|
task._asyncio_task for task in running_tasks if task._asyncio_task and not task._asyncio_task.done()
|
||||||
if task._asyncio_task and not task._asyncio_task.done()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if cancel_tasks:
|
if cancel_tasks:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
asyncio.gather(*cancel_tasks, return_exceptions=True),
|
asyncio.gather(*cancel_tasks, return_exceptions=True), timeout=self.config.shutdown_timeout
|
||||||
timeout=self.config.shutdown_timeout
|
|
||||||
)
|
)
|
||||||
logger.info("所有任务已成功取消")
|
logger.info("所有任务已成功取消")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@@ -484,10 +480,7 @@ class UnifiedScheduler:
|
|||||||
|
|
||||||
if not self._stopping:
|
if not self._stopping:
|
||||||
# 使用 create_task 避免阻塞循环
|
# 使用 create_task 避免阻塞循环
|
||||||
asyncio.create_task(
|
asyncio.create_task(self._check_and_trigger_tasks(), name="check_trigger_tasks")
|
||||||
self._check_and_trigger_tasks(),
|
|
||||||
name="check_trigger_tasks"
|
|
||||||
)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug("调度器主循环被取消")
|
logger.debug("调度器主循环被取消")
|
||||||
@@ -505,10 +498,7 @@ class UnifiedScheduler:
|
|||||||
|
|
||||||
if not self._stopping:
|
if not self._stopping:
|
||||||
# 使用 create_task 避免阻塞循环,并限制错误传播
|
# 使用 create_task 避免阻塞循环,并限制错误传播
|
||||||
asyncio.create_task(
|
asyncio.create_task(self._safe_check_and_handle_deadlocks(), name="deadlock_check")
|
||||||
self._safe_check_and_handle_deadlocks(),
|
|
||||||
name="deadlock_check"
|
|
||||||
)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug("死锁检测循环被取消")
|
logger.debug("死锁检测循环被取消")
|
||||||
@@ -624,10 +614,7 @@ class UnifiedScheduler:
|
|||||||
# 为每个任务创建独立的执行 Task
|
# 为每个任务创建独立的执行 Task
|
||||||
execution_tasks = []
|
execution_tasks = []
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
exec_task = asyncio.create_task(
|
exec_task = asyncio.create_task(self._execute_task(task), name=f"exec_{task.task_name}")
|
||||||
self._execute_task(task),
|
|
||||||
name=f"exec_{task.task_name}"
|
|
||||||
)
|
|
||||||
task._asyncio_task = exec_task
|
task._asyncio_task = exec_task
|
||||||
execution_tasks.append(exec_task)
|
execution_tasks.append(exec_task)
|
||||||
|
|
||||||
@@ -647,16 +634,12 @@ class UnifiedScheduler:
|
|||||||
timeout = task.timeout or self.config.task_default_timeout
|
timeout = task.timeout or self.config.task_default_timeout
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(self._run_callback(task), timeout=timeout)
|
||||||
self._run_callback(task),
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
# 执行成功
|
# 执行成功
|
||||||
task.finish_execution(success=True)
|
task.finish_execution(success=True)
|
||||||
self._total_executions += 1
|
self._total_executions += 1
|
||||||
logger.debug(f"任务 {task.task_name} 执行成功 "
|
logger.debug(f"任务 {task.task_name} 执行成功 (第{task.trigger_count}次)")
|
||||||
f"(第{task.trigger_count}次)")
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# 任务超时
|
# 任务超时
|
||||||
@@ -683,8 +666,10 @@ class UnifiedScheduler:
|
|||||||
# 检查是否需要重试
|
# 检查是否需要重试
|
||||||
if self.config.enable_retry and task.retry_count < task.max_retries:
|
if self.config.enable_retry and task.retry_count < task.max_retries:
|
||||||
task.retry_count += 1
|
task.retry_count += 1
|
||||||
logger.info(f"任务 {task.task_name} 将在 {self.config.retry_delay}秒后重试 "
|
logger.info(
|
||||||
f"({task.retry_count}/{task.max_retries})")
|
f"任务 {task.task_name} 将在 {self.config.retry_delay}秒后重试 "
|
||||||
|
f"({task.retry_count}/{task.max_retries})"
|
||||||
|
)
|
||||||
await asyncio.sleep(self.config.retry_delay)
|
await asyncio.sleep(self.config.retry_delay)
|
||||||
task.status = TaskStatus.PENDING # 重置为待触发状态
|
task.status = TaskStatus.PENDING # 重置为待触发状态
|
||||||
|
|
||||||
@@ -706,8 +691,7 @@ class UnifiedScheduler:
|
|||||||
# 同步函数在线程池中运行,避免阻塞事件循环
|
# 同步函数在线程池中运行,避免阻塞事件循环
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
result = await loop.run_in_executor(
|
result = await loop.run_in_executor(
|
||||||
None,
|
None, lambda: task.callback(*task.callback_args, **task.callback_kwargs)
|
||||||
lambda: task.callback(*task.callback_args, **task.callback_kwargs)
|
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -721,6 +705,7 @@ class UnifiedScheduler:
|
|||||||
else:
|
else:
|
||||||
# 返回一个空的上下文管理器
|
# 返回一个空的上下文管理器
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
|
|
||||||
async def _move_to_completed(self, task: ScheduleTask) -> None:
|
async def _move_to_completed(self, task: ScheduleTask) -> None:
|
||||||
@@ -769,8 +754,7 @@ class UnifiedScheduler:
|
|||||||
for task in tasks_to_trigger:
|
for task in tasks_to_trigger:
|
||||||
# 将事件参数注入到回调
|
# 将事件参数注入到回调
|
||||||
exec_task = asyncio.create_task(
|
exec_task = asyncio.create_task(
|
||||||
self._execute_event_task(task, event_params),
|
self._execute_event_task(task, event_params), name=f"event_exec_{task.task_name}"
|
||||||
name=f"event_exec_{task.task_name}"
|
|
||||||
)
|
)
|
||||||
task._asyncio_task = exec_task
|
task._asyncio_task = exec_task
|
||||||
execution_tasks.append(exec_task)
|
execution_tasks.append(exec_task)
|
||||||
@@ -792,18 +776,12 @@ class UnifiedScheduler:
|
|||||||
merged_kwargs = {**task.callback_kwargs, **event_params}
|
merged_kwargs = {**task.callback_kwargs, **event_params}
|
||||||
|
|
||||||
if asyncio.iscoroutinefunction(task.callback):
|
if asyncio.iscoroutinefunction(task.callback):
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(task.callback(*task.callback_args, **merged_kwargs), timeout=timeout)
|
||||||
task.callback(*task.callback_args, **merged_kwargs),
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
loop.run_in_executor(
|
loop.run_in_executor(None, lambda: task.callback(*task.callback_args, **merged_kwargs)),
|
||||||
None,
|
timeout=timeout,
|
||||||
lambda: task.callback(*task.callback_args, **merged_kwargs)
|
|
||||||
),
|
|
||||||
timeout=timeout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
task.finish_execution(success=True)
|
task.finish_execution(success=True)
|
||||||
@@ -863,10 +841,7 @@ class UnifiedScheduler:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
health = self._deadlock_detector.get_health_score(task_id)
|
health = self._deadlock_detector.get_health_score(task_id)
|
||||||
logger.warning(
|
logger.warning(f"任务 {task_name} 疑似死锁: 运行时间={runtime:.1f}秒, 健康度={health:.2f}")
|
||||||
f"任务 {task_name} 疑似死锁: "
|
|
||||||
f"运行时间={runtime:.1f}秒, 健康度={health:.2f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 尝试取消任务(每个取消操作独立处理错误)
|
# 尝试取消任务(每个取消操作独立处理错误)
|
||||||
try:
|
try:
|
||||||
@@ -893,19 +868,16 @@ class UnifiedScheduler:
|
|||||||
for i, timeout in enumerate(timeouts):
|
for i, timeout in enumerate(timeouts):
|
||||||
try:
|
try:
|
||||||
# 使用 asyncio.wait 代替 wait_for,避免重新抛出异常
|
# 使用 asyncio.wait 代替 wait_for,避免重新抛出异常
|
||||||
done, pending = await asyncio.wait(
|
done, pending = await asyncio.wait({task._asyncio_task}, timeout=timeout)
|
||||||
{task._asyncio_task},
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
# 任务已完成(可能是正常完成或被取消)
|
# 任务已完成(可能是正常完成或被取消)
|
||||||
logger.debug(f"任务 {task.task_name} 在阶段 {i+1} 成功停止")
|
logger.debug(f"任务 {task.task_name} 在阶段 {i + 1} 成功停止")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 超时:继续下一阶段或放弃
|
# 超时:继续下一阶段或放弃
|
||||||
if i < len(timeouts) - 1:
|
if i < len(timeouts) - 1:
|
||||||
logger.warning(f"任务 {task.task_name} 取消阶段 {i+1} 超时,继续等待...")
|
logger.warning(f"任务 {task.task_name} 取消阶段 {i + 1} 超时,继续等待...")
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.error(f"任务 {task.task_name} 取消失败,强制清理")
|
logger.error(f"任务 {task.task_name} 取消失败,强制清理")
|
||||||
@@ -927,8 +899,7 @@ class UnifiedScheduler:
|
|||||||
"""清理已完成的任务"""
|
"""清理已完成的任务"""
|
||||||
# 清理已完成的一次性任务
|
# 清理已完成的一次性任务
|
||||||
completed_tasks = [
|
completed_tasks = [
|
||||||
task for task in self._tasks.values()
|
task for task in self._tasks.values() if not task.is_recurring and task.status == TaskStatus.COMPLETED
|
||||||
if not task.is_recurring and task.status == TaskStatus.COMPLETED
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for task in completed_tasks:
|
for task in completed_tasks:
|
||||||
@@ -1116,10 +1087,7 @@ class UnifiedScheduler:
|
|||||||
logger.info(f"强制触发任务: {task.task_name}")
|
logger.info(f"强制触发任务: {task.task_name}")
|
||||||
|
|
||||||
# 创建执行任务
|
# 创建执行任务
|
||||||
exec_task = asyncio.create_task(
|
exec_task = asyncio.create_task(self._execute_task(task), name=f"manual_trigger_{task.task_name}")
|
||||||
self._execute_task(task),
|
|
||||||
name=f"manual_trigger_{task.task_name}"
|
|
||||||
)
|
|
||||||
task._asyncio_task = exec_task
|
task._asyncio_task = exec_task
|
||||||
|
|
||||||
# 等待完成
|
# 等待完成
|
||||||
@@ -1274,11 +1242,13 @@ class UnifiedScheduler:
|
|||||||
runtime = 0.0
|
runtime = 0.0
|
||||||
if task.current_execution:
|
if task.current_execution:
|
||||||
runtime = (datetime.now() - task.current_execution.started_at).total_seconds()
|
runtime = (datetime.now() - task.current_execution.started_at).total_seconds()
|
||||||
running_tasks_info.append({
|
running_tasks_info.append(
|
||||||
|
{
|
||||||
"schedule_id": task.schedule_id[:8] + "...",
|
"schedule_id": task.schedule_id[:8] + "...",
|
||||||
"task_name": task.task_name,
|
"task_name": task.task_name,
|
||||||
"runtime": runtime,
|
"runtime": runtime,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"is_running": self._running,
|
"is_running": self._running,
|
||||||
@@ -1316,6 +1286,7 @@ class UnifiedScheduler:
|
|||||||
# 全局调度器实例
|
# 全局调度器实例
|
||||||
unified_scheduler = UnifiedScheduler()
|
unified_scheduler = UnifiedScheduler()
|
||||||
|
|
||||||
|
|
||||||
async def initialize_scheduler():
|
async def initialize_scheduler():
|
||||||
"""初始化调度器
|
"""初始化调度器
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user