This commit is contained in:
Windpicker-owo
2025-11-12 13:38:12 +08:00
36 changed files with 934 additions and 626 deletions

View File

@@ -19,14 +19,13 @@
import asyncio
import sys
from pathlib import Path
from typing import List
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
async def generate_missing_embeddings(
target_node_types: List[str] = None,
target_node_types: list[str] = None,
batch_size: int = 50,
):
"""
@@ -46,13 +45,13 @@ async def generate_missing_embeddings(
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
print(f"\n{'='*80}")
print(f"🔧 为节点生成嵌入向量")
print("🔧 为节点生成嵌入向量")
print(f"{'='*80}\n")
print(f"目标节点类型: {', '.join(target_node_types)}")
print(f"批处理大小: {batch_size}\n")
# 1. 初始化记忆管理器
print(f"🔧 正在初始化记忆管理器...")
print("🔧 正在初始化记忆管理器...")
await initialize_memory_manager()
manager = get_memory_manager()
@@ -60,10 +59,10 @@ async def generate_missing_embeddings(
print("❌ 记忆管理器初始化失败")
return
print(f"✅ 记忆管理器已初始化\n")
print("✅ 记忆管理器已初始化\n")
# 2. 获取已索引的节点ID
print(f"🔍 检查现有向量索引...")
print("🔍 检查现有向量索引...")
existing_node_ids = set()
try:
vector_count = manager.vector_store.collection.count()
@@ -82,10 +81,10 @@ async def generate_missing_embeddings(
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
except Exception as e:
logger.warning(f"获取已索引节点ID失败: {e}")
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
# 3. 收集需要生成嵌入的节点
print(f"🔍 扫描需要生成嵌入的节点...")
print("🔍 扫描需要生成嵌入的节点...")
all_memories = manager.graph_store.get_all_memories()
nodes_to_process = []
@@ -110,7 +109,7 @@ async def generate_missing_embeddings(
})
type_stats[node.node_type.value]["need_emb"] += 1
print(f"\n📊 扫描结果:")
print("\n📊 扫描结果:")
for node_type in target_node_types:
stats = type_stats[node_type]
already_ok = stats["already_indexed"]
@@ -121,11 +120,11 @@ async def generate_missing_embeddings(
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
if len(nodes_to_process) == 0:
print(f"✅ 所有节点已有嵌入向量,无需生成")
print("✅ 所有节点已有嵌入向量,无需生成")
return
# 3. 批量生成嵌入
print(f"🚀 开始生成嵌入向量...\n")
print("🚀 开始生成嵌入向量...\n")
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
success_count = 0
@@ -193,22 +192,22 @@ async def generate_missing_embeddings(
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
# 4. 保存图数据(更新节点的 embedding 字段)
print(f"💾 保存图数据...")
print("💾 保存图数据...")
try:
await manager.persistence.save_graph_store(manager.graph_store)
print(f"✅ 图数据已保存\n")
print("✅ 图数据已保存\n")
except Exception as e:
logger.error(f"保存图数据失败", exc_info=True)
logger.error("保存图数据失败", exc_info=True)
print(f"❌ 保存失败: {e}\n")
# 5. 验证结果
print(f"🔍 验证向量索引...")
print("🔍 验证向量索引...")
final_vector_count = manager.vector_store.collection.count()
stats = manager.graph_store.get_statistics()
total_nodes = stats["total_nodes"]
print(f"\n{'='*80}")
print(f"📊 生成完成")
print("📊 生成完成")
print(f"{'='*80}")
print(f"处理节点数: {len(nodes_to_process)}")
print(f"成功生成: {success_count}")
@@ -219,7 +218,7 @@ async def generate_missing_embeddings(
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
# 6. 测试搜索
print(f"🧪 测试搜索功能...")
print("🧪 测试搜索功能...")
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
for query in test_queries:

View File

@@ -4,13 +4,13 @@
提供 Web API 用于可视化记忆图数据
"""
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from collections import defaultdict
from typing import Any
import orjson
from fastapi import APIRouter, HTTPException, Request, Query
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
@@ -29,7 +29,7 @@ router = APIRouter()
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
def find_available_data_files() -> List[Path]:
def find_available_data_files() -> list[Path]:
"""查找所有可用的记忆图数据文件"""
files = []
if not data_dir.exists():
@@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]:
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据"""
global graph_data_cache, current_data_file
@@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any
if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
with open(graph_file, "r", encoding="utf-8") as f:
with open(graph_file, encoding="utf-8") as f:
data = orjson.loads(f.read())
nodes = data.get("nodes", [])
@@ -150,7 +150,7 @@ async def index(request: Request):
return templates.TemplateResponse("visualizer.html", {"request": request})
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据"""
if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
@@ -261,7 +261,7 @@ async def get_paginated_graph(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
):
"""分页获取图数据,支持重要性过滤"""
try:
@@ -383,7 +383,7 @@ async def get_clustered_graph(
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
"""简单的图聚类算法:按类型和连接度聚类"""
# 构建邻接表
adjacency = defaultdict(set)

View File

@@ -1,6 +1,5 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Literal
from typing import Literal
from fastapi import APIRouter, HTTPException, Query

View File

@@ -481,7 +481,7 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
chat_stream.context_manager.context.is_chatter_processing = is_processing
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
except Exception as e:
@@ -517,7 +517,7 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
return chat_stream.context_manager.context.is_chatter_processing
except Exception:
pass

View File

@@ -1177,10 +1177,10 @@ class DefaultReplyer:
if unread_messages:
# 使用最后一条未读消息作为参考
last_msg = unread_messages[-1]
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else ""
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else ""
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else ""
platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
processed_plain_text = last_msg.processed_plain_text or ""
else:
# 没有未读消息,使用默认值

View File

@@ -5,12 +5,12 @@
插件可以通过实现这些接口来扩展安全功能。
"""
from .interfaces import SecurityCheckResult, SecurityChecker
from .interfaces import SecurityChecker, SecurityCheckResult
from .manager import SecurityManager, get_security_manager
__all__ = [
"SecurityChecker",
"SecurityCheckResult",
"SecurityChecker",
"SecurityManager",
"get_security_manager",
]

View File

@@ -10,7 +10,7 @@ from typing import Any
from src.common.logger import get_logger
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel
from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
logger = get_logger("security.manager")

View File

@@ -1,5 +1,7 @@
import asyncio
import copy
import re
from collections.abc import Awaitable, Callable
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger
@@ -12,122 +14,205 @@ logger = get_logger("prompt_component_manager")
class PromptComponentManager:
"""
管理所有 `BasePrompt` 组件的单例类
一个统一的、动态的、可观测的提示词组件管理中心
该管理器负责:
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
2. 根据注入点目标Prompt名称对它们进行筛选
3. 提供一个接口以便在构建核心Prompt时能够获取并执行所有相关的组件。
该管理器是整个提示词动态注入系统的核心,它负责:
1. **规则加载**: 在系统启动时,自动扫描所有已注册的 `BasePrompt` 组件,
并将其静态定义的 `injection_rules` 加载为默认的动态规则
2. **动态管理**: 提供线程安全的 API允许在运行时动态地添加、更新或移除注入规则
使得提示词的结构可以被实时调整。
3. **状态观测**: 提供丰富的查询 API用于观测系统当前完整的注入状态
例如查询所有注入到特定目标的规则、或查询某个组件定义的所有规则。
4. **注入应用**: 在构建核心 Prompt 时,根据统一的、按优先级排序的规则集,
动态地修改和装配提示词模板,实现灵活的提示词组合。
"""
def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, type[BasePrompt]]]:
"""
获取指定目标Prompt的所有注入规则及其关联的组件类
def __init__(self):
"""初始化管理器实例。"""
# _dynamic_rules 是管理器的核心状态,存储所有注入规则
# 结构: {
# "target_prompt_name": {
# "prompt_component_name": (InjectionRule, content_provider, source)
# }
# }
# content_provider 是一个异步函数,用于在应用规则时动态生成注入内容。
# source 记录了规则的来源(例如 "static_default" 或 "runtime")。
self._dynamic_rules: dict[str, dict[str, tuple[InjectionRule, Callable[..., Awaitable[str]], str]]] = {}
self._lock = asyncio.Lock() # 使用异步锁确保对 _dynamic_rules 的并发访问安全。
self._initialized = False # 标记静态规则是否已加载,防止重复加载。
Args:
target_prompt_name (str): 目标 Prompt 的名称。
# --- 核心生命周期与初始化 ---
Returns:
list[tuple[InjectionRule, Type[BasePrompt]]]: 一个元组列表,
每个元组包含一个注入规则和其对应的 Prompt 组件类,并已根据优先级排序。
def load_static_rules(self):
"""
# 从注册表中获取所有已启用的 PROMPT 类型的组件
在系统启动时加载所有静态注入规则。
该方法会扫描所有已在 `component_registry` 中注册并启用的 Prompt 组件,
将其类变量 `injection_rules` 转换为管理器的动态规则。
这确保了所有插件定义的默认注入行为在系统启动时就能生效。
此操作是幂等的,一旦初始化完成就不会重复执行。
"""
if self._initialized:
return
logger.info("正在加载静态 Prompt 注入规则...")
# 从组件注册表中获取所有已启用的 Prompt 组件
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
matching_rules = []
# 遍历所有启用的 Prompt 组件,查找与目标 Prompt 相关的注入规则
for prompt_name, prompt_info in enabled_prompts.items():
if not isinstance(prompt_info, PromptInfo):
continue
# prompt_info.injection_rules 已经经过了后向兼容处理,确保总是列表
for rule in prompt_info.injection_rules:
# 如果规则的目标是当前指定的 Prompt
if rule.target_prompt == target_prompt_name:
# 获取该规则对应的组件类
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
# 确保获取到的确实是一个 BasePrompt 的子类
if component_class and issubclass(component_class, BasePrompt):
matching_rules.append((rule, component_class))
if not (component_class and issubclass(component_class, BasePrompt)):
logger.warning(f"无法为 '{prompt_name}' 加载静态规则,因为它不是一个有效的 Prompt 组件。")
continue
# 根据规则的优先级进行排序,数字越小,优先级越高,越先应用
matching_rules.sort(key=lambda x: x[0].priority)
return matching_rules
def create_provider(cls: type[BasePrompt]) -> Callable[[PromptParameters], Awaitable[str]]:
"""
为静态组件创建一个内容提供者闭包 (Content Provider Closure)。
这个闭包捕获了组件的类 `cls`,并返回一个标准的 `content_provider` 异步函数。
当 `apply_injections` 需要内容时,它会调用这个函数。
函数内部会实例化组件,并执行其 `execute` 方法来获取注入内容。
Args:
cls (type[BasePrompt]): 需要为其创建提供者的 Prompt 组件类。
Returns:
Callable[[PromptParameters], Awaitable[str]]: 一个符合管理器标准的异步内容提供者。
"""
async def content_provider(params: PromptParameters) -> str:
"""实际执行内容生成的异步函数。"""
try:
# 从注册表获取最新的组件信息,包括插件配置
p_info = component_registry.get_component_info(cls.prompt_name, ComponentType.PROMPT)
plugin_config = {}
if isinstance(p_info, PromptInfo):
plugin_config = component_registry.get_plugin_config(p_info.plugin_name)
# 实例化组件并执行
instance = cls(params=params, plugin_config=plugin_config)
result = await instance.execute()
return str(result) if result is not None else ""
except Exception as e:
logger.error(f"执行静态规则提供者 '{cls.prompt_name}' 时出错: {e}", exc_info=True)
return "" # 出错时返回空字符串,避免影响主流程
return content_provider
# 为该组件的每条静态注入规则创建并注册一个动态规则
for rule in prompt_info.injection_rules:
provider = create_provider(component_class)
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
target_rules[prompt_name] = (rule, provider, "static_default")
self._initialized = True
logger.info(f"静态 Prompt 注入规则加载完成,共处理 {len(enabled_prompts)} 个组件。")
# --- 运行时规则管理 API ---
async def add_injection_rule(
self,
prompt_name: str,
rule: InjectionRule,
content_provider: Callable[..., Awaitable[str]],
source: str = "runtime",
) -> bool:
"""
动态添加或更新一条注入规则。
此方法允许在系统运行时,由外部逻辑(如插件、命令)向管理器中添加新的注入行为。
如果已存在同名组件针对同一目标的规则,此方法会覆盖旧规则。
Args:
prompt_name (str): 动态注入组件的唯一名称。
rule (InjectionRule): 描述注入行为的规则对象。
content_provider (Callable[..., Awaitable[str]]):
一个异步函数,用于在应用注入时动态生成内容。
函数签名应为: `async def provider(params: "PromptParameters") -> str`
source (str, optional): 规则的来源标识,默认为 "runtime"
Returns:
bool: 如果成功添加或更新,则返回 True。
"""
async with self._lock:
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
target_rules[prompt_name] = (rule, content_provider, source)
logger.info(f"成功添加/更新注入规则: '{prompt_name}' -> '{rule.target_prompt}' (来源: {source})")
return True
async def remove_injection_rule(self, prompt_name: str, target_prompt: str) -> bool:
"""
移除一条动态注入规则。
Args:
prompt_name (str): 要移除的注入组件的名称。
target_prompt (str): 该组件注入的目标核心提示词名称。
Returns:
bool: 如果成功移除,则返回 True如果规则不存在则返回 False。
"""
async with self._lock:
if target_prompt in self._dynamic_rules and prompt_name in self._dynamic_rules[target_prompt]:
del self._dynamic_rules[target_prompt][prompt_name]
# 如果目标下已无任何规则,则清理掉这个键
if not self._dynamic_rules[target_prompt]:
del self._dynamic_rules[target_prompt]
logger.info(f"成功移除注入规则: '{prompt_name}' from '{target_prompt}'")
return True
logger.warning(f"尝试移除注入规则失败: 未找到 '{prompt_name}' on '{target_prompt}'")
return False
# --- 核心注入逻辑 ---
async def apply_injections(
self, target_prompt_name: str, original_template: str, params: PromptParameters
) -> str:
"""
获取、实例化并执行所有相关组件,然后根据注入规则修改原始模板。
【核心方法】根据目标名称,应用所有匹配的注入规则,返回修改后的模板。
这是一个三步走的过程
1. 实例化所有需要执行的组件
2. 并行执行它们的 `execute` 方法以获取注入内容
3. 按照优先级顺序,将内容注入到原始模板中
这是提示词构建流程中的关键步骤。它会执行以下操作
1. 检查并确保静态规则已加载
2. 获取所有注入到 `target_prompt_name` 的规则
3. 按照规则的 `priority` 属性进行升序排序,优先级数字越小越先应用
4. 依次执行每个规则的 `content_provider` 来异步获取注入内容。
5. 根据规则的 `injection_type` (如 PREPEND, APPEND, REPLACE 等) 将内容应用到模板上。
Args:
target_prompt_name (str): 目标 Prompt 的名称。
original_template (str): 原始的、未经修改的 Prompt 模板字符串
params (PromptParameters): 传递给 Prompt 组件实例的参数
target_prompt_name (str): 目标核心提示词的名称。
original_template (str): 未经修改的原始提示词模板
params (PromptParameters): 当前请求的参数,会传递给 `content_provider`
Returns:
str: 应用了所有注入规则后,修改过的 Prompt 模板字符串。
str: 应用了所有注入规则后,最终生成的提示词模板字符串。
"""
rules_with_classes = self._get_rules_for(target_prompt_name)
# 如果没有找到任何匹配的规则,就直接返回原始模板,啥也不干
if not rules_with_classes:
if not self._initialized:
self.load_static_rules()
# 步骤 1: 获取所有指向当前目标的规则
# 使用 .values() 获取 (rule, provider, source) 元组列表
rules_for_target = list(self._dynamic_rules.get(target_prompt_name, {}).values())
if not rules_for_target:
return original_template
# --- 第一步: 实例化所有需要执行的组件 ---
instance_map = {} # 存储组件实例,虽然目前没直接用,但留着总没错
tasks = [] # 存放所有需要并行执行的 execute 异步任务
components_to_execute = [] # 存放需要执行的组件类,用于后续结果映射
# 步骤 2: 按优先级排序,数字越小越优先
rules_for_target.sort(key=lambda x: x[0].priority)
for rule, component_class in rules_with_classes:
# 如果注入类型是 REMOVE那就不需要执行组件了因为它不产生内容
# 步骤 3: 依次执行内容提供者并根据注入类型修改模板
modified_template = original_template
for rule, provider, source in rules_for_target:
content = ""
# 对于非 REMOVE 类型的注入,需要先获取内容
if rule.injection_type != InjectionType.REMOVE:
try:
# 获取组件的元信息,主要是为了拿到插件名称来读取插件配置
prompt_info = component_registry.get_component_info(
component_class.prompt_name, ComponentType.PROMPT
)
if not isinstance(prompt_info, PromptInfo):
plugin_config = {}
else:
# 从注册表获取该组件所属插件的配置
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
# 实例化组件,并传入参数和插件配置
instance = component_class(params=params, plugin_config=plugin_config)
instance_map[component_class.prompt_name] = instance
# 将组件的 execute 方法作为一个任务添加到列表中
tasks.append(instance.execute())
components_to_execute.append(component_class)
content = await provider(params)
except Exception as e:
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
# 即使失败,也添加一个立即完成的空任务,以保持与其他任务的索引同步
tasks.append(asyncio.create_task(asyncio.sleep(0, result=e))) # type: ignore
# --- 第二步: 并行执行所有组件的 execute 方法 ---
# 使用 asyncio.gather 来同时运行所有任务,提高效率
results = await asyncio.gather(*tasks, return_exceptions=True)
# 创建一个从组件名到执行结果的映射,方便后续查找
result_map = {
components_to_execute[i].prompt_name: res
for i, res in enumerate(results)
if not isinstance(res, Exception) # 只包含成功的结果
}
# 单独处理并记录执行失败的组件
for i, res in enumerate(results):
if isinstance(res, Exception):
logger.error(f"执行 Prompt 组件 '{components_to_execute[i].prompt_name}' 失败: {res}")
# --- 第三步: 按优先级顺序应用注入规则 ---
modified_template = original_template
for rule, component_class in rules_with_classes:
# 从结果映射中获取该组件生成的内容
content = result_map.get(component_class.prompt_name)
logger.error(f"执行规则 '{rule}' (来源: {source}) 的内容提供者时失败: {e}", exc_info=True)
continue # 跳过失败的 provider不中断整个流程
# 应用注入逻辑
try:
if rule.injection_type == InjectionType.PREPEND:
if content:
@@ -136,28 +221,178 @@ class PromptComponentManager:
if content:
modified_template = f"{modified_template}\n{content}"
elif rule.injection_type == InjectionType.REPLACE:
# 使用正则表达式替换目标内容
if content and rule.target_content:
# 只有在 content 不为 None 且 target_content 有效时才执行替换
if content is not None and rule.target_content:
modified_template = re.sub(rule.target_content, str(content), modified_template)
elif rule.injection_type == InjectionType.INSERT_AFTER:
# 在匹配到的内容后面插入
if content and rule.target_content:
# re.sub a little trick: \g<0> represents the entire matched string
# 使用 `\g<0>` 在正则匹配的整个内容后添加新内容
replacement = f"\\g<0>\n{content}"
modified_template = re.sub(rule.target_content, replacement, modified_template)
elif rule.injection_type == InjectionType.REMOVE:
# 使用正则表达式移除目标内容
if rule.target_content:
modified_template = re.sub(rule.target_content, "", modified_template)
except re.error as e:
logger.error(
f"在为 '{component_class.prompt_name}' 应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')"
)
logger.error(f"应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')")
except Exception as e:
logger.error(f"应用 Prompt 注入规则 '{rule}' 失败: {e}")
logger.error(f"应用注入规则 '{rule}' (来源: {source}) 失败: {e}", exc_info=True)
return modified_template
async def preview_prompt_injections(
self, target_prompt_name: str, params: PromptParameters
) -> str:
"""
【预览功能】模拟应用所有注入规则,返回最终生成的模板字符串,而不实际修改任何状态。
# 创建全局单例
这个方法对于调试和测试非常有用,可以查看在特定参数下,
一个核心提示词经过所有注入规则处理后会变成什么样子。
Args:
target_prompt_name (str): 希望预览的目标核心提示词名称。
params (PromptParameters): 模拟的请求参数。
Returns:
str: 模拟生成的最终提示词模板字符串。如果找不到模板,则返回错误信息。
"""
try:
# 从全局提示词管理器获取最原始的模板内容
from src.chat.utils.prompt import global_prompt_manager
original_prompt = global_prompt_manager._prompts.get(target_prompt_name)
if not original_prompt:
logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。")
return f"Error: Prompt '{target_prompt_name}' not found."
original_template = original_prompt.template
except KeyError:
logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。")
return f"Error: Prompt '{target_prompt_name}' not found."
# 直接调用核心注入逻辑来模拟结果
return await self.apply_injections(target_prompt_name, original_template, params)
# --- 状态观测与查询 API ---
def get_core_prompts(self) -> list[str]:
"""获取所有已注册的核心提示词模板名称列表(即所有可注入的目标)。"""
from src.chat.utils.prompt import global_prompt_manager
return list(global_prompt_manager._prompts.keys())
def get_core_prompt_contents(self) -> dict[str, str]:
"""获取所有核心提示词模板的原始内容。"""
from src.chat.utils.prompt import global_prompt_manager
return {name: prompt.template for name, prompt in global_prompt_manager._prompts.items()}
def get_registered_prompt_component_info(self) -> list[PromptInfo]:
"""获取所有在 ComponentRegistry 中注册的 Prompt 组件信息。"""
components = component_registry.get_components_by_type(ComponentType.PROMPT).values()
return [info for info in components if isinstance(info, PromptInfo)]
async def get_full_injection_map(self) -> dict[str, list[dict]]:
"""
获取当前完整的注入映射图。
此方法提供了一个系统全局的注入视图展示了每个核心提示词target
被哪些注入组件source以何种优先级注入。
Returns:
dict[str, list[dict]]: 一个字典,键是目标提示词名称,
值是按优先级排序的注入信息列表。
`[{"name": str, "priority": int, "source": str}]`
"""
injection_map = {}
async with self._lock:
# 合并所有动态规则的目标和所有核心提示词,确保所有潜在目标都被包含
all_targets = set(self._dynamic_rules.keys()) | set(self.get_core_prompts())
for target in sorted(all_targets):
rules = self._dynamic_rules.get(target, {})
if not rules:
injection_map[target] = []
continue
info_list = []
for prompt_name, (rule, _, source) in rules.items():
info_list.append({"name": prompt_name, "priority": rule.priority, "source": source})
# 按优先级排序后存入 map
info_list.sort(key=lambda x: x["priority"])
injection_map[target] = info_list
return injection_map
async def get_injections_for_prompt(self, target_prompt_name: str) -> list[dict]:
"""
获取指定核心提示词模板的所有注入信息(包含详细规则)。
Args:
target_prompt_name (str): 目标核心提示词的名称。
Returns:
list[dict]: 一个包含注入规则详细信息的列表,已按优先级排序。
"""
rules_for_target = self._dynamic_rules.get(target_prompt_name, {})
if not rules_for_target:
return []
info_list = []
for prompt_name, (rule, _, source) in rules_for_target.items():
info_list.append(
{
"name": prompt_name,
"priority": rule.priority,
"source": source,
"injection_type": rule.injection_type.value,
"target_content": rule.target_content,
}
)
info_list.sort(key=lambda x: x["priority"])
return info_list
def get_all_dynamic_rules(self) -> dict[str, dict[str, "InjectionRule"]]:
"""
获取所有当前的动态注入规则,以 InjectionRule 对象形式返回。
此方法返回一个深拷贝的规则副本,隐藏了 `content_provider` 等内部实现细节。
适合用于展示或序列化当前的规则配置。
"""
rules_copy = {}
for target, rules in self._dynamic_rules.items():
target_copy = {name: rule for name, (rule, _, _) in rules.items()}
rules_copy[target] = target_copy
return copy.deepcopy(rules_copy)
def get_rules_for_target(self, target_prompt: str) -> dict[str, InjectionRule]:
"""
获取所有注入到指定核心提示词的动态规则。
Args:
target_prompt (str): 目标核心提示词的名称。
Returns:
dict[str, InjectionRule]: 一个字典,键是注入组件的名称,值是 `InjectionRule` 对象。
如果找不到任何注入到该目标的规则,则返回一个空字典。
"""
target_rules = self._dynamic_rules.get(target_prompt, {})
return {name: copy.deepcopy(rule_info[0]) for name, rule_info in target_rules.items()}
def get_rules_by_component(self, component_name: str) -> dict[str, InjectionRule]:
"""
获取由指定的单个注入组件定义的所有动态规则。
Args:
component_name (str): 注入组件的名称。
Returns:
dict[str, InjectionRule]: 一个字典,键是目标核心提示词的名称,值是 `InjectionRule` 对象。
如果该组件没有定义任何注入规则,则返回一个空字典。
"""
found_rules = {}
for target, rules in self._dynamic_rules.items():
if component_name in rules:
rule_info = rules[component_name]
found_rules[target] = copy.deepcopy(rule_info[0])
return found_rules
# 创建全局单例 (Singleton)
# 在整个应用程序中,应该只使用这一个 `prompt_component_manager` 实例,
# 以确保所有部分都共享和操作同一份动态规则集。
prompt_component_manager = PromptComponentManager()

View File

@@ -9,11 +9,12 @@
"""
import asyncio
import builtins
import time
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union
from typing import Any, Generic, TypeVar
from src.common.logger import get_logger
from src.common.memory_utils import estimate_size_smart
@@ -96,7 +97,7 @@ class LRUCache(Generic[T]):
self._lock = asyncio.Lock()
self._stats = CacheStats()
async def get(self, key: str) -> Optional[T]:
async def get(self, key: str) -> T | None:
"""获取缓存值
Args:
@@ -137,8 +138,8 @@ class LRUCache(Generic[T]):
self,
key: str,
value: T,
size: Optional[int] = None,
ttl: Optional[float] = None,
size: int | None = None,
ttl: float | None = None,
) -> None:
"""设置缓存值
@@ -287,8 +288,8 @@ class MultiLevelCache:
async def get(
self,
key: str,
loader: Optional[Callable[[], Any]] = None,
) -> Optional[Any]:
loader: Callable[[], Any] | None = None,
) -> Any | None:
"""从缓存获取数据
查询顺序L1 -> L2 -> loader
@@ -329,8 +330,8 @@ class MultiLevelCache:
self,
key: str,
value: Any,
size: Optional[int] = None,
ttl: Optional[float] = None,
size: int | None = None,
ttl: float | None = None,
) -> None:
"""设置缓存值
@@ -390,7 +391,7 @@ class MultiLevelCache:
await self.l2_cache.clear()
logger.info("所有缓存已清空")
async def get_stats(self) -> Dict[str, Any]:
async def get_stats(self) -> dict[str, Any]:
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
# 🔧 修复:并行获取统计信息,避免锁嵌套
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
@@ -492,7 +493,7 @@ class MultiLevelCache:
logger.error(f"{cache_name}统计获取异常: {e}")
return CacheStats()
async def _get_cache_keys_safe(self, cache) -> Set[str]:
async def _get_cache_keys_safe(self, cache) -> builtins.set[str]:
"""安全获取缓存键集合(带超时)"""
try:
# 快速获取键集合,使用超时避免死锁
@@ -507,12 +508,12 @@ class MultiLevelCache:
logger.error(f"缓存键获取异常: {e}")
return set()
async def _extract_keys_with_lock(self, cache) -> Set[str]:
async def _extract_keys_with_lock(self, cache) -> builtins.set[str]:
"""在锁保护下提取键集合"""
async with cache._lock:
return set(cache._cache.keys())
async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int:
async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int:
"""安全计算内存使用(带超时)"""
if not keys:
return 0
@@ -529,7 +530,7 @@ class MultiLevelCache:
logger.error(f"内存计算异常: {e}")
return 0
async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int:
async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int:
"""在锁保护下计算内存使用"""
total_size = 0
async with cache._lock:
@@ -749,7 +750,7 @@ class MultiLevelCache:
# 全局缓存实例
_global_cache: Optional[MultiLevelCache] = None
_global_cache: MultiLevelCache | None = None
_cache_lock = asyncio.Lock()

View File

@@ -3,7 +3,6 @@ import socket
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from rich.traceback import install
from uvicorn import Config
from uvicorn import Server as UvicornServer

View File

@@ -1098,7 +1098,7 @@ class MemoryManager:
# 2. 清理孤立边(指向已删除节点的边)
edges_to_remove = []
for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'):
for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
# 检查边的源节点和目标节点是否还存在于node_to_memories中
if source not in self.graph_store.node_to_memories or \
target not in self.graph_store.node_to_memories:
@@ -2301,7 +2301,7 @@ class MemoryManager:
# 使用 asyncio.wait_for 来支持取消
await asyncio.wait_for(
asyncio.sleep(initial_delay),
timeout=float('inf') # 允许随时取消
timeout=float("inf") # 允许随时取消
)
# 检查是否仍然需要运行

View File

@@ -559,8 +559,8 @@ class MemoryTools:
)
if len(initial_memory_ids) == 0:
logger.warning(
f"⚠️ 向量搜索未找到任何记忆!"
f"可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
"⚠️ 向量搜索未找到任何记忆!"
"可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
)
# 输出相似节点的详细信息用于调试
if similar_nodes:
@@ -738,12 +738,12 @@ class MemoryTools:
activation_score = memory.activation
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type)
memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type)
# 检测记忆的主要节点类型
node_types_count = {}
for node in memory.nodes:
nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type)
nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type)
node_types_count[nt] = node_types_count.get(nt, 0) + 1
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
@@ -1092,6 +1092,7 @@ class MemoryTools:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
import re
import orjson
# 清理Markdown代码块

View File

@@ -97,7 +97,7 @@ async def expand_memories_with_semantic_filter(
source_node_id = edge.source_id
# 🆕 根据边类型设置权重优先扩展REFERENCE、ATTRIBUTE相关的边
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type)
if edge_type_str == "REFERENCE":
edge_weight = 1.3 # REFERENCE边权重最高引用关系
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:

View File

@@ -78,11 +78,9 @@ __all__ = [
# 消息
"MaiMessages",
# 工具函数
"ManifestValidator",
"PluginInfo",
# 增强命令系统
"PlusCommand",
"PlusCommandAdapter",
"PythonDependency",
"ToolInfo",
"ToolParamType",

View File

@@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata(
# 导入插件主类
from .plugin import AntiInjectionPlugin
__all__ = ["__plugin_meta__", "AntiInjectionPlugin"]
__all__ = ["AntiInjectionPlugin", "__plugin_meta__"]

View File

@@ -8,8 +8,8 @@ import time
from src.chat.security.interfaces import (
SecurityAction,
SecurityCheckResult,
SecurityChecker,
SecurityCheckResult,
SecurityLevel,
)
from src.common.logger import get_logger

View File

@@ -4,7 +4,7 @@
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
"""
from src.chat.security.interfaces import SecurityAction, SecurityCheckResult
from src.chat.security.interfaces import SecurityCheckResult
from src.common.logger import get_logger
from .counter_attack import CounterAttackGenerator

View File

@@ -6,10 +6,10 @@
import asyncio
import base64
import datetime
import filetype
from collections.abc import Callable
import aiohttp
import filetype
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager

View File

@@ -6,7 +6,7 @@
import re
from typing import ClassVar
from src.chat.utils.prompt_component_manager import prompt_component_manager
from src.plugin_system.apis import (
plugin_manage_api,
)
@@ -74,6 +74,7 @@ class SystemCommand(PlusCommand):
• `/system permission` - 权限管理
• `/system plugin` - 插件管理
• `/system schedule` - 定时任务管理
• `/system prompt` - 提示词注入管理
"""
elif target == "schedule":
help_text = """📅 定时任务管理帮助
@@ -113,8 +114,17 @@ class SystemCommand(PlusCommand):
• /system permission nodes [插件名] - 查看权限节点
• /system permission allnodes - 查看所有权限节点详情
"""
await self.send_text(help_text)
elif target == "prompt":
help_text = """📝 提示词注入管理帮助
🔎 查询命令 (需要 `system.prompt.view` 权限):
• `/system prompt help` - 显示此帮助
• `/system prompt map` - 查看全局注入关系图
• `/system prompt targets` - 列出所有可被注入的核心提示词
• `/system prompt components` - 列出所有已注册的提示词组件
• `/system prompt info <目标名>` - 查看特定核心提示词的注入详情
"""
await self.send_text(help_text)
# =================================================================
# Plugin Management Section
@@ -231,6 +241,101 @@ class SystemCommand(PlusCommand):
else:
await self.send_text(f"❌ 恢复任务失败: `{schedule_id}`")
# =================================================================
# Prompt Management Section
# =================================================================
async def _handle_prompt_commands(self, args: list[str]):
"""处理提示词管理相关命令"""
if not args or args[0].lower() in ["help", "帮助"]:
await self._show_help("prompt")
return
action = args[0].lower()
remaining_args = args[1:]
if action in ["map", "关系图"]:
await self._show_injection_map()
elif action in ["targets", "目标"]:
await self._list_core_prompts()
elif action in ["components", "组件"]:
await self._list_prompt_components()
elif action in ["info", "详情"] and remaining_args:
await self._get_prompt_injection_info(remaining_args[0])
else:
await self.send_text("❌ 提示词管理命令不合法\n使用 /system prompt help 查看帮助")
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _show_injection_map(self):
"""显示全局注入关系图"""
injection_map = await prompt_component_manager.get_full_injection_map()
if not injection_map:
await self.send_text("📊 当前没有任何提示词注入关系")
return
response_parts = ["📊 全局提示词注入关系图:\n"]
for target, injections in injection_map.items():
if injections:
response_parts.append(f"🎯 **{target}** (注入源):")
for inj in injections:
source_tag = f"({inj['source']})" if inj['source'] != 'static_default' else ''
response_parts.append(f" ⎿ `{inj['name']}` (优先级: {inj['priority']}) {source_tag}")
else:
response_parts.append(f"🎯 **{target}** (无注入)")
await self._send_long_message("\n".join(response_parts))
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _list_core_prompts(self):
"""列出所有可注入的核心提示词"""
targets = prompt_component_manager.get_core_prompts()
if not targets:
await self.send_text("🎯 当前没有可注入的核心提示词")
return
response = "🎯 所有可注入的核心提示词:\n" + "\n".join([f"• `{name}`" for name in targets])
await self.send_text(response)
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _list_prompt_components(self):
"""列出所有已注册的提示词组件"""
components = prompt_component_manager.get_registered_prompt_component_info()
if not components:
await self.send_text("🧩 当前没有已注册的提示词组件")
return
response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"]
for comp in components:
response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)")
await self._send_long_message("\n".join(response_parts))
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _get_prompt_injection_info(self, target_name: str):
"""获取特定核心提示词的注入详情"""
injections = await prompt_component_manager.get_injections_for_prompt(target_name)
core_prompts = prompt_component_manager.get_core_prompts()
if target_name not in core_prompts:
await self.send_text(f"❌ 找不到核心提示词: `{target_name}`")
return
if not injections:
await self.send_text(f"🎯 核心提示词 `{target_name}` 当前没有被任何组件注入。")
return
response_parts = [f"🔎 核心提示词 `{target_name}` 的注入详情:"]
for inj in injections:
response_parts.append(
f" • **`{inj['name']}`** (优先级: {inj['priority']})"
)
response_parts.append(f" - 来源: `{inj['source']}`")
response_parts.append(f" - 类型: `{inj['injection_type']}`")
if inj.get('target_content'):
response_parts.append(f" - 操作目标: `{inj['target_content']}`")
await self.send_text("\n".join(response_parts))
# =================================================================
# Permission Management Section
# =================================================================

View File

@@ -17,7 +17,6 @@ import uuid
import weakref
from collections import defaultdict
from collections.abc import Awaitable, Callable
from contextlib import suppress
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
@@ -31,6 +30,7 @@ logger = get_logger("unified_scheduler")
# ==================== 配置和常量 ====================
@dataclass
class SchedulerConfig:
"""调度器配置"""
@@ -61,8 +61,10 @@ class SchedulerConfig:
# ==================== 枚举类型 ====================
class TriggerType(Enum):
"""触发类型枚举"""
TIME = "time" # 时间触发
EVENT = "event" # 事件触发(通过 event_manager
CUSTOM = "custom" # 自定义条件触发
@@ -70,6 +72,7 @@ class TriggerType(Enum):
class TaskStatus(Enum):
"""任务状态枚举"""
PENDING = "pending" # 等待触发
RUNNING = "running" # 正在执行
COMPLETED = "completed" # 已完成
@@ -81,9 +84,11 @@ class TaskStatus(Enum):
# ==================== 任务模型 ====================
@dataclass
class TaskExecution:
"""任务执行记录"""
execution_id: str
started_at: datetime
ended_at: datetime | None = None
@@ -176,10 +181,7 @@ class ScheduleTask:
def start_execution(self) -> TaskExecution:
"""开始新的执行"""
execution = TaskExecution(
execution_id=str(uuid.uuid4()),
started_at=datetime.now()
)
execution = TaskExecution(execution_id=str(uuid.uuid4()), started_at=datetime.now())
self.current_execution = execution
self.status = TaskStatus.RUNNING
return execution
@@ -218,6 +220,7 @@ class ScheduleTask:
# ==================== 死锁检测器(重构版)====================
class DeadlockDetector:
"""死锁检测器(重构版)
@@ -296,6 +299,7 @@ class DeadlockDetector:
# ==================== 统一调度器(完全重构版)====================
class UnifiedScheduler:
"""统一调度器(完全重构版)
@@ -367,22 +371,14 @@ class UnifiedScheduler:
self._start_time = datetime.now()
# 启动后台任务
self._check_loop_task = asyncio.create_task(
self._check_loop(),
name="scheduler_check_loop"
)
self._deadlock_check_task = asyncio.create_task(
self._deadlock_check_loop(),
name="scheduler_deadlock_check"
)
self._cleanup_task = asyncio.create_task(
self._cleanup_loop(),
name="scheduler_cleanup"
)
self._check_loop_task = asyncio.create_task(self._check_loop(), name="scheduler_check_loop")
self._deadlock_check_task = asyncio.create_task(self._deadlock_check_loop(), name="scheduler_deadlock_check")
self._cleanup_task = asyncio.create_task(self._cleanup_loop(), name="scheduler_cleanup")
# 注册到 event_manager
try:
from src.plugin_system.core.event_manager import event_manager
event_manager.register_scheduler_callback(self._handle_event_trigger)
logger.debug("调度器已注册到 event_manager")
except ImportError:
@@ -416,6 +412,7 @@ class UnifiedScheduler:
# 取消注册 event_manager
try:
from src.plugin_system.core.event_manager import event_manager
event_manager.unregister_scheduler_callback()
logger.debug("调度器已从 event_manager 注销")
except ImportError:
@@ -426,9 +423,11 @@ class UnifiedScheduler:
# 显示最终统计
stats = self.get_statistics()
logger.info(f"调度器最终统计: 总任务={stats['total_tasks']}, "
logger.info(
f"调度器最终统计: 总任务={stats['total_tasks']}, "
f"执行次数={stats['total_executions']}, "
f"失败={stats['total_failures']}")
f"失败={stats['total_failures']}"
)
# 清理资源
self._tasks.clear()
@@ -442,8 +441,7 @@ class UnifiedScheduler:
async def _cancel_all_running_tasks(self) -> None:
"""取消所有正在运行的任务"""
running_tasks = [
task for task in self._tasks.values()
if task.status == TaskStatus.RUNNING and task._asyncio_task
task for task in self._tasks.values() if task.status == TaskStatus.RUNNING and task._asyncio_task
]
if not running_tasks:
@@ -458,15 +456,13 @@ class UnifiedScheduler:
# 第二阶段:等待取消完成(带超时)
cancel_tasks = [
task._asyncio_task for task in running_tasks
if task._asyncio_task and not task._asyncio_task.done()
task._asyncio_task for task in running_tasks if task._asyncio_task and not task._asyncio_task.done()
]
if cancel_tasks:
try:
await asyncio.wait_for(
asyncio.gather(*cancel_tasks, return_exceptions=True),
timeout=self.config.shutdown_timeout
asyncio.gather(*cancel_tasks, return_exceptions=True), timeout=self.config.shutdown_timeout
)
logger.info("所有任务已成功取消")
except asyncio.TimeoutError:
@@ -484,10 +480,7 @@ class UnifiedScheduler:
if not self._stopping:
# 使用 create_task 避免阻塞循环
asyncio.create_task(
self._check_and_trigger_tasks(),
name="check_trigger_tasks"
)
asyncio.create_task(self._check_and_trigger_tasks(), name="check_trigger_tasks")
except asyncio.CancelledError:
logger.debug("调度器主循环被取消")
@@ -505,10 +498,7 @@ class UnifiedScheduler:
if not self._stopping:
# 使用 create_task 避免阻塞循环,并限制错误传播
asyncio.create_task(
self._safe_check_and_handle_deadlocks(),
name="deadlock_check"
)
asyncio.create_task(self._safe_check_and_handle_deadlocks(), name="deadlock_check")
except asyncio.CancelledError:
logger.debug("死锁检测循环被取消")
@@ -624,10 +614,7 @@ class UnifiedScheduler:
# 为每个任务创建独立的执行 Task
execution_tasks = []
for task in tasks:
exec_task = asyncio.create_task(
self._execute_task(task),
name=f"exec_{task.task_name}"
)
exec_task = asyncio.create_task(self._execute_task(task), name=f"exec_{task.task_name}")
task._asyncio_task = exec_task
execution_tasks.append(exec_task)
@@ -647,16 +634,12 @@ class UnifiedScheduler:
timeout = task.timeout or self.config.task_default_timeout
try:
await asyncio.wait_for(
self._run_callback(task),
timeout=timeout
)
await asyncio.wait_for(self._run_callback(task), timeout=timeout)
# 执行成功
task.finish_execution(success=True)
self._total_executions += 1
logger.debug(f"任务 {task.task_name} 执行成功 "
f"(第{task.trigger_count}次)")
logger.debug(f"任务 {task.task_name} 执行成功 (第{task.trigger_count}次)")
except asyncio.TimeoutError:
# 任务超时
@@ -683,8 +666,10 @@ class UnifiedScheduler:
# 检查是否需要重试
if self.config.enable_retry and task.retry_count < task.max_retries:
task.retry_count += 1
logger.info(f"任务 {task.task_name} 将在 {self.config.retry_delay}秒后重试 "
f"({task.retry_count}/{task.max_retries})")
logger.info(
f"任务 {task.task_name} 将在 {self.config.retry_delay}秒后重试 "
f"({task.retry_count}/{task.max_retries})"
)
await asyncio.sleep(self.config.retry_delay)
task.status = TaskStatus.PENDING # 重置为待触发状态
@@ -706,8 +691,7 @@ class UnifiedScheduler:
# 同步函数在线程池中运行,避免阻塞事件循环
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None,
lambda: task.callback(*task.callback_args, **task.callback_kwargs)
None, lambda: task.callback(*task.callback_args, **task.callback_kwargs)
)
return result
except Exception as e:
@@ -721,6 +705,7 @@ class UnifiedScheduler:
else:
# 返回一个空的上下文管理器
from contextlib import nullcontext
return nullcontext()
async def _move_to_completed(self, task: ScheduleTask) -> None:
@@ -769,8 +754,7 @@ class UnifiedScheduler:
for task in tasks_to_trigger:
# 将事件参数注入到回调
exec_task = asyncio.create_task(
self._execute_event_task(task, event_params),
name=f"event_exec_{task.task_name}"
self._execute_event_task(task, event_params), name=f"event_exec_{task.task_name}"
)
task._asyncio_task = exec_task
execution_tasks.append(exec_task)
@@ -792,18 +776,12 @@ class UnifiedScheduler:
merged_kwargs = {**task.callback_kwargs, **event_params}
if asyncio.iscoroutinefunction(task.callback):
await asyncio.wait_for(
task.callback(*task.callback_args, **merged_kwargs),
timeout=timeout
)
await asyncio.wait_for(task.callback(*task.callback_args, **merged_kwargs), timeout=timeout)
else:
loop = asyncio.get_running_loop()
await asyncio.wait_for(
loop.run_in_executor(
None,
lambda: task.callback(*task.callback_args, **merged_kwargs)
),
timeout=timeout
loop.run_in_executor(None, lambda: task.callback(*task.callback_args, **merged_kwargs)),
timeout=timeout,
)
task.finish_execution(success=True)
@@ -863,10 +841,7 @@ class UnifiedScheduler:
continue
health = self._deadlock_detector.get_health_score(task_id)
logger.warning(
f"任务 {task_name} 疑似死锁: "
f"运行时间={runtime:.1f}秒, 健康度={health:.2f}"
)
logger.warning(f"任务 {task_name} 疑似死锁: 运行时间={runtime:.1f}秒, 健康度={health:.2f}")
# 尝试取消任务(每个取消操作独立处理错误)
try:
@@ -893,19 +868,16 @@ class UnifiedScheduler:
for i, timeout in enumerate(timeouts):
try:
# 使用 asyncio.wait 代替 wait_for避免重新抛出异常
done, pending = await asyncio.wait(
{task._asyncio_task},
timeout=timeout
)
done, pending = await asyncio.wait({task._asyncio_task}, timeout=timeout)
if done:
# 任务已完成(可能是正常完成或被取消)
logger.debug(f"任务 {task.task_name} 在阶段 {i+1} 成功停止")
logger.debug(f"任务 {task.task_name} 在阶段 {i + 1} 成功停止")
return True
# 超时:继续下一阶段或放弃
if i < len(timeouts) - 1:
logger.warning(f"任务 {task.task_name} 取消阶段 {i+1} 超时,继续等待...")
logger.warning(f"任务 {task.task_name} 取消阶段 {i + 1} 超时,继续等待...")
continue
else:
logger.error(f"任务 {task.task_name} 取消失败,强制清理")
@@ -927,8 +899,7 @@ class UnifiedScheduler:
"""清理已完成的任务"""
# 清理已完成的一次性任务
completed_tasks = [
task for task in self._tasks.values()
if not task.is_recurring and task.status == TaskStatus.COMPLETED
task for task in self._tasks.values() if not task.is_recurring and task.status == TaskStatus.COMPLETED
]
for task in completed_tasks:
@@ -1116,10 +1087,7 @@ class UnifiedScheduler:
logger.info(f"强制触发任务: {task.task_name}")
# 创建执行任务
exec_task = asyncio.create_task(
self._execute_task(task),
name=f"manual_trigger_{task.task_name}"
)
exec_task = asyncio.create_task(self._execute_task(task), name=f"manual_trigger_{task.task_name}")
task._asyncio_task = exec_task
# 等待完成
@@ -1274,11 +1242,13 @@ class UnifiedScheduler:
runtime = 0.0
if task.current_execution:
runtime = (datetime.now() - task.current_execution.started_at).total_seconds()
running_tasks_info.append({
running_tasks_info.append(
{
"schedule_id": task.schedule_id[:8] + "...",
"task_name": task.task_name,
"runtime": runtime,
})
}
)
return {
"is_running": self._running,
@@ -1316,6 +1286,7 @@ class UnifiedScheduler:
# 全局调度器实例
unified_scheduler = UnifiedScheduler()
async def initialize_scheduler():
"""初始化调度器