This commit is contained in:
Windpicker-owo
2025-11-12 13:38:12 +08:00
36 changed files with 934 additions and 626 deletions

2
bot.py
View File

@@ -588,7 +588,7 @@ class MaiBotMain:
async def run_async_init(self, main_system):
"""执行异步初始化步骤"""
# 初始化数据库表结构
await self.initialize_database_async()

View File

@@ -19,14 +19,13 @@
import asyncio
import sys
from pathlib import Path
from typing import List
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
async def generate_missing_embeddings(
target_node_types: List[str] = None,
target_node_types: list[str] = None,
batch_size: int = 50,
):
"""
@@ -46,13 +45,13 @@ async def generate_missing_embeddings(
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
print(f"\n{'='*80}")
print(f"🔧 为节点生成嵌入向量")
print("🔧 为节点生成嵌入向量")
print(f"{'='*80}\n")
print(f"目标节点类型: {', '.join(target_node_types)}")
print(f"批处理大小: {batch_size}\n")
# 1. 初始化记忆管理器
print(f"🔧 正在初始化记忆管理器...")
print("🔧 正在初始化记忆管理器...")
await initialize_memory_manager()
manager = get_memory_manager()
@@ -60,10 +59,10 @@ async def generate_missing_embeddings(
print("❌ 记忆管理器初始化失败")
return
print(f"✅ 记忆管理器已初始化\n")
print("✅ 记忆管理器已初始化\n")
# 2. 获取已索引的节点ID
print(f"🔍 检查现有向量索引...")
print("🔍 检查现有向量索引...")
existing_node_ids = set()
try:
vector_count = manager.vector_store.collection.count()
@@ -78,14 +77,14 @@ async def generate_missing_embeddings(
)
if result and "ids" in result:
existing_node_ids.update(result["ids"])
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
except Exception as e:
logger.warning(f"获取已索引节点ID失败: {e}")
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
# 3. 收集需要生成嵌入的节点
print(f"🔍 扫描需要生成嵌入的节点...")
print("🔍 扫描需要生成嵌入的节点...")
all_memories = manager.graph_store.get_all_memories()
nodes_to_process = []
@@ -110,7 +109,7 @@ async def generate_missing_embeddings(
})
type_stats[node.node_type.value]["need_emb"] += 1
print(f"\n📊 扫描结果:")
print("\n📊 扫描结果:")
for node_type in target_node_types:
stats = type_stats[node_type]
already_ok = stats["already_indexed"]
@@ -121,11 +120,11 @@ async def generate_missing_embeddings(
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
if len(nodes_to_process) == 0:
print(f"✅ 所有节点已有嵌入向量,无需生成")
print("✅ 所有节点已有嵌入向量,无需生成")
return
# 3. 批量生成嵌入
print(f"🚀 开始生成嵌入向量...\n")
print("🚀 开始生成嵌入向量...\n")
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
success_count = 0
@@ -193,22 +192,22 @@ async def generate_missing_embeddings(
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
# 4. 保存图数据(更新节点的 embedding 字段)
print(f"💾 保存图数据...")
print("💾 保存图数据...")
try:
await manager.persistence.save_graph_store(manager.graph_store)
print(f"✅ 图数据已保存\n")
print("✅ 图数据已保存\n")
except Exception as e:
logger.error(f"保存图数据失败", exc_info=True)
logger.error("保存图数据失败", exc_info=True)
print(f"❌ 保存失败: {e}\n")
# 5. 验证结果
print(f"🔍 验证向量索引...")
print("🔍 验证向量索引...")
final_vector_count = manager.vector_store.collection.count()
stats = manager.graph_store.get_statistics()
total_nodes = stats["total_nodes"]
print(f"\n{'='*80}")
print(f"📊 生成完成")
print("📊 生成完成")
print(f"{'='*80}")
print(f"处理节点数: {len(nodes_to_process)}")
print(f"成功生成: {success_count}")
@@ -219,7 +218,7 @@ async def generate_missing_embeddings(
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
# 6. 测试搜索
print(f"🧪 测试搜索功能...")
print("🧪 测试搜索功能...")
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
for query in test_queries:

View File

@@ -38,7 +38,7 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
# ========== 性能配置参数 ==========
#
#
# 知识提取步骤2txt转json并发控制
# - 控制同时进行的LLM提取请求数量
# - 推荐值: 3-10取决于API速率限制
@@ -184,7 +184,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
tuple: (doc_item或None, failed_hash或None)
"""
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
# 🔧 优化:使用异步文件检查,避免阻塞
if os.path.exists(temp_file_path):
try:
@@ -215,11 +215,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
"extracted_entities": extracted_data.get("entities", []),
"extracted_triples": extracted_data.get("triples", []),
}
# 保存到缓存(异步写入)
async with aiofiles.open(temp_file_path, "wb") as f:
await f.write(orjson.dumps(doc_item))
return doc_item, None
except Exception as e:
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
@@ -249,13 +249,13 @@ async def extract_information(paragraphs_dict, model_set):
os.makedirs(TEMP_DIR, exist_ok=True)
failed_hashes, open_ie_docs = [], []
# 🔧 关键修复:创建单个 LLM 请求实例,复用连接
llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction")
# 🔧 并发控制:限制最大并发数,防止速率限制
semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY)
async def extract_with_semaphore(pg_hash, paragraph):
"""带信号量控制的提取函数"""
async with semaphore:
@@ -266,7 +266,7 @@ async def extract_information(paragraphs_dict, model_set):
extract_with_semaphore(p_hash, paragraph)
for p_hash, paragraph in paragraphs_dict.items()
]
total = len(tasks)
completed = 0
@@ -284,7 +284,7 @@ async def extract_information(paragraphs_dict, model_set):
TimeRemainingColumn(),
) as progress:
task = progress.add_task("[cyan]正在提取信息...", total=total)
# 🔧 优化:使用 asyncio.gather 并发执行所有任务
# return_exceptions=True 确保单个失败不影响其他任务
for coro in asyncio.as_completed(tasks):
@@ -293,7 +293,7 @@ async def extract_information(paragraphs_dict, model_set):
failed_hashes.append(failed_hash)
elif doc_item:
open_ie_docs.append(doc_item)
completed += 1
progress.update(task, advance=1)
@@ -415,7 +415,7 @@ def rebuild_faiss_only():
logger.info("--- 重建 FAISS 索引 ---")
# 重建索引不需要并发参数(不涉及 embedding 生成)
embed_manager = EmbeddingManager()
logger.info("正在加载现有的 Embedding 库...")
try:
embed_manager.load_from_file()

View File

@@ -4,13 +4,13 @@
提供 Web API 用于可视化记忆图数据
"""
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from collections import defaultdict
from typing import Any
import orjson
from fastapi import APIRouter, HTTPException, Request, Query
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
@@ -29,7 +29,7 @@ router = APIRouter()
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
def find_available_data_files() -> List[Path]:
def find_available_data_files() -> list[Path]:
"""查找所有可用的记忆图数据文件"""
files = []
if not data_dir.exists():
@@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]:
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据"""
global graph_data_cache, current_data_file
@@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any
if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
with open(graph_file, "r", encoding="utf-8") as f:
with open(graph_file, encoding="utf-8") as f:
data = orjson.loads(f.read())
nodes = data.get("nodes", [])
@@ -150,7 +150,7 @@ async def index(request: Request):
return templates.TemplateResponse("visualizer.html", {"request": request})
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据"""
if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
@@ -188,7 +188,7 @@ def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
"arrows": "to",
"memory_id": memory.id,
}
edges_list = list(edges_dict.values())
stats = memory_manager.get_statistics()
@@ -261,7 +261,7 @@ async def get_paginated_graph(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
):
"""分页获取图数据,支持重要性过滤"""
try:
@@ -301,13 +301,13 @@ async def get_paginated_graph(
total_pages = (total_nodes + page_size - 1) // page_size
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_nodes)
paginated_nodes = nodes_with_importance[start_idx:end_idx]
node_ids = set(n["id"] for n in paginated_nodes)
# 只保留连接分页节点的边
paginated_edges = [
e for e in edges
e for e in edges
if e.get("from") in node_ids and e.get("to") in node_ids
]
@@ -383,7 +383,7 @@ async def get_clustered_graph(
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
"""简单的图聚类算法:按类型和连接度聚类"""
# 构建邻接表
adjacency = defaultdict(set)
@@ -412,21 +412,21 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
for node in type_nodes:
importance = len(adjacency[node["id"]])
node_importance.append((node, importance))
node_importance.sort(key=lambda x: x[1], reverse=True)
# 保留前N个重要节点
keep_count = min(len(type_nodes), max_nodes // len(type_groups))
for node, importance in node_importance[:keep_count]:
clustered_nodes.append(node)
node_mapping[node["id"]] = node["id"]
# 其余节点聚合为一个超级节点
if len(node_importance) > keep_count:
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
clustered_nodes.append({
"id": cluster_id,
"label": cluster_label,
@@ -436,7 +436,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
"cluster_size": len(clustered_node_ids),
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
})
for node_id in clustered_node_ids:
node_mapping[node_id] = cluster_id
@@ -445,7 +445,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
for edge in edges:
from_id = node_mapping.get(edge["from"])
to_id = node_mapping.get(edge["to"])
if from_id and to_id and from_id != to_id:
edge_key = tuple(sorted([from_id, to_id]))
if edge_key not in edge_set:

View File

@@ -1,6 +1,5 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Literal
from typing import Literal
from fastapi import APIRouter, HTTPException, Query

View File

@@ -161,16 +161,16 @@ class EmbeddingStore:
# 限制 chunk_size 和 max_workers 在合理范围内
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
semaphore = asyncio.Semaphore(max_workers)
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
results = {}
# 将字符串列表分成多个 chunk
chunks = []
for i in range(0, len(strs), chunk_size):
chunks.append(strs[i : i + chunk_size])
async def _process_chunk(chunk: list[str]):
"""处理一个 chunk 的字符串(批量获取 embedding"""
async with semaphore:
@@ -180,12 +180,12 @@ class EmbeddingStore:
embedding = await EmbeddingStore._get_embedding_async(llm, s)
embeddings.append(embedding)
results[s] = embedding
if progress_callback:
progress_callback(len(chunk))
return embeddings
# 并发处理所有 chunks
tasks = [_process_chunk(chunk) for chunk in chunks]
await asyncio.gather(*tasks)
@@ -418,22 +418,22 @@ class EmbeddingStore:
# 🔧 修复:检查所有 embedding 的维度是否一致
dimensions = [len(emb) for emb in array]
unique_dims = set(dimensions)
if len(unique_dims) > 1:
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
# 获取期望的维度(使用最常见的维度)
from collections import Counter
dim_counter = Counter(dimensions)
expected_dim = dim_counter.most_common(1)[0][0]
logger.warning(f"将使用最常见的维度: {expected_dim}")
# 过滤掉维度不匹配的 embedding
filtered_array = []
filtered_idx2hash = {}
skipped_count = 0
for i, emb in enumerate(array):
if len(emb) == expected_dim:
filtered_array.append(emb)
@@ -442,11 +442,11 @@ class EmbeddingStore:
skipped_count += 1
hash_key = self.idx2hash[str(i)]
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
array = filtered_array
self.idx2hash = filtered_idx2hash
if not array:
logger.error("过滤后没有可用的 embedding无法构建索引")
embedding_dim = expected_dim

View File

@@ -13,4 +13,4 @@ __all__ = [
"StreamLoopManager",
"message_manager",
"stream_loop_manager",
]
]

View File

@@ -82,7 +82,7 @@ class SingleStreamContextManager:
self.total_messages += 1
self.last_access_time = time.time()
# 如果使用了缓存系统,输出调试信息
if cache_enabled and self.context.is_cache_enabled:
if self.context.is_chatter_processing:

View File

@@ -111,9 +111,9 @@ class StreamLoopManager:
# 获取或创建该流的启动锁
if stream_id not in self._stream_start_locks:
self._stream_start_locks[stream_id] = asyncio.Lock()
lock = self._stream_start_locks[stream_id]
# 使用锁防止并发启动同一个流的多个循环任务
async with lock:
# 获取流上下文
@@ -148,7 +148,7 @@ class StreamLoopManager:
# 紧急取消
context.stream_loop_task.cancel()
await asyncio.sleep(0.1)
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
# 将任务记录到 StreamContext 中
@@ -252,7 +252,7 @@ class StreamLoopManager:
self.stats["total_process_cycles"] += 1
if success:
logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功")
# 🔒 处理成功后,等待一小段时间确保清理操作完成
# 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环
await asyncio.sleep(0.1)
@@ -382,7 +382,7 @@ class StreamLoopManager:
self.chatter_manager.process_stream_context(stream_id, context),
name=f"chatter_process_{stream_id}"
)
# 等待 chatter 任务完成
results = await chatter_task
success = results.get("success", False)
@@ -398,8 +398,8 @@ class StreamLoopManager:
else:
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
return success
except asyncio.CancelledError:
return success
except asyncio.CancelledError:
if chatter_task and not chatter_task.done():
chatter_task.cancel()
raise
@@ -709,4 +709,4 @@ class StreamLoopManager:
# 全局流循环管理器实例
stream_loop_manager = StreamLoopManager()
stream_loop_manager = StreamLoopManager()

View File

@@ -417,7 +417,7 @@ class MessageManager:
return
# 记录详细信息
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
for msg in unread_messages[:3]] # 只显示前3条
logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}")
@@ -446,15 +446,15 @@ class MessageManager:
context = chat_stream.context_manager.context
if hasattr(context, "unread_messages") and context.unread_messages:
unread_count = len(context.unread_messages)
# 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们
if unread_count > 0:
if unread_count > 0:
# 获取所有未读消息的 ID
message_ids = [msg.message_id for msg in context.unread_messages]
# 标记为已读(会移到历史消息)
success = chat_stream.context_manager.mark_messages_as_read(message_ids)
if success:
logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读")
else:
@@ -481,7 +481,7 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
chat_stream.context_manager.context.is_chatter_processing = is_processing
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
except Exception as e:
@@ -517,7 +517,7 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
return chat_stream.context_manager.context.is_chatter_processing
except Exception:
pass
@@ -677,4 +677,4 @@ class MessageManager:
# 创建全局消息管理器实例
message_manager = MessageManager()
message_manager = MessageManager()

View File

@@ -248,16 +248,16 @@ class ChatterActionManager:
try:
# 根据动作类型确定提示词模式
prompt_mode = "s4u" if action_name == "reply" else "normal"
# 将prompt_mode传递给generate_reply
action_data_with_mode = (action_data or {}).copy()
action_data_with_mode["prompt_mode"] = prompt_mode
# 只传递当前正在执行的动作,而不是所有可用动作
# 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用"
current_action_info = self._using_actions.get(action_name)
current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {}
# 附加目标消息信息(如果存在)
if target_message:
# 提取目标消息的关键信息
@@ -268,7 +268,7 @@ class ChatterActionManager:
"time": getattr(target_message, "time", 0),
}
current_actions["_target_message"] = target_msg_info
success, response_set, _ = await generator_api.generate_reply(
chat_stream=chat_stream,
reply_message=target_message,
@@ -295,12 +295,12 @@ class ChatterActionManager:
should_quote_reply = None
if action_data and isinstance(action_data, dict):
should_quote_reply = action_data.get("should_quote_reply", None)
# respond动作默认不引用回复保持对话流畅
if action_name == "respond" and should_quote_reply is None:
should_quote_reply = False
async def _after_reply():
async def _after_reply():
# 发送并存储回复
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
chat_stream,

View File

@@ -365,7 +365,7 @@ class DefaultReplyer:
# 确保类型安全
if isinstance(mode, str):
prompt_mode_value = mode
# 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_reply_context(
@@ -1171,16 +1171,16 @@ class DefaultReplyer:
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream_obj = await chat_manager.get_stream(chat_id)
if chat_stream_obj:
unread_messages = chat_stream_obj.context_manager.get_unread_messages()
if unread_messages:
# 使用最后一条未读消息作为参考
last_msg = unread_messages[-1]
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else ""
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else ""
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else ""
platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
processed_plain_text = last_msg.processed_plain_text or ""
else:
# 没有未读消息,使用默认值
@@ -1263,19 +1263,19 @@ class DefaultReplyer:
if available_actions:
# 过滤掉特殊键以_开头
action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")}
# 提取目标消息信息(如果存在)
target_msg_info = available_actions.get("_target_message") # type: ignore
if action_items:
if len(action_items) == 1:
# 单个动作
action_name, action_info = list(action_items.items())[0]
action_desc = action_info.description
# 构建基础决策信息
action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n"
# 只有需要目标消息的动作才显示目标消息详情
# respond 动作是统一回应所有未读消息,不应该显示特定目标消息
if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict):
@@ -1284,7 +1284,7 @@ class DefaultReplyer:
content = target_msg_info.get("content", "")
msg_time = target_msg_info.get("time", 0)
time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间"
action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n"
else:
# 多个动作
@@ -2137,7 +2137,7 @@ class DefaultReplyer:
except Exception as e:
logger.error(f"存储聊天记忆失败: {e}")
def weighted_sample_no_replacement(items, weights, k) -> list:
"""

View File

@@ -5,12 +5,12 @@
插件可以通过实现这些接口来扩展安全功能。
"""
from .interfaces import SecurityCheckResult, SecurityChecker
from .interfaces import SecurityChecker, SecurityCheckResult
from .manager import SecurityManager, get_security_manager
__all__ = [
"SecurityChecker",
"SecurityCheckResult",
"SecurityChecker",
"SecurityManager",
"get_security_manager",
]

View File

@@ -10,7 +10,7 @@ from typing import Any
from src.common.logger import get_logger
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel
from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
logger = get_logger("security.manager")

View File

@@ -1,5 +1,7 @@
import asyncio
import copy
import re
from collections.abc import Awaitable, Callable
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger
@@ -12,122 +14,205 @@ logger = get_logger("prompt_component_manager")
class PromptComponentManager:
"""
管理所有 `BasePrompt` 组件的单例类
一个统一的、动态的、可观测的提示词组件管理中心
该管理器负责:
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
2. 根据注入点目标Prompt名称对它们进行筛选
3. 提供一个接口以便在构建核心Prompt时能够获取并执行所有相关的组件。
该管理器是整个提示词动态注入系统的核心,它负责:
1. **规则加载**: 在系统启动时,自动扫描所有已注册的 `BasePrompt` 组件,
并将其静态定义的 `injection_rules` 加载为默认的动态规则
2. **动态管理**: 提供线程安全的 API允许在运行时动态地添加、更新或移除注入规则
使得提示词的结构可以被实时调整。
3. **状态观测**: 提供丰富的查询 API用于观测系统当前完整的注入状态
例如查询所有注入到特定目标的规则、或查询某个组件定义的所有规则。
4. **注入应用**: 在构建核心 Prompt 时,根据统一的、按优先级排序的规则集,
动态地修改和装配提示词模板,实现灵活的提示词组合。
"""
def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, type[BasePrompt]]]:
"""
获取指定目标Prompt的所有注入规则及其关联的组件类
def __init__(self):
"""初始化管理器实例。"""
# _dynamic_rules 是管理器的核心状态,存储所有注入规则
# 结构: {
# "target_prompt_name": {
# "prompt_component_name": (InjectionRule, content_provider, source)
# }
# }
# content_provider 是一个异步函数,用于在应用规则时动态生成注入内容。
# source 记录了规则的来源(例如 "static_default" 或 "runtime")。
self._dynamic_rules: dict[str, dict[str, tuple[InjectionRule, Callable[..., Awaitable[str]], str]]] = {}
self._lock = asyncio.Lock() # 使用异步锁确保对 _dynamic_rules 的并发访问安全。
self._initialized = False # 标记静态规则是否已加载,防止重复加载。
Args:
target_prompt_name (str): 目标 Prompt 的名称。
# --- 核心生命周期与初始化 ---
Returns:
list[tuple[InjectionRule, Type[BasePrompt]]]: 一个元组列表,
每个元组包含一个注入规则和其对应的 Prompt 组件类,并已根据优先级排序。
def load_static_rules(self):
"""
# 从注册表中获取所有已启用的 PROMPT 类型的组件
在系统启动时加载所有静态注入规则。
该方法会扫描所有已在 `component_registry` 中注册并启用的 Prompt 组件,
将其类变量 `injection_rules` 转换为管理器的动态规则。
这确保了所有插件定义的默认注入行为在系统启动时就能生效。
此操作是幂等的,一旦初始化完成就不会重复执行。
"""
if self._initialized:
return
logger.info("正在加载静态 Prompt 注入规则...")
# 从组件注册表中获取所有已启用的 Prompt 组件
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
matching_rules = []
# 遍历所有启用的 Prompt 组件,查找与目标 Prompt 相关的注入规则
for prompt_name, prompt_info in enabled_prompts.items():
if not isinstance(prompt_info, PromptInfo):
continue
# prompt_info.injection_rules 已经经过了后向兼容处理,确保总是列表
for rule in prompt_info.injection_rules:
# 如果规则的目标是当前指定的 Prompt
if rule.target_prompt == target_prompt_name:
# 获取该规则对应的组件类
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
# 确保获取到的确实是一个 BasePrompt 的子类
if component_class and issubclass(component_class, BasePrompt):
matching_rules.append((rule, component_class))
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
if not (component_class and issubclass(component_class, BasePrompt)):
logger.warning(f"无法为 '{prompt_name}' 加载静态规则,因为它不是一个有效的 Prompt 组件。")
continue
# 根据规则的优先级进行排序,数字越小,优先级越高,越先应用
matching_rules.sort(key=lambda x: x[0].priority)
return matching_rules
def create_provider(cls: type[BasePrompt]) -> Callable[[PromptParameters], Awaitable[str]]:
"""
为静态组件创建一个内容提供者闭包 (Content Provider Closure)。
这个闭包捕获了组件的类 `cls`,并返回一个标准的 `content_provider` 异步函数。
当 `apply_injections` 需要内容时,它会调用这个函数。
函数内部会实例化组件,并执行其 `execute` 方法来获取注入内容。
Args:
cls (type[BasePrompt]): 需要为其创建提供者的 Prompt 组件类。
Returns:
Callable[[PromptParameters], Awaitable[str]]: 一个符合管理器标准的异步内容提供者。
"""
async def content_provider(params: PromptParameters) -> str:
"""实际执行内容生成的异步函数。"""
try:
# 从注册表获取最新的组件信息,包括插件配置
p_info = component_registry.get_component_info(cls.prompt_name, ComponentType.PROMPT)
plugin_config = {}
if isinstance(p_info, PromptInfo):
plugin_config = component_registry.get_plugin_config(p_info.plugin_name)
# 实例化组件并执行
instance = cls(params=params, plugin_config=plugin_config)
result = await instance.execute()
return str(result) if result is not None else ""
except Exception as e:
logger.error(f"执行静态规则提供者 '{cls.prompt_name}' 时出错: {e}", exc_info=True)
return "" # 出错时返回空字符串,避免影响主流程
return content_provider
# 为该组件的每条静态注入规则创建并注册一个动态规则
for rule in prompt_info.injection_rules:
provider = create_provider(component_class)
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
target_rules[prompt_name] = (rule, provider, "static_default")
self._initialized = True
logger.info(f"静态 Prompt 注入规则加载完成,共处理 {len(enabled_prompts)} 个组件。")
# --- 运行时规则管理 API ---
async def add_injection_rule(
self,
prompt_name: str,
rule: InjectionRule,
content_provider: Callable[..., Awaitable[str]],
source: str = "runtime",
) -> bool:
"""
动态添加或更新一条注入规则。
此方法允许在系统运行时,由外部逻辑(如插件、命令)向管理器中添加新的注入行为。
如果已存在同名组件针对同一目标的规则,此方法会覆盖旧规则。
Args:
prompt_name (str): 动态注入组件的唯一名称。
rule (InjectionRule): 描述注入行为的规则对象。
content_provider (Callable[..., Awaitable[str]]):
一个异步函数,用于在应用注入时动态生成内容。
函数签名应为: `async def provider(params: "PromptParameters") -> str`
source (str, optional): 规则的来源标识,默认为 "runtime"
Returns:
bool: 如果成功添加或更新,则返回 True。
"""
async with self._lock:
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
target_rules[prompt_name] = (rule, content_provider, source)
logger.info(f"成功添加/更新注入规则: '{prompt_name}' -> '{rule.target_prompt}' (来源: {source})")
return True
async def remove_injection_rule(self, prompt_name: str, target_prompt: str) -> bool:
"""
移除一条动态注入规则。
Args:
prompt_name (str): 要移除的注入组件的名称。
target_prompt (str): 该组件注入的目标核心提示词名称。
Returns:
bool: 如果成功移除,则返回 True如果规则不存在则返回 False。
"""
async with self._lock:
if target_prompt in self._dynamic_rules and prompt_name in self._dynamic_rules[target_prompt]:
del self._dynamic_rules[target_prompt][prompt_name]
# 如果目标下已无任何规则,则清理掉这个键
if not self._dynamic_rules[target_prompt]:
del self._dynamic_rules[target_prompt]
logger.info(f"成功移除注入规则: '{prompt_name}' from '{target_prompt}'")
return True
logger.warning(f"尝试移除注入规则失败: 未找到 '{prompt_name}' on '{target_prompt}'")
return False
# --- 核心注入逻辑 ---
async def apply_injections(
self, target_prompt_name: str, original_template: str, params: PromptParameters
) -> str:
"""
获取、实例化并执行所有相关组件,然后根据注入规则修改原始模板。
【核心方法】根据目标名称,应用所有匹配的注入规则,返回修改后的模板。
这是一个三步走的过程
1. 实例化所有需要执行的组件
2. 并行执行它们的 `execute` 方法以获取注入内容
3. 按照优先级顺序,将内容注入到原始模板中
这是提示词构建流程中的关键步骤。它会执行以下操作
1. 检查并确保静态规则已加载
2. 获取所有注入到 `target_prompt_name` 的规则
3. 按照规则的 `priority` 属性进行升序排序,优先级数字越小越先应用
4. 依次执行每个规则的 `content_provider` 来异步获取注入内容。
5. 根据规则的 `injection_type` (如 PREPEND, APPEND, REPLACE 等) 将内容应用到模板上。
Args:
target_prompt_name (str): 目标 Prompt 的名称。
original_template (str): 原始的、未经修改的 Prompt 模板字符串
params (PromptParameters): 传递给 Prompt 组件实例的参数
target_prompt_name (str): 目标核心提示词的名称。
original_template (str): 未经修改的原始提示词模板
params (PromptParameters): 当前请求的参数,会传递给 `content_provider`
Returns:
str: 应用了所有注入规则后,修改过的 Prompt 模板字符串。
str: 应用了所有注入规则后,最终生成的提示词模板字符串。
"""
rules_with_classes = self._get_rules_for(target_prompt_name)
# 如果没有找到任何匹配的规则,就直接返回原始模板,啥也不干
if not rules_with_classes:
if not self._initialized:
self.load_static_rules()
# 步骤 1: 获取所有指向当前目标的规则
# 使用 .values() 获取 (rule, provider, source) 元组列表
rules_for_target = list(self._dynamic_rules.get(target_prompt_name, {}).values())
if not rules_for_target:
return original_template
# --- 第一步: 实例化所有需要执行的组件 ---
instance_map = {} # 存储组件实例,虽然目前没直接用,但留着总没错
tasks = [] # 存放所有需要并行执行的 execute 异步任务
components_to_execute = [] # 存放需要执行的组件类,用于后续结果映射
# 步骤 2: 按优先级排序,数字越小越优先
rules_for_target.sort(key=lambda x: x[0].priority)
for rule, component_class in rules_with_classes:
# 如果注入类型是 REMOVE那就不需要执行组件了因为它不产生内容
# 步骤 3: 依次执行内容提供者并根据注入类型修改模板
modified_template = original_template
for rule, provider, source in rules_for_target:
content = ""
# 对于非 REMOVE 类型的注入,需要先获取内容
if rule.injection_type != InjectionType.REMOVE:
try:
# 获取组件的元信息,主要是为了拿到插件名称来读取插件配置
prompt_info = component_registry.get_component_info(
component_class.prompt_name, ComponentType.PROMPT
)
if not isinstance(prompt_info, PromptInfo):
plugin_config = {}
else:
# 从注册表获取该组件所属插件的配置
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
# 实例化组件,并传入参数和插件配置
instance = component_class(params=params, plugin_config=plugin_config)
instance_map[component_class.prompt_name] = instance
# 将组件的 execute 方法作为一个任务添加到列表中
tasks.append(instance.execute())
components_to_execute.append(component_class)
content = await provider(params)
except Exception as e:
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
# 即使失败,也添加一个立即完成的空任务,以保持与其他任务的索引同步
tasks.append(asyncio.create_task(asyncio.sleep(0, result=e))) # type: ignore
# --- 第二步: 并行执行所有组件的 execute 方法 ---
# 使用 asyncio.gather 来同时运行所有任务,提高效率
results = await asyncio.gather(*tasks, return_exceptions=True)
# 创建一个从组件名到执行结果的映射,方便后续查找
result_map = {
components_to_execute[i].prompt_name: res
for i, res in enumerate(results)
if not isinstance(res, Exception) # 只包含成功的结果
}
# 单独处理并记录执行失败的组件
for i, res in enumerate(results):
if isinstance(res, Exception):
logger.error(f"执行 Prompt 组件 '{components_to_execute[i].prompt_name}' 失败: {res}")
# --- 第三步: 按优先级顺序应用注入规则 ---
modified_template = original_template
for rule, component_class in rules_with_classes:
# 从结果映射中获取该组件生成的内容
content = result_map.get(component_class.prompt_name)
logger.error(f"执行规则 '{rule}' (来源: {source}) 的内容提供者时失败: {e}", exc_info=True)
continue # 跳过失败的 provider不中断整个流程
# 应用注入逻辑
try:
if rule.injection_type == InjectionType.PREPEND:
if content:
@@ -136,28 +221,178 @@ class PromptComponentManager:
if content:
modified_template = f"{modified_template}\n{content}"
elif rule.injection_type == InjectionType.REPLACE:
# 使用正则表达式替换目标内容
if content and rule.target_content:
# 只有在 content 不为 None 且 target_content 有效时才执行替换
if content is not None and rule.target_content:
modified_template = re.sub(rule.target_content, str(content), modified_template)
elif rule.injection_type == InjectionType.INSERT_AFTER:
# 在匹配到的内容后面插入
if content and rule.target_content:
# re.sub a little trick: \g<0> represents the entire matched string
# 使用 `\g<0>` 在正则匹配的整个内容后添加新内容
replacement = f"\\g<0>\n{content}"
modified_template = re.sub(rule.target_content, replacement, modified_template)
elif rule.injection_type == InjectionType.REMOVE:
# 使用正则表达式移除目标内容
if rule.target_content:
modified_template = re.sub(rule.target_content, "", modified_template)
except re.error as e:
logger.error(
f"在为 '{component_class.prompt_name}' 应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')"
)
logger.error(f"应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')")
except Exception as e:
logger.error(f"应用 Prompt 注入规则 '{rule}' 失败: {e}")
logger.error(f"应用注入规则 '{rule}' (来源: {source}) 失败: {e}", exc_info=True)
return modified_template
async def preview_prompt_injections(
self, target_prompt_name: str, params: PromptParameters
) -> str:
"""
【预览功能】模拟应用所有注入规则,返回最终生成的模板字符串,而不实际修改任何状态。
# 创建全局单例
这个方法对于调试和测试非常有用,可以查看在特定参数下,
一个核心提示词经过所有注入规则处理后会变成什么样子。
Args:
target_prompt_name (str): 希望预览的目标核心提示词名称。
params (PromptParameters): 模拟的请求参数。
Returns:
str: 模拟生成的最终提示词模板字符串。如果找不到模板,则返回错误信息。
"""
try:
# 从全局提示词管理器获取最原始的模板内容
from src.chat.utils.prompt import global_prompt_manager
original_prompt = global_prompt_manager._prompts.get(target_prompt_name)
if not original_prompt:
logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。")
return f"Error: Prompt '{target_prompt_name}' not found."
original_template = original_prompt.template
except KeyError:
logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。")
return f"Error: Prompt '{target_prompt_name}' not found."
# 直接调用核心注入逻辑来模拟结果
return await self.apply_injections(target_prompt_name, original_template, params)
# --- 状态观测与查询 API ---
def get_core_prompts(self) -> list[str]:
"""获取所有已注册的核心提示词模板名称列表(即所有可注入的目标)。"""
from src.chat.utils.prompt import global_prompt_manager
return list(global_prompt_manager._prompts.keys())
def get_core_prompt_contents(self) -> dict[str, str]:
"""获取所有核心提示词模板的原始内容。"""
from src.chat.utils.prompt import global_prompt_manager
return {name: prompt.template for name, prompt in global_prompt_manager._prompts.items()}
def get_registered_prompt_component_info(self) -> list[PromptInfo]:
"""获取所有在 ComponentRegistry 中注册的 Prompt 组件信息。"""
components = component_registry.get_components_by_type(ComponentType.PROMPT).values()
return [info for info in components if isinstance(info, PromptInfo)]
async def get_full_injection_map(self) -> dict[str, list[dict]]:
"""
获取当前完整的注入映射图。
此方法提供了一个系统全局的注入视图展示了每个核心提示词target
被哪些注入组件source以何种优先级注入。
Returns:
dict[str, list[dict]]: 一个字典,键是目标提示词名称,
值是按优先级排序的注入信息列表。
`[{"name": str, "priority": int, "source": str}]`
"""
injection_map = {}
async with self._lock:
# 合并所有动态规则的目标和所有核心提示词,确保所有潜在目标都被包含
all_targets = set(self._dynamic_rules.keys()) | set(self.get_core_prompts())
for target in sorted(all_targets):
rules = self._dynamic_rules.get(target, {})
if not rules:
injection_map[target] = []
continue
info_list = []
for prompt_name, (rule, _, source) in rules.items():
info_list.append({"name": prompt_name, "priority": rule.priority, "source": source})
# 按优先级排序后存入 map
info_list.sort(key=lambda x: x["priority"])
injection_map[target] = info_list
return injection_map
async def get_injections_for_prompt(self, target_prompt_name: str) -> list[dict]:
"""
获取指定核心提示词模板的所有注入信息(包含详细规则)。
Args:
target_prompt_name (str): 目标核心提示词的名称。
Returns:
list[dict]: 一个包含注入规则详细信息的列表,已按优先级排序。
"""
rules_for_target = self._dynamic_rules.get(target_prompt_name, {})
if not rules_for_target:
return []
info_list = []
for prompt_name, (rule, _, source) in rules_for_target.items():
info_list.append(
{
"name": prompt_name,
"priority": rule.priority,
"source": source,
"injection_type": rule.injection_type.value,
"target_content": rule.target_content,
}
)
info_list.sort(key=lambda x: x["priority"])
return info_list
def get_all_dynamic_rules(self) -> dict[str, dict[str, "InjectionRule"]]:
"""
获取所有当前的动态注入规则,以 InjectionRule 对象形式返回。
此方法返回一个深拷贝的规则副本,隐藏了 `content_provider` 等内部实现细节。
适合用于展示或序列化当前的规则配置。
"""
rules_copy = {}
for target, rules in self._dynamic_rules.items():
target_copy = {name: rule for name, (rule, _, _) in rules.items()}
rules_copy[target] = target_copy
return copy.deepcopy(rules_copy)
def get_rules_for_target(self, target_prompt: str) -> dict[str, InjectionRule]:
"""
获取所有注入到指定核心提示词的动态规则。
Args:
target_prompt (str): 目标核心提示词的名称。
Returns:
dict[str, InjectionRule]: 一个字典,键是注入组件的名称,值是 `InjectionRule` 对象。
如果找不到任何注入到该目标的规则,则返回一个空字典。
"""
target_rules = self._dynamic_rules.get(target_prompt, {})
return {name: copy.deepcopy(rule_info[0]) for name, rule_info in target_rules.items()}
def get_rules_by_component(self, component_name: str) -> dict[str, InjectionRule]:
"""
获取由指定的单个注入组件定义的所有动态规则。
Args:
component_name (str): 注入组件的名称。
Returns:
dict[str, InjectionRule]: 一个字典,键是目标核心提示词的名称,值是 `InjectionRule` 对象。
如果该组件没有定义任何注入规则,则返回一个空字典。
"""
found_rules = {}
for target, rules in self._dynamic_rules.items():
if component_name in rules:
rule_info = rules[component_name]
found_rules[target] = copy.deepcopy(rule_info[0])
return found_rules
# 创建全局单例 (Singleton)
# 在整个应用程序中,应该只使用这一个 `prompt_component_manager` 实例,
# 以确保所有部分都共享和操作同一份动态规则集。
prompt_component_manager = PromptComponentManager()

View File

@@ -98,7 +98,7 @@ class StreamContext(BaseDataModel):
break
def mark_message_as_read(self, message_id: str):
"""标记消息为已读"""
"""标记消息为已读"""
# 先找到要标记的消息(处理 int/str 类型不匹配问题)
message_to_mark = None
for msg in self.unread_messages:
@@ -106,7 +106,7 @@ class StreamContext(BaseDataModel):
if str(msg.message_id) == str(message_id):
message_to_mark = msg
break
# 然后移动到历史消息
if message_to_mark:
message_to_mark.is_read = True

View File

@@ -9,11 +9,12 @@
"""
import asyncio
import builtins
import time
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union
from typing import Any, Generic, TypeVar
from src.common.logger import get_logger
from src.common.memory_utils import estimate_size_smart
@@ -96,7 +97,7 @@ class LRUCache(Generic[T]):
self._lock = asyncio.Lock()
self._stats = CacheStats()
async def get(self, key: str) -> Optional[T]:
async def get(self, key: str) -> T | None:
"""获取缓存值
Args:
@@ -137,8 +138,8 @@ class LRUCache(Generic[T]):
self,
key: str,
value: T,
size: Optional[int] = None,
ttl: Optional[float] = None,
size: int | None = None,
ttl: float | None = None,
) -> None:
"""设置缓存值
@@ -287,8 +288,8 @@ class MultiLevelCache:
async def get(
self,
key: str,
loader: Optional[Callable[[], Any]] = None,
) -> Optional[Any]:
loader: Callable[[], Any] | None = None,
) -> Any | None:
"""从缓存获取数据
查询顺序L1 -> L2 -> loader
@@ -329,8 +330,8 @@ class MultiLevelCache:
self,
key: str,
value: Any,
size: Optional[int] = None,
ttl: Optional[float] = None,
size: int | None = None,
ttl: float | None = None,
) -> None:
"""设置缓存值
@@ -390,7 +391,7 @@ class MultiLevelCache:
await self.l2_cache.clear()
logger.info("所有缓存已清空")
async def get_stats(self) -> Dict[str, Any]:
async def get_stats(self) -> dict[str, Any]:
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
# 🔧 修复:并行获取统计信息,避免锁嵌套
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
@@ -492,7 +493,7 @@ class MultiLevelCache:
logger.error(f"{cache_name}统计获取异常: {e}")
return CacheStats()
async def _get_cache_keys_safe(self, cache) -> Set[str]:
async def _get_cache_keys_safe(self, cache) -> builtins.set[str]:
"""安全获取缓存键集合(带超时)"""
try:
# 快速获取键集合,使用超时避免死锁
@@ -507,12 +508,12 @@ class MultiLevelCache:
logger.error(f"缓存键获取异常: {e}")
return set()
async def _extract_keys_with_lock(self, cache) -> Set[str]:
async def _extract_keys_with_lock(self, cache) -> builtins.set[str]:
"""在锁保护下提取键集合"""
async with cache._lock:
return set(cache._cache.keys())
async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int:
async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int:
"""安全计算内存使用(带超时)"""
if not keys:
return 0
@@ -529,7 +530,7 @@ class MultiLevelCache:
logger.error(f"内存计算异常: {e}")
return 0
async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int:
async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int:
"""在锁保护下计算内存使用"""
total_size = 0
async with cache._lock:
@@ -749,7 +750,7 @@ class MultiLevelCache:
# 全局缓存实例
_global_cache: Optional[MultiLevelCache] = None
_global_cache: MultiLevelCache | None = None
_cache_lock = asyncio.Lock()

View File

@@ -3,7 +3,6 @@ import socket
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from rich.traceback import install
from uvicorn import Config
from uvicorn import Server as UvicornServer

View File

@@ -436,7 +436,7 @@ class OpenaiClient(BaseClient):
# 🔧 优化增加连接池限制支持高并发embedding请求
# 默认httpx限制为100对于高频embedding场景不够用
import httpx
limits = httpx.Limits(
max_keepalive_connections=200, # 保持活跃连接数原100
max_connections=300, # 最大总连接数原100

View File

@@ -128,7 +128,7 @@ class MemoryBuilder:
# 6. 构建 Memory 对象
# 新记忆应该有较高的初始激活度
initial_activation = 0.75 # 新记忆初始激活度为 0.75
memory = Memory(
id=memory_id,
subject_id=subject_node.id,

View File

@@ -149,7 +149,7 @@ class MemoryManager:
# 读取阈值过滤配置
search_min_importance = self.config.search_min_importance
search_similarity_threshold = self.config.search_similarity_threshold
logger.info(
f"📊 配置检查: search_max_expand_depth={expand_depth}, "
f"search_expand_semantic_threshold={expand_semantic_threshold}, "
@@ -417,7 +417,7 @@ class MemoryManager:
# 使用配置的默认值
if top_k is None:
top_k = getattr(self.config, "search_top_k", 10)
# 准备搜索参数
params = {
"query": query,
@@ -951,7 +951,7 @@ class MemoryManager:
)
else:
logger.debug(f"记忆已删除: {memory_id} (删除了 {deleted_vectors} 个向量)")
# 4. 保存更新
await self.persistence.save_graph_store(self.graph_store)
return True
@@ -984,7 +984,7 @@ class MemoryManager:
try:
forgotten_count = 0
all_memories = self.graph_store.get_all_memories()
# 获取配置参数
min_importance = getattr(self.config, "forgetting_min_importance", 0.8)
decay_rate = getattr(self.config, "activation_decay_rate", 0.9)
@@ -1010,10 +1010,10 @@ class MemoryManager:
try:
last_access_dt = datetime.fromisoformat(last_access)
days_passed = (datetime.now() - last_access_dt).days
# 应用指数衰减activation = base * (decay_rate ^ days)
current_activation = base_activation * (decay_rate ** days_passed)
logger.debug(
f"记忆 {memory.id[:8]}: 基础激活度={base_activation:.3f}, "
f"经过{days_passed}天衰减后={current_activation:.3f}"
@@ -1035,20 +1035,20 @@ class MemoryManager:
# 批量遗忘记忆(不立即清理孤立节点)
if memories_to_forget:
logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...")
for memory_id, activation in memories_to_forget:
# cleanup_orphans=False暂不清理孤立节点
success = await self.forget_memory(memory_id, cleanup_orphans=False)
if success:
forgotten_count += 1
# 统一清理孤立节点和边
logger.info("批量遗忘完成,开始统一清理孤立节点和边...")
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
# 保存最终更新
await self.persistence.save_graph_store(self.graph_store)
logger.info(
f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, "
f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边"
@@ -1079,31 +1079,31 @@ class MemoryManager:
# 1. 清理孤立节点
# graph_store.node_to_memories 记录了每个节点属于哪些记忆
nodes_to_remove = []
for node_id, memory_ids in list(self.graph_store.node_to_memories.items()):
# 如果节点不再属于任何记忆,标记为删除
if not memory_ids:
nodes_to_remove.append(node_id)
# 从图中删除孤立节点
for node_id in nodes_to_remove:
if self.graph_store.graph.has_node(node_id):
self.graph_store.graph.remove_node(node_id)
orphan_nodes_count += 1
# 从映射中删除
if node_id in self.graph_store.node_to_memories:
del self.graph_store.node_to_memories[node_id]
# 2. 清理孤立边(指向已删除节点的边)
edges_to_remove = []
for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'):
for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
# 检查边的源节点和目标节点是否还存在于node_to_memories中
if source not in self.graph_store.node_to_memories or \
target not in self.graph_store.node_to_memories:
edges_to_remove.append((source, target))
# 删除孤立边
for source, target in edges_to_remove:
try:
@@ -1111,12 +1111,12 @@ class MemoryManager:
orphan_edges_count += 1
except Exception as e:
logger.debug(f"删除边失败 {source} -> {target}: {e}")
if orphan_nodes_count > 0 or orphan_edges_count > 0:
logger.info(
f"清理完成: {orphan_nodes_count} 个孤立节点, {orphan_edges_count} 条孤立边"
)
return orphan_nodes_count, orphan_edges_count
except Exception as e:
@@ -1258,7 +1258,7 @@ class MemoryManager:
mem for mem in recent_memories
if mem.importance >= min_importance_for_consolidation
]
result["importance_filtered"] = len(recent_memories) - len(important_memories)
logger.info(
f"📊 步骤2: 重要性过滤 (阈值={min_importance_for_consolidation:.2f}): "
@@ -1382,26 +1382,26 @@ class MemoryManager:
# ===== 步骤4: 向量检索关联记忆 + LLM分析关系 =====
# 过滤掉已删除的记忆
remaining_memories = [m for m in important_memories if m.id not in deleted_ids]
if not remaining_memories:
logger.info("✅ 记忆整理完成: 去重后无剩余记忆")
return
logger.info(f"📍 步骤4: 开始关联分析 ({len(remaining_memories)} 条记忆)...")
# 分批处理记忆关联
llm_batch_size = getattr(self.config, "consolidation_llm_batch_size", 10)
max_candidates_per_memory = getattr(self.config, "consolidation_max_candidates", 5)
min_confidence = getattr(self.config, "consolidation_min_confidence", 0.6)
all_new_edges = [] # 收集所有新建的边
for batch_start in range(0, len(remaining_memories), llm_batch_size):
batch_end = min(batch_start + llm_batch_size, len(remaining_memories))
batch = remaining_memories[batch_start:batch_end]
logger.debug(f"处理批次 {batch_start//llm_batch_size + 1}/{(len(remaining_memories)-1)//llm_batch_size + 1}")
for memory in batch:
# 跳过已经有很多连接的记忆
existing_edges = len([
@@ -1454,14 +1454,14 @@ class MemoryManager:
except Exception as e:
logger.warning(f"创建关联边失败: {e}")
continue
# 每个批次后让出控制权
await asyncio.sleep(0.01)
# ===== 步骤5: 统一更新记忆数据 =====
if all_new_edges:
logger.info(f"📍 步骤5: 统一更新 {len(all_new_edges)} 条新关联边...")
for memory, edge, relation in all_new_edges:
try:
# 添加到图
@@ -2301,7 +2301,7 @@ class MemoryManager:
# 使用 asyncio.wait_for 来支持取消
await asyncio.wait_for(
asyncio.sleep(initial_delay),
timeout=float('inf') # 允许随时取消
timeout=float("inf") # 允许随时取消
)
# 检查是否仍然需要运行

View File

@@ -482,7 +482,7 @@ class GraphStore:
for node in memory.nodes:
if node.id in self.node_to_memories:
self.node_to_memories[node.id].discard(memory_id)
# 可选:立即清理孤立节点
if cleanup_orphans:
# 如果该节点不再属于任何记忆,从图中移除节点

View File

@@ -72,12 +72,12 @@ class MemoryTools:
self.max_expand_depth = max_expand_depth
self.expand_semantic_threshold = expand_semantic_threshold
self.search_top_k = search_top_k
# 保存权重配置
self.base_vector_weight = search_vector_weight
self.base_importance_weight = search_importance_weight
self.base_recency_weight = search_recency_weight
# 保存阈值过滤配置
self.search_min_importance = search_min_importance
self.search_similarity_threshold = search_similarity_threshold
@@ -516,14 +516,14 @@ class MemoryTools:
# 1. 根据策略选择检索方式
llm_prefer_types = [] # LLM识别的偏好节点类型
if use_multi_query:
# 多查询策略(返回节点列表 + 偏好类型)
similar_nodes, llm_prefer_types = await self._multi_query_search(query, top_k, context)
else:
# 传统单查询策略
similar_nodes = await self._single_query_search(query, top_k)
# 合并用户指定的偏好类型和LLM识别的偏好类型
all_prefer_types = list(set(prefer_node_types + llm_prefer_types))
if all_prefer_types:
@@ -551,7 +551,7 @@ class MemoryTools:
# 记录最高分数
if mem_id not in memory_scores or similarity > memory_scores[mem_id]:
memory_scores[mem_id] = similarity
# 🔥 详细日志:检查初始召回情况
logger.info(
f"初始向量搜索: 返回{len(similar_nodes)}个节点 → "
@@ -559,8 +559,8 @@ class MemoryTools:
)
if len(initial_memory_ids) == 0:
logger.warning(
f"⚠️ 向量搜索未找到任何记忆!"
f"可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
"⚠️ 向量搜索未找到任何记忆!"
"可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
)
# 输出相似节点的详细信息用于调试
if similar_nodes:
@@ -692,7 +692,7 @@ class MemoryTools:
key=lambda x: final_scores[x],
reverse=True
) # 🔥 不再提前截断,让所有候选参与详细评分
# 🔍 统计初始记忆的相似度分布(用于诊断)
if memory_scores:
similarities = list(memory_scores.values())
@@ -707,7 +707,7 @@ class MemoryTools:
# 5. 获取完整记忆并进行最终排序(优化后的动态权重系统)
memories_with_scores = []
filter_stats = {"importance": 0, "similarity": 0, "total_checked": 0} # 过滤统计
for memory_id in sorted_memory_ids: # 遍历所有候选
memory = self.graph_store.get_memory_by_id(memory_id)
if memory:
@@ -715,7 +715,7 @@ class MemoryTools:
# 基础分数
similarity_score = final_scores[memory_id]
importance_score = memory.importance
# 🆕 区分记忆来源(用于过滤)
is_initial_memory = memory_id in memory_scores # 是否来自初始向量搜索
true_similarity = memory_scores.get(memory_id, 0.0) if is_initial_memory else None
@@ -738,16 +738,16 @@ class MemoryTools:
activation_score = memory.activation
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type)
memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type)
# 检测记忆的主要节点类型
node_types_count = {}
for node in memory.nodes:
nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type)
nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type)
node_types_count[nt] = node_types_count.get(nt, 0) + 1
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
# 根据记忆类型和节点类型计算调整系数(在配置权重基础上微调)
if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT":
# 事实性记忆:提升相似度权重,降低时效性权重
@@ -777,41 +777,41 @@ class MemoryTools:
"importance": 1.0,
"recency": 1.0,
}
# 应用调整后的权重(基于配置的基础权重)
weights = {
"similarity": self.base_vector_weight * type_adjustments["similarity"],
"importance": self.base_importance_weight * type_adjustments["importance"],
"recency": self.base_recency_weight * type_adjustments["recency"],
}
# 归一化权重确保总和为1.0
total_weight = sum(weights.values())
if total_weight > 0:
weights = {k: v / total_weight for k, v in weights.items()}
# 综合分数计算(🔥 移除激活度影响)
final_score = (
similarity_score * weights["similarity"] +
importance_score * weights["importance"] +
recency_score * weights["recency"]
)
# 🆕 阈值过滤策略:
# 1. 重要性过滤:应用于所有记忆(过滤极低质量)
if memory.importance < self.search_min_importance:
filter_stats["importance"] += 1
logger.debug(f"❌ 过滤 {memory.id[:8]}: 重要性 {memory.importance:.2f} < 阈值 {self.search_min_importance}")
continue
# 2. 相似度过滤:不再对初始向量搜索结果过滤(信任向量搜索的排序)
# 理由:向量搜索已经按相似度排序,返回的都是最相关结果
# 如果再用阈值过滤,会导致"最相关的也不够相关"的矛盾
#
#
# 注意:如果未来需要对扩展记忆过滤,可以在这里添加逻辑
# if not is_initial_memory and some_score < threshold:
# continue
# 记录通过过滤的记忆(用于调试)
if is_initial_memory:
logger.debug(
@@ -823,11 +823,11 @@ class MemoryTools:
f"✅ 保留 {memory.id[:8]} [扩展]: 重要性={memory.importance:.2f}, "
f"综合分数={final_score:.4f}"
)
# 🆕 节点类型加权对REFERENCE/ATTRIBUTE节点额外加分促进事实性信息召回
if "REFERENCE" in node_types_count or "ATTRIBUTE" in node_types_count:
final_score *= 1.1 # 10% 加成
# 🆕 用户指定的优先节点类型额外加权
if prefer_node_types:
for prefer_type in prefer_node_types:
@@ -835,7 +835,7 @@ class MemoryTools:
final_score *= 1.15 # 15% 额外加成
logger.debug(f"记忆 {memory.id[:8]} 包含优先节点类型 {prefer_type},加权后分数: {final_score:.4f}")
break
memories_with_scores.append((memory, final_score, dominant_node_type))
# 按综合分数排序
@@ -845,7 +845,7 @@ class MemoryTools:
# 统计过滤情况
total_candidates = len(all_memory_ids)
filtered_count = total_candidates - len(memories_with_scores)
# 6. 格式化结果(包含调试信息)
results = []
for memory, score, node_type in memories_with_scores[:top_k]:
@@ -866,7 +866,7 @@ class MemoryTools:
f"过滤{filtered_count}个 (重要性过滤) → "
f"最终返回{len(results)}条记忆"
)
# 如果过滤率过高,发出警告
if total_candidates > 0:
filter_rate = filtered_count / total_candidates
@@ -1092,20 +1092,21 @@ class MemoryTools:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
import re
import orjson
# 清理Markdown代码块
response = re.sub(r"```json\s*", "", response)
response = re.sub(r"```\s*$", "", response).strip()
# 解析JSON
data = orjson.loads(response)
# 提取查询列表
queries = data.get("queries", [])
result_queries = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
for item in queries if item.get("text", "").strip()]
# 提取偏好节点类型
prefer_node_types = data.get("prefer_node_types", [])
# 确保类型正确且有效
@@ -1154,7 +1155,7 @@ class MemoryTools:
limit=top_k * 5, # 🔥 从2倍提升到5倍提高初始召回率
min_similarity=0.0, # 不在这里过滤,交给后续评分
)
logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}")
if similar_nodes:
logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}")

View File

@@ -62,7 +62,7 @@ async def expand_memories_with_semantic_filter(
try:
import time
start_time = time.time()
# 记录已访问的记忆,避免重复
visited_memories = set(initial_memory_ids)
# 记录扩展的记忆及其分数
@@ -87,17 +87,17 @@ async def expand_memories_with_semantic_filter(
# 获取该记忆的邻居记忆(通过边关系)
neighbor_memory_ids = set()
# 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重)
edge_weights = {} # 记录通过不同边类型到达的记忆的权重
for edge in memory.edges:
# 获取边的目标节点
target_node_id = edge.target_id
source_node_id = edge.source_id
# 🆕 根据边类型设置权重优先扩展REFERENCE、ATTRIBUTE相关的边
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type)
if edge_type_str == "REFERENCE":
edge_weight = 1.3 # REFERENCE边权重最高引用关系
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
@@ -108,18 +108,18 @@ async def expand_memories_with_semantic_filter(
edge_weight = 0.9 # 一般关系适中降权
else:
edge_weight = 1.0 # 默认权重
# 通过节点找到其他记忆
for node_id in [target_node_id, source_node_id]:
if node_id in graph_store.node_to_memories:
for neighbor_id in graph_store.node_to_memories[node_id]:
if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight:
edge_weights[neighbor_id] = edge_weight
# 将权重高的邻居记忆加入候选
for neighbor_id, edge_weight in edge_weights.items():
neighbor_memory_ids.add((neighbor_id, edge_weight))
# 过滤掉已访问的和自己
filtered_neighbors = []
for neighbor_id, edge_weight in neighbor_memory_ids:
@@ -129,7 +129,7 @@ async def expand_memories_with_semantic_filter(
# 批量评估邻居记忆
for neighbor_mem_id, edge_weight in filtered_neighbors:
candidates_checked += 1
neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id)
if not neighbor_memory:
continue
@@ -139,7 +139,7 @@ async def expand_memories_with_semantic_filter(
(n for n in neighbor_memory.nodes if n.has_embedding()),
None
)
if not topic_node or topic_node.embedding is None:
continue
@@ -179,11 +179,11 @@ async def expand_memories_with_semantic_filter(
if len(expanded_memories) >= max_expanded:
logger.debug(f"⏹️ 提前停止:已达到最大扩展数量 {max_expanded}")
break
# 早停检查
if len(expanded_memories) >= max_expanded:
break
# 记录本层统计
depth_stats.append({
"depth": depth + 1,
@@ -199,20 +199,20 @@ async def expand_memories_with_semantic_filter(
# 限制下一层的记忆数量,避免爆炸性增长
current_level_memories = next_level_memories[:max_expanded]
# 每层让出控制权
await asyncio.sleep(0.001)
# 排序并返回
sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded]
elapsed = time.time() - start_time
logger.info(
f"✅ 图扩展完成: 初始{len(initial_memory_ids)}个 → "
f"扩展{len(sorted_results)}个新记忆 "
f"(深度={max_depth}, 阈值={semantic_threshold:.2f}, 耗时={elapsed:.3f}s)"
)
# 输出每层统计
for stat in depth_stats:
logger.debug(

View File

@@ -78,11 +78,9 @@ __all__ = [
# 消息
"MaiMessages",
# 工具函数
"ManifestValidator",
"PluginInfo",
# 增强命令系统
"PlusCommand",
"PlusCommandAdapter",
"PythonDependency",
"ToolInfo",
"ToolParamType",

View File

@@ -132,7 +132,7 @@ async def generate_reply(
prompt_mode = "s4u" # 默认使用s4u模式
if action_data and "prompt_mode" in action_data:
prompt_mode = action_data.get("prompt_mode", "s4u")
# 将prompt_mode添加到available_actions中作为特殊键
# 注意这里我们需要暂时使用类型忽略因为available_actions的类型定义不支持非ActionInfo值
if available_actions is None:

View File

@@ -362,7 +362,7 @@ class ChatterPlanFilter:
return "最近没有聊天内容。", "没有未读消息。", []
stream_context = chat_stream.context_manager
# 获取真正的已读和未读消息
read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中
if not read_messages:
@@ -652,30 +652,30 @@ class ChatterPlanFilter:
if not action_info:
logger.debug(f"动作 {action_name} 不在可用动作列表中,保留所有参数")
return action_data
# 获取该动作定义的合法参数
defined_params = set(action_info.action_parameters.keys())
# 合法参数集合
valid_params = defined_params
# 过滤参数
filtered_data = {}
removed_params = []
for key, value in action_data.items():
if key in valid_params:
filtered_data[key] = value
else:
removed_params.append(key)
# 记录被移除的参数
if removed_params:
logger.info(
f"🧹 [参数过滤] 动作 '{action_name}' 移除了多余参数: {removed_params}. "
f"合法参数: {sorted(valid_params)}"
)
return filtered_data
def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]:

View File

@@ -545,14 +545,14 @@ async def execute_proactive_thinking(stream_id: str):
# 获取或创建该聊天流的执行锁
if stream_id not in _execution_locks:
_execution_locks[stream_id] = asyncio.Lock()
lock = _execution_locks[stream_id]
# 尝试获取锁,如果已被占用则跳过本次执行(防止重复)
if lock.locked():
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 已有正在执行的主动思考任务")
return
async with lock:
logger.debug(f"🤔 开始主动思考 {stream_id}")
@@ -563,13 +563,13 @@ async def execute_proactive_thinking(stream_id: str):
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and chat_stream.context_manager.context.is_chatter_processing:
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息")
return
except Exception as e:
logger.warning(f"检查 chatter 处理状态时出错: {e},继续执行")
# 0.1 检查白名单/黑名单
# 从 stream_id 获取 stream_config 字符串进行验证
try:

View File

@@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata(
# 导入插件主类
from .plugin import AntiInjectionPlugin
__all__ = ["__plugin_meta__", "AntiInjectionPlugin"]
__all__ = ["AntiInjectionPlugin", "__plugin_meta__"]

View File

@@ -8,8 +8,8 @@ import time
from src.chat.security.interfaces import (
SecurityAction,
SecurityCheckResult,
SecurityChecker,
SecurityCheckResult,
SecurityLevel,
)
from src.common.logger import get_logger

View File

@@ -4,7 +4,7 @@
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
"""
from src.chat.security.interfaces import SecurityAction, SecurityCheckResult
from src.chat.security.interfaces import SecurityCheckResult
from src.common.logger import get_logger
from .counter_attack import CounterAttackGenerator

View File

@@ -64,15 +64,15 @@ class CoreActionsPlugin(BasePlugin):
# --- 根据配置注册组件 ---
components: ClassVar = []
# 注册 reply 动作
if self.get_config("components.enable_reply", True):
components.append((ReplyAction.get_action_info(), ReplyAction))
# 注册 respond 动作
if self.get_config("components.enable_respond", True):
components.append((RespondAction.get_action_info(), RespondAction))
# 注册 emoji 动作
if self.get_config("components.enable_emoji", True):
components.append((EmojiAction.get_action_info(), EmojiAction))

View File

@@ -22,23 +22,23 @@ class ReplyAction(BaseAction):
- 专注于理解和回应单条消息的具体内容
- 适合 Focus 模式下的精准回复
"""
# 动作基本信息
action_name = "reply"
action_description = "针对特定消息进行精准回复。深度理解并回应单条消息的具体内容。需要指定目标消息ID。"
# 激活设置
activation_type = ActionActivationType.ALWAYS # 回复动作总是可用
mode_enable = ChatMode.ALL # 在所有模式下都可用
parallel_action = False # 回复动作不能与其他动作并行
# 动作参数定义
action_parameters: ClassVar = {
"target_message_id": "要回复的目标消息ID必需来自未读消息的 <m...> 标签)",
"content": "回复的具体内容可选由LLM生成",
"should_quote_reply": "是否引用原消息可选true/false默认false。群聊中回复较早消息或需要明确指向时使用true",
}
# 动作使用场景
action_require: ClassVar = [
"需要针对特定消息进行精准回复时使用",
@@ -48,10 +48,10 @@ class ReplyAction(BaseAction):
"群聊中需要明确回应某个特定用户或问题时使用",
"关注单条消息的具体内容和上下文细节",
]
# 关联类型
associated_types: ClassVar[list[str]] = ["text"]
async def execute(self) -> tuple[bool, str]:
"""执行reply动作
@@ -70,21 +70,21 @@ class RespondAction(BaseAction):
- 适合对于群聊消息下的宏观回应
- 避免与单一用户深度对话而忽略其他用户的消息
"""
# 动作基本信息
action_name = "respond"
action_description = "统一回应所有未读消息。理解整体对话动态和话题走向,生成连贯的回复。无需指定目标消息。"
# 激活设置
activation_type = ActionActivationType.ALWAYS # 回应动作总是可用
mode_enable = ChatMode.ALL # 在所有模式下都可用
parallel_action = False # 回应动作不能与其他动作并行
# 动作参数定义
action_parameters: ClassVar = {
"content": "回复的具体内容可选由LLM生成",
}
# 动作使用场景
action_require: ClassVar = [
"需要统一回应多条未读消息时使用Normal 模式专用)",
@@ -94,10 +94,10 @@ class RespondAction(BaseAction):
"适合群聊中的自然对话流,无需精确指向特定消息",
"可以同时回应多个话题或参与者",
]
# 关联类型
associated_types: ClassVar[list[str]] = ["text"]
async def execute(self) -> tuple[bool, str]:
"""执行respond动作

View File

@@ -6,10 +6,10 @@
import asyncio
import base64
import datetime
import filetype
from collections.abc import Callable
import aiohttp
import filetype
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager

View File

@@ -6,7 +6,7 @@
import re
from typing import ClassVar
from src.chat.utils.prompt_component_manager import prompt_component_manager
from src.plugin_system.apis import (
plugin_manage_api,
)
@@ -74,6 +74,7 @@ class SystemCommand(PlusCommand):
• `/system permission` - 权限管理
• `/system plugin` - 插件管理
• `/system schedule` - 定时任务管理
• `/system prompt` - 提示词注入管理
"""
elif target == "schedule":
help_text = """📅 定时任务管理帮助
@@ -113,8 +114,17 @@ class SystemCommand(PlusCommand):
• /system permission nodes [插件名] - 查看权限节点
• /system permission allnodes - 查看所有权限节点详情
"""
await self.send_text(help_text)
elif target == "prompt":
help_text = """📝 提示词注入管理帮助
🔎 查询命令 (需要 `system.prompt.view` 权限):
• `/system prompt help` - 显示此帮助
• `/system prompt map` - 查看全局注入关系图
• `/system prompt targets` - 列出所有可被注入的核心提示词
• `/system prompt components` - 列出所有已注册的提示词组件
• `/system prompt info <目标名>` - 查看特定核心提示词的注入详情
"""
await self.send_text(help_text)
# =================================================================
# Plugin Management Section
@@ -231,6 +241,101 @@ class SystemCommand(PlusCommand):
else:
await self.send_text(f"❌ 恢复任务失败: `{schedule_id}`")
# =================================================================
# Prompt Management Section
# =================================================================
async def _handle_prompt_commands(self, args: list[str]):
"""处理提示词管理相关命令"""
if not args or args[0].lower() in ["help", "帮助"]:
await self._show_help("prompt")
return
action = args[0].lower()
remaining_args = args[1:]
if action in ["map", "关系图"]:
await self._show_injection_map()
elif action in ["targets", "目标"]:
await self._list_core_prompts()
elif action in ["components", "组件"]:
await self._list_prompt_components()
elif action in ["info", "详情"] and remaining_args:
await self._get_prompt_injection_info(remaining_args[0])
else:
await self.send_text("❌ 提示词管理命令不合法\n使用 /system prompt help 查看帮助")
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _show_injection_map(self):
"""显示全局注入关系图"""
injection_map = await prompt_component_manager.get_full_injection_map()
if not injection_map:
await self.send_text("📊 当前没有任何提示词注入关系")
return
response_parts = ["📊 全局提示词注入关系图:\n"]
for target, injections in injection_map.items():
if injections:
response_parts.append(f"🎯 **{target}** (注入源):")
for inj in injections:
source_tag = f"({inj['source']})" if inj['source'] != 'static_default' else ''
response_parts.append(f" ⎿ `{inj['name']}` (优先级: {inj['priority']}) {source_tag}")
else:
response_parts.append(f"🎯 **{target}** (无注入)")
await self._send_long_message("\n".join(response_parts))
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _list_core_prompts(self):
"""列出所有可注入的核心提示词"""
targets = prompt_component_manager.get_core_prompts()
if not targets:
await self.send_text("🎯 当前没有可注入的核心提示词")
return
response = "🎯 所有可注入的核心提示词:\n" + "\n".join([f"• `{name}`" for name in targets])
await self.send_text(response)
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _list_prompt_components(self):
"""列出所有已注册的提示词组件"""
components = prompt_component_manager.get_registered_prompt_component_info()
if not components:
await self.send_text("🧩 当前没有已注册的提示词组件")
return
response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"]
for comp in components:
response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)")
await self._send_long_message("\n".join(response_parts))
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _get_prompt_injection_info(self, target_name: str):
"""获取特定核心提示词的注入详情"""
injections = await prompt_component_manager.get_injections_for_prompt(target_name)
core_prompts = prompt_component_manager.get_core_prompts()
if target_name not in core_prompts:
await self.send_text(f"❌ 找不到核心提示词: `{target_name}`")
return
if not injections:
await self.send_text(f"🎯 核心提示词 `{target_name}` 当前没有被任何组件注入。")
return
response_parts = [f"🔎 核心提示词 `{target_name}` 的注入详情:"]
for inj in injections:
response_parts.append(
f" • **`{inj['name']}`** (优先级: {inj['priority']})"
)
response_parts.append(f" - 来源: `{inj['source']}`")
response_parts.append(f" - 类型: `{inj['injection_type']}`")
if inj.get('target_content'):
response_parts.append(f" - 操作目标: `{inj['target_content']}`")
await self.send_text("\n".join(response_parts))
# =================================================================
# Permission Management Section
# =================================================================

File diff suppressed because it is too large Load Diff