Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
2
bot.py
2
bot.py
@@ -588,7 +588,7 @@ class MaiBotMain:
|
||||
|
||||
async def run_async_init(self, main_system):
|
||||
"""执行异步初始化步骤"""
|
||||
|
||||
|
||||
# 初始化数据库表结构
|
||||
await self.initialize_database_async()
|
||||
|
||||
|
||||
@@ -19,14 +19,13 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
async def generate_missing_embeddings(
|
||||
target_node_types: List[str] = None,
|
||||
target_node_types: list[str] = None,
|
||||
batch_size: int = 50,
|
||||
):
|
||||
"""
|
||||
@@ -46,13 +45,13 @@ async def generate_missing_embeddings(
|
||||
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🔧 为节点生成嵌入向量")
|
||||
print("🔧 为节点生成嵌入向量")
|
||||
print(f"{'='*80}\n")
|
||||
print(f"目标节点类型: {', '.join(target_node_types)}")
|
||||
print(f"批处理大小: {batch_size}\n")
|
||||
|
||||
# 1. 初始化记忆管理器
|
||||
print(f"🔧 正在初始化记忆管理器...")
|
||||
print("🔧 正在初始化记忆管理器...")
|
||||
await initialize_memory_manager()
|
||||
manager = get_memory_manager()
|
||||
|
||||
@@ -60,10 +59,10 @@ async def generate_missing_embeddings(
|
||||
print("❌ 记忆管理器初始化失败")
|
||||
return
|
||||
|
||||
print(f"✅ 记忆管理器已初始化\n")
|
||||
print("✅ 记忆管理器已初始化\n")
|
||||
|
||||
# 2. 获取已索引的节点ID
|
||||
print(f"🔍 检查现有向量索引...")
|
||||
print("🔍 检查现有向量索引...")
|
||||
existing_node_ids = set()
|
||||
try:
|
||||
vector_count = manager.vector_store.collection.count()
|
||||
@@ -78,14 +77,14 @@ async def generate_missing_embeddings(
|
||||
)
|
||||
if result and "ids" in result:
|
||||
existing_node_ids.update(result["ids"])
|
||||
|
||||
|
||||
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取已索引节点ID失败: {e}")
|
||||
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||
|
||||
# 3. 收集需要生成嵌入的节点
|
||||
print(f"🔍 扫描需要生成嵌入的节点...")
|
||||
print("🔍 扫描需要生成嵌入的节点...")
|
||||
all_memories = manager.graph_store.get_all_memories()
|
||||
|
||||
nodes_to_process = []
|
||||
@@ -110,7 +109,7 @@ async def generate_missing_embeddings(
|
||||
})
|
||||
type_stats[node.node_type.value]["need_emb"] += 1
|
||||
|
||||
print(f"\n📊 扫描结果:")
|
||||
print("\n📊 扫描结果:")
|
||||
for node_type in target_node_types:
|
||||
stats = type_stats[node_type]
|
||||
already_ok = stats["already_indexed"]
|
||||
@@ -121,11 +120,11 @@ async def generate_missing_embeddings(
|
||||
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
|
||||
|
||||
if len(nodes_to_process) == 0:
|
||||
print(f"✅ 所有节点已有嵌入向量,无需生成")
|
||||
print("✅ 所有节点已有嵌入向量,无需生成")
|
||||
return
|
||||
|
||||
# 3. 批量生成嵌入
|
||||
print(f"🚀 开始生成嵌入向量...\n")
|
||||
print("🚀 开始生成嵌入向量...\n")
|
||||
|
||||
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
|
||||
success_count = 0
|
||||
@@ -193,22 +192,22 @@ async def generate_missing_embeddings(
|
||||
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
|
||||
|
||||
# 4. 保存图数据(更新节点的 embedding 字段)
|
||||
print(f"💾 保存图数据...")
|
||||
print("💾 保存图数据...")
|
||||
try:
|
||||
await manager.persistence.save_graph_store(manager.graph_store)
|
||||
print(f"✅ 图数据已保存\n")
|
||||
print("✅ 图数据已保存\n")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图数据失败", exc_info=True)
|
||||
logger.error("保存图数据失败", exc_info=True)
|
||||
print(f"❌ 保存失败: {e}\n")
|
||||
|
||||
# 5. 验证结果
|
||||
print(f"🔍 验证向量索引...")
|
||||
print("🔍 验证向量索引...")
|
||||
final_vector_count = manager.vector_store.collection.count()
|
||||
stats = manager.graph_store.get_statistics()
|
||||
total_nodes = stats["total_nodes"]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"📊 生成完成")
|
||||
print("📊 生成完成")
|
||||
print(f"{'='*80}")
|
||||
print(f"处理节点数: {len(nodes_to_process)}")
|
||||
print(f"成功生成: {success_count}")
|
||||
@@ -219,7 +218,7 @@ async def generate_missing_embeddings(
|
||||
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
|
||||
|
||||
# 6. 测试搜索
|
||||
print(f"🧪 测试搜索功能...")
|
||||
print("🧪 测试搜索功能...")
|
||||
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
|
||||
|
||||
for query in test_queries:
|
||||
|
||||
@@ -38,7 +38,7 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
|
||||
|
||||
# ========== 性能配置参数 ==========
|
||||
#
|
||||
#
|
||||
# 知识提取(步骤2:txt转json)并发控制
|
||||
# - 控制同时进行的LLM提取请求数量
|
||||
# - 推荐值: 3-10,取决于API速率限制
|
||||
@@ -184,7 +184,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
tuple: (doc_item或None, failed_hash或None)
|
||||
"""
|
||||
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
|
||||
|
||||
|
||||
# 🔧 优化:使用异步文件检查,避免阻塞
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
@@ -215,11 +215,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
"extracted_entities": extracted_data.get("entities", []),
|
||||
"extracted_triples": extracted_data.get("triples", []),
|
||||
}
|
||||
|
||||
|
||||
# 保存到缓存(异步写入)
|
||||
async with aiofiles.open(temp_file_path, "wb") as f:
|
||||
await f.write(orjson.dumps(doc_item))
|
||||
|
||||
|
||||
return doc_item, None
|
||||
except Exception as e:
|
||||
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
||||
@@ -249,13 +249,13 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
failed_hashes, open_ie_docs = [], []
|
||||
|
||||
|
||||
# 🔧 关键修复:创建单个 LLM 请求实例,复用连接
|
||||
llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction")
|
||||
|
||||
|
||||
# 🔧 并发控制:限制最大并发数,防止速率限制
|
||||
semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY)
|
||||
|
||||
|
||||
async def extract_with_semaphore(pg_hash, paragraph):
|
||||
"""带信号量控制的提取函数"""
|
||||
async with semaphore:
|
||||
@@ -266,7 +266,7 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
extract_with_semaphore(p_hash, paragraph)
|
||||
for p_hash, paragraph in paragraphs_dict.items()
|
||||
]
|
||||
|
||||
|
||||
total = len(tasks)
|
||||
completed = 0
|
||||
|
||||
@@ -284,7 +284,7 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
TimeRemainingColumn(),
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]正在提取信息...", total=total)
|
||||
|
||||
|
||||
# 🔧 优化:使用 asyncio.gather 并发执行所有任务
|
||||
# return_exceptions=True 确保单个失败不影响其他任务
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
@@ -293,7 +293,7 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
failed_hashes.append(failed_hash)
|
||||
elif doc_item:
|
||||
open_ie_docs.append(doc_item)
|
||||
|
||||
|
||||
completed += 1
|
||||
progress.update(task, advance=1)
|
||||
|
||||
@@ -415,7 +415,7 @@ def rebuild_faiss_only():
|
||||
logger.info("--- 重建 FAISS 索引 ---")
|
||||
# 重建索引不需要并发参数(不涉及 embedding 生成)
|
||||
embed_manager = EmbeddingManager()
|
||||
|
||||
|
||||
logger.info("正在加载现有的 Embedding 库...")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
|
||||
@@ -4,13 +4,13 @@
|
||||
提供 Web API 用于可视化记忆图数据
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request, Query
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
@@ -29,7 +29,7 @@ router = APIRouter()
|
||||
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
|
||||
|
||||
|
||||
def find_available_data_files() -> List[Path]:
|
||||
def find_available_data_files() -> list[Path]:
|
||||
"""查找所有可用的记忆图数据文件"""
|
||||
files = []
|
||||
if not data_dir.exists():
|
||||
@@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]:
|
||||
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
|
||||
|
||||
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
|
||||
def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
|
||||
"""从磁盘加载图数据"""
|
||||
global graph_data_cache, current_data_file
|
||||
|
||||
@@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any
|
||||
if not graph_file.exists():
|
||||
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
|
||||
|
||||
with open(graph_file, "r", encoding="utf-8") as f:
|
||||
with open(graph_file, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
|
||||
nodes = data.get("nodes", [])
|
||||
@@ -150,7 +150,7 @@ async def index(request: Request):
|
||||
return templates.TemplateResponse("visualizer.html", {"request": request})
|
||||
|
||||
|
||||
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
|
||||
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
|
||||
"""从 MemoryManager 提取并格式化图数据"""
|
||||
if not memory_manager.graph_store:
|
||||
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
|
||||
@@ -188,7 +188,7 @@ def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
|
||||
"arrows": "to",
|
||||
"memory_id": memory.id,
|
||||
}
|
||||
|
||||
|
||||
edges_list = list(edges_dict.values())
|
||||
|
||||
stats = memory_manager.get_statistics()
|
||||
@@ -261,7 +261,7 @@ async def get_paginated_graph(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
|
||||
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
|
||||
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
|
||||
node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
|
||||
):
|
||||
"""分页获取图数据,支持重要性过滤"""
|
||||
try:
|
||||
@@ -301,13 +301,13 @@ async def get_paginated_graph(
|
||||
total_pages = (total_nodes + page_size - 1) // page_size
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_nodes)
|
||||
|
||||
|
||||
paginated_nodes = nodes_with_importance[start_idx:end_idx]
|
||||
node_ids = set(n["id"] for n in paginated_nodes)
|
||||
|
||||
# 只保留连接分页节点的边
|
||||
paginated_edges = [
|
||||
e for e in edges
|
||||
e for e in edges
|
||||
if e.get("from") in node_ids and e.get("to") in node_ids
|
||||
]
|
||||
|
||||
@@ -383,7 +383,7 @@ async def get_clustered_graph(
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
|
||||
def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
|
||||
"""简单的图聚类算法:按类型和连接度聚类"""
|
||||
# 构建邻接表
|
||||
adjacency = defaultdict(set)
|
||||
@@ -412,21 +412,21 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
|
||||
for node in type_nodes:
|
||||
importance = len(adjacency[node["id"]])
|
||||
node_importance.append((node, importance))
|
||||
|
||||
|
||||
node_importance.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
# 保留前N个重要节点
|
||||
keep_count = min(len(type_nodes), max_nodes // len(type_groups))
|
||||
for node, importance in node_importance[:keep_count]:
|
||||
clustered_nodes.append(node)
|
||||
node_mapping[node["id"]] = node["id"]
|
||||
|
||||
|
||||
# 其余节点聚合为一个超级节点
|
||||
if len(node_importance) > keep_count:
|
||||
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
|
||||
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
|
||||
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
|
||||
|
||||
|
||||
clustered_nodes.append({
|
||||
"id": cluster_id,
|
||||
"label": cluster_label,
|
||||
@@ -436,7 +436,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
|
||||
"cluster_size": len(clustered_node_ids),
|
||||
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
|
||||
})
|
||||
|
||||
|
||||
for node_id in clustered_node_ids:
|
||||
node_mapping[node_id] = cluster_id
|
||||
|
||||
@@ -445,7 +445,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
|
||||
for edge in edges:
|
||||
from_id = node_mapping.get(edge["from"])
|
||||
to_id = node_mapping.get(edge["to"])
|
||||
|
||||
|
||||
if from_id and to_id and from_id != to_id:
|
||||
edge_key = tuple(sorted([from_id, to_id]))
|
||||
if edge_key not in edge_set:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
|
||||
@@ -161,16 +161,16 @@ class EmbeddingStore:
|
||||
# 限制 chunk_size 和 max_workers 在合理范围内
|
||||
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
|
||||
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
|
||||
|
||||
|
||||
semaphore = asyncio.Semaphore(max_workers)
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
results = {}
|
||||
|
||||
|
||||
# 将字符串列表分成多个 chunk
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunks.append(strs[i : i + chunk_size])
|
||||
|
||||
|
||||
async def _process_chunk(chunk: list[str]):
|
||||
"""处理一个 chunk 的字符串(批量获取 embedding)"""
|
||||
async with semaphore:
|
||||
@@ -180,12 +180,12 @@ class EmbeddingStore:
|
||||
embedding = await EmbeddingStore._get_embedding_async(llm, s)
|
||||
embeddings.append(embedding)
|
||||
results[s] = embedding
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(len(chunk))
|
||||
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
# 并发处理所有 chunks
|
||||
tasks = [_process_chunk(chunk) for chunk in chunks]
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -418,22 +418,22 @@ class EmbeddingStore:
|
||||
# 🔧 修复:检查所有 embedding 的维度是否一致
|
||||
dimensions = [len(emb) for emb in array]
|
||||
unique_dims = set(dimensions)
|
||||
|
||||
|
||||
if len(unique_dims) > 1:
|
||||
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
|
||||
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
|
||||
|
||||
|
||||
# 获取期望的维度(使用最常见的维度)
|
||||
from collections import Counter
|
||||
dim_counter = Counter(dimensions)
|
||||
expected_dim = dim_counter.most_common(1)[0][0]
|
||||
logger.warning(f"将使用最常见的维度: {expected_dim}")
|
||||
|
||||
|
||||
# 过滤掉维度不匹配的 embedding
|
||||
filtered_array = []
|
||||
filtered_idx2hash = {}
|
||||
skipped_count = 0
|
||||
|
||||
|
||||
for i, emb in enumerate(array):
|
||||
if len(emb) == expected_dim:
|
||||
filtered_array.append(emb)
|
||||
@@ -442,11 +442,11 @@ class EmbeddingStore:
|
||||
skipped_count += 1
|
||||
hash_key = self.idx2hash[str(i)]
|
||||
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
|
||||
|
||||
|
||||
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
|
||||
array = filtered_array
|
||||
self.idx2hash = filtered_idx2hash
|
||||
|
||||
|
||||
if not array:
|
||||
logger.error("过滤后没有可用的 embedding,无法构建索引")
|
||||
embedding_dim = expected_dim
|
||||
|
||||
@@ -13,4 +13,4 @@ __all__ = [
|
||||
"StreamLoopManager",
|
||||
"message_manager",
|
||||
"stream_loop_manager",
|
||||
]
|
||||
]
|
||||
|
||||
@@ -82,7 +82,7 @@ class SingleStreamContextManager:
|
||||
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
|
||||
|
||||
# 如果使用了缓存系统,输出调试信息
|
||||
if cache_enabled and self.context.is_cache_enabled:
|
||||
if self.context.is_chatter_processing:
|
||||
|
||||
@@ -111,9 +111,9 @@ class StreamLoopManager:
|
||||
# 获取或创建该流的启动锁
|
||||
if stream_id not in self._stream_start_locks:
|
||||
self._stream_start_locks[stream_id] = asyncio.Lock()
|
||||
|
||||
|
||||
lock = self._stream_start_locks[stream_id]
|
||||
|
||||
|
||||
# 使用锁防止并发启动同一个流的多个循环任务
|
||||
async with lock:
|
||||
# 获取流上下文
|
||||
@@ -148,7 +148,7 @@ class StreamLoopManager:
|
||||
# 紧急取消
|
||||
context.stream_loop_task.cancel()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
|
||||
|
||||
# 将任务记录到 StreamContext 中
|
||||
@@ -252,7 +252,7 @@ class StreamLoopManager:
|
||||
self.stats["total_process_cycles"] += 1
|
||||
if success:
|
||||
logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功")
|
||||
|
||||
|
||||
# 🔒 处理成功后,等待一小段时间确保清理操作完成
|
||||
# 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环
|
||||
await asyncio.sleep(0.1)
|
||||
@@ -382,7 +382,7 @@ class StreamLoopManager:
|
||||
self.chatter_manager.process_stream_context(stream_id, context),
|
||||
name=f"chatter_process_{stream_id}"
|
||||
)
|
||||
|
||||
|
||||
# 等待 chatter 任务完成
|
||||
results = await chatter_task
|
||||
success = results.get("success", False)
|
||||
@@ -398,8 +398,8 @@ class StreamLoopManager:
|
||||
else:
|
||||
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
|
||||
|
||||
return success
|
||||
except asyncio.CancelledError:
|
||||
return success
|
||||
except asyncio.CancelledError:
|
||||
if chatter_task and not chatter_task.done():
|
||||
chatter_task.cancel()
|
||||
raise
|
||||
@@ -709,4 +709,4 @@ class StreamLoopManager:
|
||||
|
||||
|
||||
# 全局流循环管理器实例
|
||||
stream_loop_manager = StreamLoopManager()
|
||||
stream_loop_manager = StreamLoopManager()
|
||||
|
||||
@@ -417,7 +417,7 @@ class MessageManager:
|
||||
return
|
||||
|
||||
# 记录详细信息
|
||||
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
|
||||
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
|
||||
for msg in unread_messages[:3]] # 只显示前3条
|
||||
logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}")
|
||||
|
||||
@@ -446,15 +446,15 @@ class MessageManager:
|
||||
context = chat_stream.context_manager.context
|
||||
if hasattr(context, "unread_messages") and context.unread_messages:
|
||||
unread_count = len(context.unread_messages)
|
||||
|
||||
|
||||
# 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们
|
||||
if unread_count > 0:
|
||||
if unread_count > 0:
|
||||
# 获取所有未读消息的 ID
|
||||
message_ids = [msg.message_id for msg in context.unread_messages]
|
||||
|
||||
|
||||
# 标记为已读(会移到历史消息)
|
||||
success = chat_stream.context_manager.mark_messages_as_read(message_ids)
|
||||
|
||||
|
||||
if success:
|
||||
logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读")
|
||||
else:
|
||||
@@ -481,7 +481,7 @@ class MessageManager:
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
|
||||
chat_stream.context_manager.context.is_chatter_processing = is_processing
|
||||
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
|
||||
except Exception as e:
|
||||
@@ -517,7 +517,7 @@ class MessageManager:
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
|
||||
return chat_stream.context_manager.context.is_chatter_processing
|
||||
except Exception:
|
||||
pass
|
||||
@@ -677,4 +677,4 @@ class MessageManager:
|
||||
|
||||
|
||||
# 创建全局消息管理器实例
|
||||
message_manager = MessageManager()
|
||||
message_manager = MessageManager()
|
||||
|
||||
@@ -248,16 +248,16 @@ class ChatterActionManager:
|
||||
try:
|
||||
# 根据动作类型确定提示词模式
|
||||
prompt_mode = "s4u" if action_name == "reply" else "normal"
|
||||
|
||||
|
||||
# 将prompt_mode传递给generate_reply
|
||||
action_data_with_mode = (action_data or {}).copy()
|
||||
action_data_with_mode["prompt_mode"] = prompt_mode
|
||||
|
||||
|
||||
# 只传递当前正在执行的动作,而不是所有可用动作
|
||||
# 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用"
|
||||
current_action_info = self._using_actions.get(action_name)
|
||||
current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {}
|
||||
|
||||
|
||||
# 附加目标消息信息(如果存在)
|
||||
if target_message:
|
||||
# 提取目标消息的关键信息
|
||||
@@ -268,7 +268,7 @@ class ChatterActionManager:
|
||||
"time": getattr(target_message, "time", 0),
|
||||
}
|
||||
current_actions["_target_message"] = target_msg_info
|
||||
|
||||
|
||||
success, response_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=chat_stream,
|
||||
reply_message=target_message,
|
||||
@@ -295,12 +295,12 @@ class ChatterActionManager:
|
||||
should_quote_reply = None
|
||||
if action_data and isinstance(action_data, dict):
|
||||
should_quote_reply = action_data.get("should_quote_reply", None)
|
||||
|
||||
|
||||
# respond动作默认不引用回复,保持对话流畅
|
||||
if action_name == "respond" and should_quote_reply is None:
|
||||
should_quote_reply = False
|
||||
|
||||
async def _after_reply():
|
||||
async def _after_reply():
|
||||
# 发送并存储回复
|
||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||
chat_stream,
|
||||
|
||||
@@ -365,7 +365,7 @@ class DefaultReplyer:
|
||||
# 确保类型安全
|
||||
if isinstance(mode, str):
|
||||
prompt_mode_value = mode
|
||||
|
||||
|
||||
# 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_reply_context(
|
||||
@@ -1171,16 +1171,16 @@ class DefaultReplyer:
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream_obj = await chat_manager.get_stream(chat_id)
|
||||
|
||||
|
||||
if chat_stream_obj:
|
||||
unread_messages = chat_stream_obj.context_manager.get_unread_messages()
|
||||
if unread_messages:
|
||||
# 使用最后一条未读消息作为参考
|
||||
last_msg = unread_messages[-1]
|
||||
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform
|
||||
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else ""
|
||||
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else ""
|
||||
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else ""
|
||||
platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
|
||||
user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
|
||||
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
|
||||
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
|
||||
processed_plain_text = last_msg.processed_plain_text or ""
|
||||
else:
|
||||
# 没有未读消息,使用默认值
|
||||
@@ -1263,19 +1263,19 @@ class DefaultReplyer:
|
||||
if available_actions:
|
||||
# 过滤掉特殊键(以_开头)
|
||||
action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")}
|
||||
|
||||
|
||||
# 提取目标消息信息(如果存在)
|
||||
target_msg_info = available_actions.get("_target_message") # type: ignore
|
||||
|
||||
|
||||
if action_items:
|
||||
if len(action_items) == 1:
|
||||
# 单个动作
|
||||
action_name, action_info = list(action_items.items())[0]
|
||||
action_desc = action_info.description
|
||||
|
||||
|
||||
# 构建基础决策信息
|
||||
action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n"
|
||||
|
||||
|
||||
# 只有需要目标消息的动作才显示目标消息详情
|
||||
# respond 动作是统一回应所有未读消息,不应该显示特定目标消息
|
||||
if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict):
|
||||
@@ -1284,7 +1284,7 @@ class DefaultReplyer:
|
||||
content = target_msg_info.get("content", "")
|
||||
msg_time = target_msg_info.get("time", 0)
|
||||
time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间"
|
||||
|
||||
|
||||
action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n"
|
||||
else:
|
||||
# 多个动作
|
||||
@@ -2137,7 +2137,7 @@ class DefaultReplyer:
|
||||
except Exception as e:
|
||||
logger.error(f"存储聊天记忆失败: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
插件可以通过实现这些接口来扩展安全功能。
|
||||
"""
|
||||
|
||||
from .interfaces import SecurityCheckResult, SecurityChecker
|
||||
from .interfaces import SecurityChecker, SecurityCheckResult
|
||||
from .manager import SecurityManager, get_security_manager
|
||||
|
||||
__all__ = [
|
||||
"SecurityChecker",
|
||||
"SecurityCheckResult",
|
||||
"SecurityChecker",
|
||||
"SecurityManager",
|
||||
"get_security_manager",
|
||||
]
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel
|
||||
from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
|
||||
|
||||
logger = get_logger("security.manager")
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
@@ -12,122 +14,205 @@ logger = get_logger("prompt_component_manager")
|
||||
|
||||
class PromptComponentManager:
|
||||
"""
|
||||
管理所有 `BasePrompt` 组件的单例类。
|
||||
一个统一的、动态的、可观测的提示词组件管理中心。
|
||||
|
||||
该管理器负责:
|
||||
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
|
||||
2. 根据注入点(目标Prompt名称)对它们进行筛选。
|
||||
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
||||
该管理器是整个提示词动态注入系统的核心,它负责:
|
||||
1. **规则加载**: 在系统启动时,自动扫描所有已注册的 `BasePrompt` 组件,
|
||||
并将其静态定义的 `injection_rules` 加载为默认的动态规则。
|
||||
2. **动态管理**: 提供线程安全的 API,允许在运行时动态地添加、更新或移除注入规则,
|
||||
使得提示词的结构可以被实时调整。
|
||||
3. **状态观测**: 提供丰富的查询 API,用于观测系统当前完整的注入状态,
|
||||
例如查询所有注入到特定目标的规则、或查询某个组件定义的所有规则。
|
||||
4. **注入应用**: 在构建核心 Prompt 时,根据统一的、按优先级排序的规则集,
|
||||
动态地修改和装配提示词模板,实现灵活的提示词组合。
|
||||
"""
|
||||
|
||||
def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, type[BasePrompt]]]:
|
||||
"""
|
||||
获取指定目标Prompt的所有注入规则及其关联的组件类。
|
||||
def __init__(self):
|
||||
"""初始化管理器实例。"""
|
||||
# _dynamic_rules 是管理器的核心状态,存储所有注入规则。
|
||||
# 结构: {
|
||||
# "target_prompt_name": {
|
||||
# "prompt_component_name": (InjectionRule, content_provider, source)
|
||||
# }
|
||||
# }
|
||||
# content_provider 是一个异步函数,用于在应用规则时动态生成注入内容。
|
||||
# source 记录了规则的来源(例如 "static_default" 或 "runtime")。
|
||||
self._dynamic_rules: dict[str, dict[str, tuple[InjectionRule, Callable[..., Awaitable[str]], str]]] = {}
|
||||
self._lock = asyncio.Lock() # 使用异步锁确保对 _dynamic_rules 的并发访问安全。
|
||||
self._initialized = False # 标记静态规则是否已加载,防止重复加载。
|
||||
|
||||
Args:
|
||||
target_prompt_name (str): 目标 Prompt 的名称。
|
||||
# --- 核心生命周期与初始化 ---
|
||||
|
||||
Returns:
|
||||
list[tuple[InjectionRule, Type[BasePrompt]]]: 一个元组列表,
|
||||
每个元组包含一个注入规则和其对应的 Prompt 组件类,并已根据优先级排序。
|
||||
def load_static_rules(self):
|
||||
"""
|
||||
# 从注册表中获取所有已启用的 PROMPT 类型的组件
|
||||
在系统启动时加载所有静态注入规则。
|
||||
|
||||
该方法会扫描所有已在 `component_registry` 中注册并启用的 Prompt 组件,
|
||||
将其类变量 `injection_rules` 转换为管理器的动态规则。
|
||||
这确保了所有插件定义的默认注入行为在系统启动时就能生效。
|
||||
此操作是幂等的,一旦初始化完成就不会重复执行。
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
logger.info("正在加载静态 Prompt 注入规则...")
|
||||
|
||||
# 从组件注册表中获取所有已启用的 Prompt 组件
|
||||
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
|
||||
matching_rules = []
|
||||
|
||||
# 遍历所有启用的 Prompt 组件,查找与目标 Prompt 相关的注入规则
|
||||
for prompt_name, prompt_info in enabled_prompts.items():
|
||||
if not isinstance(prompt_info, PromptInfo):
|
||||
continue
|
||||
|
||||
# prompt_info.injection_rules 已经经过了后向兼容处理,确保总是列表
|
||||
for rule in prompt_info.injection_rules:
|
||||
# 如果规则的目标是当前指定的 Prompt
|
||||
if rule.target_prompt == target_prompt_name:
|
||||
# 获取该规则对应的组件类
|
||||
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
|
||||
# 确保获取到的确实是一个 BasePrompt 的子类
|
||||
if component_class and issubclass(component_class, BasePrompt):
|
||||
matching_rules.append((rule, component_class))
|
||||
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
|
||||
if not (component_class and issubclass(component_class, BasePrompt)):
|
||||
logger.warning(f"无法为 '{prompt_name}' 加载静态规则,因为它不是一个有效的 Prompt 组件。")
|
||||
continue
|
||||
|
||||
# 根据规则的优先级进行排序,数字越小,优先级越高,越先应用
|
||||
matching_rules.sort(key=lambda x: x[0].priority)
|
||||
return matching_rules
|
||||
def create_provider(cls: type[BasePrompt]) -> Callable[[PromptParameters], Awaitable[str]]:
|
||||
"""
|
||||
为静态组件创建一个内容提供者闭包 (Content Provider Closure)。
|
||||
|
||||
这个闭包捕获了组件的类 `cls`,并返回一个标准的 `content_provider` 异步函数。
|
||||
当 `apply_injections` 需要内容时,它会调用这个函数。
|
||||
函数内部会实例化组件,并执行其 `execute` 方法来获取注入内容。
|
||||
|
||||
Args:
|
||||
cls (type[BasePrompt]): 需要为其创建提供者的 Prompt 组件类。
|
||||
|
||||
Returns:
|
||||
Callable[[PromptParameters], Awaitable[str]]: 一个符合管理器标准的异步内容提供者。
|
||||
"""
|
||||
|
||||
async def content_provider(params: PromptParameters) -> str:
|
||||
"""实际执行内容生成的异步函数。"""
|
||||
try:
|
||||
# 从注册表获取最新的组件信息,包括插件配置
|
||||
p_info = component_registry.get_component_info(cls.prompt_name, ComponentType.PROMPT)
|
||||
plugin_config = {}
|
||||
if isinstance(p_info, PromptInfo):
|
||||
plugin_config = component_registry.get_plugin_config(p_info.plugin_name)
|
||||
|
||||
# 实例化组件并执行
|
||||
instance = cls(params=params, plugin_config=plugin_config)
|
||||
result = await instance.execute()
|
||||
return str(result) if result is not None else ""
|
||||
except Exception as e:
|
||||
logger.error(f"执行静态规则提供者 '{cls.prompt_name}' 时出错: {e}", exc_info=True)
|
||||
return "" # 出错时返回空字符串,避免影响主流程
|
||||
|
||||
return content_provider
|
||||
|
||||
# 为该组件的每条静态注入规则创建并注册一个动态规则
|
||||
for rule in prompt_info.injection_rules:
|
||||
provider = create_provider(component_class)
|
||||
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
|
||||
target_rules[prompt_name] = (rule, provider, "static_default")
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"静态 Prompt 注入规则加载完成,共处理 {len(enabled_prompts)} 个组件。")
|
||||
|
||||
# --- 运行时规则管理 API ---
|
||||
|
||||
async def add_injection_rule(
|
||||
self,
|
||||
prompt_name: str,
|
||||
rule: InjectionRule,
|
||||
content_provider: Callable[..., Awaitable[str]],
|
||||
source: str = "runtime",
|
||||
) -> bool:
|
||||
"""
|
||||
动态添加或更新一条注入规则。
|
||||
|
||||
此方法允许在系统运行时,由外部逻辑(如插件、命令)向管理器中添加新的注入行为。
|
||||
如果已存在同名组件针对同一目标的规则,此方法会覆盖旧规则。
|
||||
|
||||
Args:
|
||||
prompt_name (str): 动态注入组件的唯一名称。
|
||||
rule (InjectionRule): 描述注入行为的规则对象。
|
||||
content_provider (Callable[..., Awaitable[str]]):
|
||||
一个异步函数,用于在应用注入时动态生成内容。
|
||||
函数签名应为: `async def provider(params: "PromptParameters") -> str`
|
||||
source (str, optional): 规则的来源标识,默认为 "runtime"。
|
||||
|
||||
Returns:
|
||||
bool: 如果成功添加或更新,则返回 True。
|
||||
"""
|
||||
async with self._lock:
|
||||
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
|
||||
target_rules[prompt_name] = (rule, content_provider, source)
|
||||
logger.info(f"成功添加/更新注入规则: '{prompt_name}' -> '{rule.target_prompt}' (来源: {source})")
|
||||
return True
|
||||
|
||||
async def remove_injection_rule(self, prompt_name: str, target_prompt: str) -> bool:
|
||||
"""
|
||||
移除一条动态注入规则。
|
||||
|
||||
Args:
|
||||
prompt_name (str): 要移除的注入组件的名称。
|
||||
target_prompt (str): 该组件注入的目标核心提示词名称。
|
||||
|
||||
Returns:
|
||||
bool: 如果成功移除,则返回 True;如果规则不存在,则返回 False。
|
||||
"""
|
||||
async with self._lock:
|
||||
if target_prompt in self._dynamic_rules and prompt_name in self._dynamic_rules[target_prompt]:
|
||||
del self._dynamic_rules[target_prompt][prompt_name]
|
||||
# 如果目标下已无任何规则,则清理掉这个键
|
||||
if not self._dynamic_rules[target_prompt]:
|
||||
del self._dynamic_rules[target_prompt]
|
||||
logger.info(f"成功移除注入规则: '{prompt_name}' from '{target_prompt}'")
|
||||
return True
|
||||
logger.warning(f"尝试移除注入规则失败: 未找到 '{prompt_name}' on '{target_prompt}'")
|
||||
return False
|
||||
|
||||
# --- 核心注入逻辑 ---
|
||||
|
||||
async def apply_injections(
|
||||
self, target_prompt_name: str, original_template: str, params: PromptParameters
|
||||
) -> str:
|
||||
"""
|
||||
获取、实例化并执行所有相关组件,然后根据注入规则修改原始模板。
|
||||
【核心方法】根据目标名称,应用所有匹配的注入规则,返回修改后的模板。
|
||||
|
||||
这是一个三步走的过程:
|
||||
1. 实例化所有需要执行的组件。
|
||||
2. 并行执行它们的 `execute` 方法以获取注入内容。
|
||||
3. 按照优先级顺序,将内容注入到原始模板中。
|
||||
这是提示词构建流程中的关键步骤。它会执行以下操作:
|
||||
1. 检查并确保静态规则已加载。
|
||||
2. 获取所有注入到 `target_prompt_name` 的规则。
|
||||
3. 按照规则的 `priority` 属性进行升序排序,优先级数字越小越先应用。
|
||||
4. 依次执行每个规则的 `content_provider` 来异步获取注入内容。
|
||||
5. 根据规则的 `injection_type` (如 PREPEND, APPEND, REPLACE 等) 将内容应用到模板上。
|
||||
|
||||
Args:
|
||||
target_prompt_name (str): 目标 Prompt 的名称。
|
||||
original_template (str): 原始的、未经修改的 Prompt 模板字符串。
|
||||
params (PromptParameters): 传递给 Prompt 组件实例的参数。
|
||||
target_prompt_name (str): 目标核心提示词的名称。
|
||||
original_template (str): 未经修改的原始提示词模板。
|
||||
params (PromptParameters): 当前请求的参数,会传递给 `content_provider`。
|
||||
|
||||
Returns:
|
||||
str: 应用了所有注入规则后,修改过的 Prompt 模板字符串。
|
||||
str: 应用了所有注入规则后,最终生成的提示词模板字符串。
|
||||
"""
|
||||
rules_with_classes = self._get_rules_for(target_prompt_name)
|
||||
# 如果没有找到任何匹配的规则,就直接返回原始模板,啥也不干
|
||||
if not rules_with_classes:
|
||||
if not self._initialized:
|
||||
self.load_static_rules()
|
||||
|
||||
# 步骤 1: 获取所有指向当前目标的规则
|
||||
# 使用 .values() 获取 (rule, provider, source) 元组列表
|
||||
rules_for_target = list(self._dynamic_rules.get(target_prompt_name, {}).values())
|
||||
if not rules_for_target:
|
||||
return original_template
|
||||
|
||||
# --- 第一步: 实例化所有需要执行的组件 ---
|
||||
instance_map = {} # 存储组件实例,虽然目前没直接用,但留着总没错
|
||||
tasks = [] # 存放所有需要并行执行的 execute 异步任务
|
||||
components_to_execute = [] # 存放需要执行的组件类,用于后续结果映射
|
||||
# 步骤 2: 按优先级排序,数字越小越优先
|
||||
rules_for_target.sort(key=lambda x: x[0].priority)
|
||||
|
||||
for rule, component_class in rules_with_classes:
|
||||
# 如果注入类型是 REMOVE,那就不需要执行组件了,因为它不产生内容
|
||||
# 步骤 3: 依次执行内容提供者并根据注入类型修改模板
|
||||
modified_template = original_template
|
||||
for rule, provider, source in rules_for_target:
|
||||
content = ""
|
||||
# 对于非 REMOVE 类型的注入,需要先获取内容
|
||||
if rule.injection_type != InjectionType.REMOVE:
|
||||
try:
|
||||
# 获取组件的元信息,主要是为了拿到插件名称来读取插件配置
|
||||
prompt_info = component_registry.get_component_info(
|
||||
component_class.prompt_name, ComponentType.PROMPT
|
||||
)
|
||||
if not isinstance(prompt_info, PromptInfo):
|
||||
plugin_config = {}
|
||||
else:
|
||||
# 从注册表获取该组件所属插件的配置
|
||||
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
|
||||
|
||||
# 实例化组件,并传入参数和插件配置
|
||||
instance = component_class(params=params, plugin_config=plugin_config)
|
||||
instance_map[component_class.prompt_name] = instance
|
||||
# 将组件的 execute 方法作为一个任务添加到列表中
|
||||
tasks.append(instance.execute())
|
||||
components_to_execute.append(component_class)
|
||||
content = await provider(params)
|
||||
except Exception as e:
|
||||
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
|
||||
# 即使失败,也添加一个立即完成的空任务,以保持与其他任务的索引同步
|
||||
tasks.append(asyncio.create_task(asyncio.sleep(0, result=e))) # type: ignore
|
||||
|
||||
# --- 第二步: 并行执行所有组件的 execute 方法 ---
|
||||
# 使用 asyncio.gather 来同时运行所有任务,提高效率
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
# 创建一个从组件名到执行结果的映射,方便后续查找
|
||||
result_map = {
|
||||
components_to_execute[i].prompt_name: res
|
||||
for i, res in enumerate(results)
|
||||
if not isinstance(res, Exception) # 只包含成功的结果
|
||||
}
|
||||
# 单独处理并记录执行失败的组件
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"执行 Prompt 组件 '{components_to_execute[i].prompt_name}' 失败: {res}")
|
||||
|
||||
# --- 第三步: 按优先级顺序应用注入规则 ---
|
||||
modified_template = original_template
|
||||
for rule, component_class in rules_with_classes:
|
||||
# 从结果映射中获取该组件生成的内容
|
||||
content = result_map.get(component_class.prompt_name)
|
||||
logger.error(f"执行规则 '{rule}' (来源: {source}) 的内容提供者时失败: {e}", exc_info=True)
|
||||
continue # 跳过失败的 provider,不中断整个流程
|
||||
|
||||
# 应用注入逻辑
|
||||
try:
|
||||
if rule.injection_type == InjectionType.PREPEND:
|
||||
if content:
|
||||
@@ -136,28 +221,178 @@ class PromptComponentManager:
|
||||
if content:
|
||||
modified_template = f"{modified_template}\n{content}"
|
||||
elif rule.injection_type == InjectionType.REPLACE:
|
||||
# 使用正则表达式替换目标内容
|
||||
if content and rule.target_content:
|
||||
# 只有在 content 不为 None 且 target_content 有效时才执行替换
|
||||
if content is not None and rule.target_content:
|
||||
modified_template = re.sub(rule.target_content, str(content), modified_template)
|
||||
elif rule.injection_type == InjectionType.INSERT_AFTER:
|
||||
# 在匹配到的内容后面插入
|
||||
if content and rule.target_content:
|
||||
# re.sub a little trick: \g<0> represents the entire matched string
|
||||
# 使用 `\g<0>` 在正则匹配的整个内容后添加新内容
|
||||
replacement = f"\\g<0>\n{content}"
|
||||
modified_template = re.sub(rule.target_content, replacement, modified_template)
|
||||
elif rule.injection_type == InjectionType.REMOVE:
|
||||
# 使用正则表达式移除目标内容
|
||||
if rule.target_content:
|
||||
modified_template = re.sub(rule.target_content, "", modified_template)
|
||||
except re.error as e:
|
||||
logger.error(
|
||||
f"在为 '{component_class.prompt_name}' 应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')"
|
||||
)
|
||||
logger.error(f"应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')")
|
||||
except Exception as e:
|
||||
logger.error(f"应用 Prompt 注入规则 '{rule}' 失败: {e}")
|
||||
logger.error(f"应用注入规则 '{rule}' (来源: {source}) 失败: {e}", exc_info=True)
|
||||
|
||||
return modified_template
|
||||
|
||||
async def preview_prompt_injections(
|
||||
self, target_prompt_name: str, params: PromptParameters
|
||||
) -> str:
|
||||
"""
|
||||
【预览功能】模拟应用所有注入规则,返回最终生成的模板字符串,而不实际修改任何状态。
|
||||
|
||||
# 创建全局单例
|
||||
这个方法对于调试和测试非常有用,可以查看在特定参数下,
|
||||
一个核心提示词经过所有注入规则处理后会变成什么样子。
|
||||
|
||||
Args:
|
||||
target_prompt_name (str): 希望预览的目标核心提示词名称。
|
||||
params (PromptParameters): 模拟的请求参数。
|
||||
|
||||
Returns:
|
||||
str: 模拟生成的最终提示词模板字符串。如果找不到模板,则返回错误信息。
|
||||
"""
|
||||
try:
|
||||
# 从全局提示词管理器获取最原始的模板内容
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
original_prompt = global_prompt_manager._prompts.get(target_prompt_name)
|
||||
if not original_prompt:
|
||||
logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。")
|
||||
return f"Error: Prompt '{target_prompt_name}' not found."
|
||||
original_template = original_prompt.template
|
||||
except KeyError:
|
||||
logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。")
|
||||
return f"Error: Prompt '{target_prompt_name}' not found."
|
||||
|
||||
# 直接调用核心注入逻辑来模拟结果
|
||||
return await self.apply_injections(target_prompt_name, original_template, params)
|
||||
|
||||
# --- 状态观测与查询 API ---
|
||||
|
||||
def get_core_prompts(self) -> list[str]:
|
||||
"""获取所有已注册的核心提示词模板名称列表(即所有可注入的目标)。"""
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
return list(global_prompt_manager._prompts.keys())
|
||||
|
||||
def get_core_prompt_contents(self) -> dict[str, str]:
|
||||
"""获取所有核心提示词模板的原始内容。"""
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
return {name: prompt.template for name, prompt in global_prompt_manager._prompts.items()}
|
||||
|
||||
def get_registered_prompt_component_info(self) -> list[PromptInfo]:
|
||||
"""获取所有在 ComponentRegistry 中注册的 Prompt 组件信息。"""
|
||||
components = component_registry.get_components_by_type(ComponentType.PROMPT).values()
|
||||
return [info for info in components if isinstance(info, PromptInfo)]
|
||||
|
||||
async def get_full_injection_map(self) -> dict[str, list[dict]]:
|
||||
"""
|
||||
获取当前完整的注入映射图。
|
||||
|
||||
此方法提供了一个系统全局的注入视图,展示了每个核心提示词(target)
|
||||
被哪些注入组件(source)以何种优先级注入。
|
||||
|
||||
Returns:
|
||||
dict[str, list[dict]]: 一个字典,键是目标提示词名称,
|
||||
值是按优先级排序的注入信息列表。
|
||||
`[{"name": str, "priority": int, "source": str}]`
|
||||
"""
|
||||
injection_map = {}
|
||||
async with self._lock:
|
||||
# 合并所有动态规则的目标和所有核心提示词,确保所有潜在目标都被包含
|
||||
all_targets = set(self._dynamic_rules.keys()) | set(self.get_core_prompts())
|
||||
for target in sorted(all_targets):
|
||||
rules = self._dynamic_rules.get(target, {})
|
||||
if not rules:
|
||||
injection_map[target] = []
|
||||
continue
|
||||
|
||||
info_list = []
|
||||
for prompt_name, (rule, _, source) in rules.items():
|
||||
info_list.append({"name": prompt_name, "priority": rule.priority, "source": source})
|
||||
|
||||
# 按优先级排序后存入 map
|
||||
info_list.sort(key=lambda x: x["priority"])
|
||||
injection_map[target] = info_list
|
||||
return injection_map
|
||||
|
||||
async def get_injections_for_prompt(self, target_prompt_name: str) -> list[dict]:
|
||||
"""
|
||||
获取指定核心提示词模板的所有注入信息(包含详细规则)。
|
||||
|
||||
Args:
|
||||
target_prompt_name (str): 目标核心提示词的名称。
|
||||
|
||||
Returns:
|
||||
list[dict]: 一个包含注入规则详细信息的列表,已按优先级排序。
|
||||
"""
|
||||
rules_for_target = self._dynamic_rules.get(target_prompt_name, {})
|
||||
if not rules_for_target:
|
||||
return []
|
||||
|
||||
info_list = []
|
||||
for prompt_name, (rule, _, source) in rules_for_target.items():
|
||||
info_list.append(
|
||||
{
|
||||
"name": prompt_name,
|
||||
"priority": rule.priority,
|
||||
"source": source,
|
||||
"injection_type": rule.injection_type.value,
|
||||
"target_content": rule.target_content,
|
||||
}
|
||||
)
|
||||
info_list.sort(key=lambda x: x["priority"])
|
||||
return info_list
|
||||
|
||||
def get_all_dynamic_rules(self) -> dict[str, dict[str, "InjectionRule"]]:
|
||||
"""
|
||||
获取所有当前的动态注入规则,以 InjectionRule 对象形式返回。
|
||||
|
||||
此方法返回一个深拷贝的规则副本,隐藏了 `content_provider` 等内部实现细节。
|
||||
适合用于展示或序列化当前的规则配置。
|
||||
"""
|
||||
rules_copy = {}
|
||||
for target, rules in self._dynamic_rules.items():
|
||||
target_copy = {name: rule for name, (rule, _, _) in rules.items()}
|
||||
rules_copy[target] = target_copy
|
||||
return copy.deepcopy(rules_copy)
|
||||
|
||||
def get_rules_for_target(self, target_prompt: str) -> dict[str, InjectionRule]:
|
||||
"""
|
||||
获取所有注入到指定核心提示词的动态规则。
|
||||
|
||||
Args:
|
||||
target_prompt (str): 目标核心提示词的名称。
|
||||
|
||||
Returns:
|
||||
dict[str, InjectionRule]: 一个字典,键是注入组件的名称,值是 `InjectionRule` 对象。
|
||||
如果找不到任何注入到该目标的规则,则返回一个空字典。
|
||||
"""
|
||||
target_rules = self._dynamic_rules.get(target_prompt, {})
|
||||
return {name: copy.deepcopy(rule_info[0]) for name, rule_info in target_rules.items()}
|
||||
|
||||
def get_rules_by_component(self, component_name: str) -> dict[str, InjectionRule]:
|
||||
"""
|
||||
获取由指定的单个注入组件定义的所有动态规则。
|
||||
|
||||
Args:
|
||||
component_name (str): 注入组件的名称。
|
||||
|
||||
Returns:
|
||||
dict[str, InjectionRule]: 一个字典,键是目标核心提示词的名称,值是 `InjectionRule` 对象。
|
||||
如果该组件没有定义任何注入规则,则返回一个空字典。
|
||||
"""
|
||||
found_rules = {}
|
||||
for target, rules in self._dynamic_rules.items():
|
||||
if component_name in rules:
|
||||
rule_info = rules[component_name]
|
||||
found_rules[target] = copy.deepcopy(rule_info[0])
|
||||
return found_rules
|
||||
|
||||
|
||||
# 创建全局单例 (Singleton)
|
||||
# 在整个应用程序中,应该只使用这一个 `prompt_component_manager` 实例,
|
||||
# 以确保所有部分都共享和操作同一份动态规则集。
|
||||
prompt_component_manager = PromptComponentManager()
|
||||
|
||||
@@ -98,7 +98,7 @@ class StreamContext(BaseDataModel):
|
||||
break
|
||||
|
||||
def mark_message_as_read(self, message_id: str):
|
||||
"""标记消息为已读"""
|
||||
"""标记消息为已读"""
|
||||
# 先找到要标记的消息(处理 int/str 类型不匹配问题)
|
||||
message_to_mark = None
|
||||
for msg in self.unread_messages:
|
||||
@@ -106,7 +106,7 @@ class StreamContext(BaseDataModel):
|
||||
if str(msg.message_id) == str(message_id):
|
||||
message_to_mark = msg
|
||||
break
|
||||
|
||||
|
||||
# 然后移动到历史消息
|
||||
if message_to_mark:
|
||||
message_to_mark.is_read = True
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.memory_utils import estimate_size_smart
|
||||
@@ -96,7 +97,7 @@ class LRUCache(Generic[T]):
|
||||
self._lock = asyncio.Lock()
|
||||
self._stats = CacheStats()
|
||||
|
||||
async def get(self, key: str) -> Optional[T]:
|
||||
async def get(self, key: str) -> T | None:
|
||||
"""获取缓存值
|
||||
|
||||
Args:
|
||||
@@ -137,8 +138,8 @@ class LRUCache(Generic[T]):
|
||||
self,
|
||||
key: str,
|
||||
value: T,
|
||||
size: Optional[int] = None,
|
||||
ttl: Optional[float] = None,
|
||||
size: int | None = None,
|
||||
ttl: float | None = None,
|
||||
) -> None:
|
||||
"""设置缓存值
|
||||
|
||||
@@ -287,8 +288,8 @@ class MultiLevelCache:
|
||||
async def get(
|
||||
self,
|
||||
key: str,
|
||||
loader: Optional[Callable[[], Any]] = None,
|
||||
) -> Optional[Any]:
|
||||
loader: Callable[[], Any] | None = None,
|
||||
) -> Any | None:
|
||||
"""从缓存获取数据
|
||||
|
||||
查询顺序:L1 -> L2 -> loader
|
||||
@@ -329,8 +330,8 @@ class MultiLevelCache:
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
size: Optional[int] = None,
|
||||
ttl: Optional[float] = None,
|
||||
size: int | None = None,
|
||||
ttl: float | None = None,
|
||||
) -> None:
|
||||
"""设置缓存值
|
||||
|
||||
@@ -390,7 +391,7 @@ class MultiLevelCache:
|
||||
await self.l2_cache.clear()
|
||||
logger.info("所有缓存已清空")
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
|
||||
# 🔧 修复:并行获取统计信息,避免锁嵌套
|
||||
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
|
||||
@@ -492,7 +493,7 @@ class MultiLevelCache:
|
||||
logger.error(f"{cache_name}统计获取异常: {e}")
|
||||
return CacheStats()
|
||||
|
||||
async def _get_cache_keys_safe(self, cache) -> Set[str]:
|
||||
async def _get_cache_keys_safe(self, cache) -> builtins.set[str]:
|
||||
"""安全获取缓存键集合(带超时)"""
|
||||
try:
|
||||
# 快速获取键集合,使用超时避免死锁
|
||||
@@ -507,12 +508,12 @@ class MultiLevelCache:
|
||||
logger.error(f"缓存键获取异常: {e}")
|
||||
return set()
|
||||
|
||||
async def _extract_keys_with_lock(self, cache) -> Set[str]:
|
||||
async def _extract_keys_with_lock(self, cache) -> builtins.set[str]:
|
||||
"""在锁保护下提取键集合"""
|
||||
async with cache._lock:
|
||||
return set(cache._cache.keys())
|
||||
|
||||
async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int:
|
||||
async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int:
|
||||
"""安全计算内存使用(带超时)"""
|
||||
if not keys:
|
||||
return 0
|
||||
@@ -529,7 +530,7 @@ class MultiLevelCache:
|
||||
logger.error(f"内存计算异常: {e}")
|
||||
return 0
|
||||
|
||||
async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int:
|
||||
async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int:
|
||||
"""在锁保护下计算内存使用"""
|
||||
total_size = 0
|
||||
async with cache._lock:
|
||||
@@ -749,7 +750,7 @@ class MultiLevelCache:
|
||||
|
||||
|
||||
# 全局缓存实例
|
||||
_global_cache: Optional[MultiLevelCache] = None
|
||||
_global_cache: MultiLevelCache | None = None
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import socket
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from rich.traceback import install
|
||||
from uvicorn import Config
|
||||
from uvicorn import Server as UvicornServer
|
||||
|
||||
@@ -436,7 +436,7 @@ class OpenaiClient(BaseClient):
|
||||
# 🔧 优化:增加连接池限制,支持高并发embedding请求
|
||||
# 默认httpx限制为100,对于高频embedding场景不够用
|
||||
import httpx
|
||||
|
||||
|
||||
limits = httpx.Limits(
|
||||
max_keepalive_connections=200, # 保持活跃连接数(原100)
|
||||
max_connections=300, # 最大总连接数(原100)
|
||||
|
||||
@@ -128,7 +128,7 @@ class MemoryBuilder:
|
||||
# 6. 构建 Memory 对象
|
||||
# 新记忆应该有较高的初始激活度
|
||||
initial_activation = 0.75 # 新记忆初始激活度为 0.75
|
||||
|
||||
|
||||
memory = Memory(
|
||||
id=memory_id,
|
||||
subject_id=subject_node.id,
|
||||
|
||||
@@ -149,7 +149,7 @@ class MemoryManager:
|
||||
# 读取阈值过滤配置
|
||||
search_min_importance = self.config.search_min_importance
|
||||
search_similarity_threshold = self.config.search_similarity_threshold
|
||||
|
||||
|
||||
logger.info(
|
||||
f"📊 配置检查: search_max_expand_depth={expand_depth}, "
|
||||
f"search_expand_semantic_threshold={expand_semantic_threshold}, "
|
||||
@@ -417,7 +417,7 @@ class MemoryManager:
|
||||
# 使用配置的默认值
|
||||
if top_k is None:
|
||||
top_k = getattr(self.config, "search_top_k", 10)
|
||||
|
||||
|
||||
# 准备搜索参数
|
||||
params = {
|
||||
"query": query,
|
||||
@@ -951,7 +951,7 @@ class MemoryManager:
|
||||
)
|
||||
else:
|
||||
logger.debug(f"记忆已删除: {memory_id} (删除了 {deleted_vectors} 个向量)")
|
||||
|
||||
|
||||
# 4. 保存更新
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
return True
|
||||
@@ -984,7 +984,7 @@ class MemoryManager:
|
||||
try:
|
||||
forgotten_count = 0
|
||||
all_memories = self.graph_store.get_all_memories()
|
||||
|
||||
|
||||
# 获取配置参数
|
||||
min_importance = getattr(self.config, "forgetting_min_importance", 0.8)
|
||||
decay_rate = getattr(self.config, "activation_decay_rate", 0.9)
|
||||
@@ -1010,10 +1010,10 @@ class MemoryManager:
|
||||
try:
|
||||
last_access_dt = datetime.fromisoformat(last_access)
|
||||
days_passed = (datetime.now() - last_access_dt).days
|
||||
|
||||
|
||||
# 应用指数衰减:activation = base * (decay_rate ^ days)
|
||||
current_activation = base_activation * (decay_rate ** days_passed)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"记忆 {memory.id[:8]}: 基础激活度={base_activation:.3f}, "
|
||||
f"经过{days_passed}天衰减后={current_activation:.3f}"
|
||||
@@ -1035,20 +1035,20 @@ class MemoryManager:
|
||||
# 批量遗忘记忆(不立即清理孤立节点)
|
||||
if memories_to_forget:
|
||||
logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...")
|
||||
|
||||
|
||||
for memory_id, activation in memories_to_forget:
|
||||
# cleanup_orphans=False:暂不清理孤立节点
|
||||
success = await self.forget_memory(memory_id, cleanup_orphans=False)
|
||||
if success:
|
||||
forgotten_count += 1
|
||||
|
||||
|
||||
# 统一清理孤立节点和边
|
||||
logger.info("批量遗忘完成,开始统一清理孤立节点和边...")
|
||||
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
|
||||
|
||||
|
||||
# 保存最终更新
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, "
|
||||
f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边"
|
||||
@@ -1079,31 +1079,31 @@ class MemoryManager:
|
||||
# 1. 清理孤立节点
|
||||
# graph_store.node_to_memories 记录了每个节点属于哪些记忆
|
||||
nodes_to_remove = []
|
||||
|
||||
|
||||
for node_id, memory_ids in list(self.graph_store.node_to_memories.items()):
|
||||
# 如果节点不再属于任何记忆,标记为删除
|
||||
if not memory_ids:
|
||||
nodes_to_remove.append(node_id)
|
||||
|
||||
|
||||
# 从图中删除孤立节点
|
||||
for node_id in nodes_to_remove:
|
||||
if self.graph_store.graph.has_node(node_id):
|
||||
self.graph_store.graph.remove_node(node_id)
|
||||
orphan_nodes_count += 1
|
||||
|
||||
|
||||
# 从映射中删除
|
||||
if node_id in self.graph_store.node_to_memories:
|
||||
del self.graph_store.node_to_memories[node_id]
|
||||
|
||||
|
||||
# 2. 清理孤立边(指向已删除节点的边)
|
||||
edges_to_remove = []
|
||||
|
||||
for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'):
|
||||
|
||||
for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
|
||||
# 检查边的源节点和目标节点是否还存在于node_to_memories中
|
||||
if source not in self.graph_store.node_to_memories or \
|
||||
target not in self.graph_store.node_to_memories:
|
||||
edges_to_remove.append((source, target))
|
||||
|
||||
|
||||
# 删除孤立边
|
||||
for source, target in edges_to_remove:
|
||||
try:
|
||||
@@ -1111,12 +1111,12 @@ class MemoryManager:
|
||||
orphan_edges_count += 1
|
||||
except Exception as e:
|
||||
logger.debug(f"删除边失败 {source} -> {target}: {e}")
|
||||
|
||||
|
||||
if orphan_nodes_count > 0 or orphan_edges_count > 0:
|
||||
logger.info(
|
||||
f"清理完成: {orphan_nodes_count} 个孤立节点, {orphan_edges_count} 条孤立边"
|
||||
)
|
||||
|
||||
|
||||
return orphan_nodes_count, orphan_edges_count
|
||||
|
||||
except Exception as e:
|
||||
@@ -1258,7 +1258,7 @@ class MemoryManager:
|
||||
mem for mem in recent_memories
|
||||
if mem.importance >= min_importance_for_consolidation
|
||||
]
|
||||
|
||||
|
||||
result["importance_filtered"] = len(recent_memories) - len(important_memories)
|
||||
logger.info(
|
||||
f"📊 步骤2: 重要性过滤 (阈值={min_importance_for_consolidation:.2f}): "
|
||||
@@ -1382,26 +1382,26 @@ class MemoryManager:
|
||||
# ===== 步骤4: 向量检索关联记忆 + LLM分析关系 =====
|
||||
# 过滤掉已删除的记忆
|
||||
remaining_memories = [m for m in important_memories if m.id not in deleted_ids]
|
||||
|
||||
|
||||
if not remaining_memories:
|
||||
logger.info("✅ 记忆整理完成: 去重后无剩余记忆")
|
||||
return
|
||||
|
||||
logger.info(f"📍 步骤4: 开始关联分析 ({len(remaining_memories)} 条记忆)...")
|
||||
|
||||
|
||||
# 分批处理记忆关联
|
||||
llm_batch_size = getattr(self.config, "consolidation_llm_batch_size", 10)
|
||||
max_candidates_per_memory = getattr(self.config, "consolidation_max_candidates", 5)
|
||||
min_confidence = getattr(self.config, "consolidation_min_confidence", 0.6)
|
||||
|
||||
|
||||
all_new_edges = [] # 收集所有新建的边
|
||||
|
||||
|
||||
for batch_start in range(0, len(remaining_memories), llm_batch_size):
|
||||
batch_end = min(batch_start + llm_batch_size, len(remaining_memories))
|
||||
batch = remaining_memories[batch_start:batch_end]
|
||||
|
||||
|
||||
logger.debug(f"处理批次 {batch_start//llm_batch_size + 1}/{(len(remaining_memories)-1)//llm_batch_size + 1}")
|
||||
|
||||
|
||||
for memory in batch:
|
||||
# 跳过已经有很多连接的记忆
|
||||
existing_edges = len([
|
||||
@@ -1454,14 +1454,14 @@ class MemoryManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"创建关联边失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 每个批次后让出控制权
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# ===== 步骤5: 统一更新记忆数据 =====
|
||||
if all_new_edges:
|
||||
logger.info(f"📍 步骤5: 统一更新 {len(all_new_edges)} 条新关联边...")
|
||||
|
||||
|
||||
for memory, edge, relation in all_new_edges:
|
||||
try:
|
||||
# 添加到图
|
||||
@@ -2301,7 +2301,7 @@ class MemoryManager:
|
||||
# 使用 asyncio.wait_for 来支持取消
|
||||
await asyncio.wait_for(
|
||||
asyncio.sleep(initial_delay),
|
||||
timeout=float('inf') # 允许随时取消
|
||||
timeout=float("inf") # 允许随时取消
|
||||
)
|
||||
|
||||
# 检查是否仍然需要运行
|
||||
|
||||
@@ -482,7 +482,7 @@ class GraphStore:
|
||||
for node in memory.nodes:
|
||||
if node.id in self.node_to_memories:
|
||||
self.node_to_memories[node.id].discard(memory_id)
|
||||
|
||||
|
||||
# 可选:立即清理孤立节点
|
||||
if cleanup_orphans:
|
||||
# 如果该节点不再属于任何记忆,从图中移除节点
|
||||
|
||||
@@ -72,12 +72,12 @@ class MemoryTools:
|
||||
self.max_expand_depth = max_expand_depth
|
||||
self.expand_semantic_threshold = expand_semantic_threshold
|
||||
self.search_top_k = search_top_k
|
||||
|
||||
|
||||
# 保存权重配置
|
||||
self.base_vector_weight = search_vector_weight
|
||||
self.base_importance_weight = search_importance_weight
|
||||
self.base_recency_weight = search_recency_weight
|
||||
|
||||
|
||||
# 保存阈值过滤配置
|
||||
self.search_min_importance = search_min_importance
|
||||
self.search_similarity_threshold = search_similarity_threshold
|
||||
@@ -516,14 +516,14 @@ class MemoryTools:
|
||||
|
||||
# 1. 根据策略选择检索方式
|
||||
llm_prefer_types = [] # LLM识别的偏好节点类型
|
||||
|
||||
|
||||
if use_multi_query:
|
||||
# 多查询策略(返回节点列表 + 偏好类型)
|
||||
similar_nodes, llm_prefer_types = await self._multi_query_search(query, top_k, context)
|
||||
else:
|
||||
# 传统单查询策略
|
||||
similar_nodes = await self._single_query_search(query, top_k)
|
||||
|
||||
|
||||
# 合并用户指定的偏好类型和LLM识别的偏好类型
|
||||
all_prefer_types = list(set(prefer_node_types + llm_prefer_types))
|
||||
if all_prefer_types:
|
||||
@@ -551,7 +551,7 @@ class MemoryTools:
|
||||
# 记录最高分数
|
||||
if mem_id not in memory_scores or similarity > memory_scores[mem_id]:
|
||||
memory_scores[mem_id] = similarity
|
||||
|
||||
|
||||
# 🔥 详细日志:检查初始召回情况
|
||||
logger.info(
|
||||
f"初始向量搜索: 返回{len(similar_nodes)}个节点 → "
|
||||
@@ -559,8 +559,8 @@ class MemoryTools:
|
||||
)
|
||||
if len(initial_memory_ids) == 0:
|
||||
logger.warning(
|
||||
f"⚠️ 向量搜索未找到任何记忆!"
|
||||
f"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
|
||||
"⚠️ 向量搜索未找到任何记忆!"
|
||||
"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
|
||||
)
|
||||
# 输出相似节点的详细信息用于调试
|
||||
if similar_nodes:
|
||||
@@ -692,7 +692,7 @@ class MemoryTools:
|
||||
key=lambda x: final_scores[x],
|
||||
reverse=True
|
||||
) # 🔥 不再提前截断,让所有候选参与详细评分
|
||||
|
||||
|
||||
# 🔍 统计初始记忆的相似度分布(用于诊断)
|
||||
if memory_scores:
|
||||
similarities = list(memory_scores.values())
|
||||
@@ -707,7 +707,7 @@ class MemoryTools:
|
||||
# 5. 获取完整记忆并进行最终排序(优化后的动态权重系统)
|
||||
memories_with_scores = []
|
||||
filter_stats = {"importance": 0, "similarity": 0, "total_checked": 0} # 过滤统计
|
||||
|
||||
|
||||
for memory_id in sorted_memory_ids: # 遍历所有候选
|
||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||
if memory:
|
||||
@@ -715,7 +715,7 @@ class MemoryTools:
|
||||
# 基础分数
|
||||
similarity_score = final_scores[memory_id]
|
||||
importance_score = memory.importance
|
||||
|
||||
|
||||
# 🆕 区分记忆来源(用于过滤)
|
||||
is_initial_memory = memory_id in memory_scores # 是否来自初始向量搜索
|
||||
true_similarity = memory_scores.get(memory_id, 0.0) if is_initial_memory else None
|
||||
@@ -738,16 +738,16 @@ class MemoryTools:
|
||||
activation_score = memory.activation
|
||||
|
||||
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
|
||||
memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type)
|
||||
|
||||
memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type)
|
||||
|
||||
# 检测记忆的主要节点类型
|
||||
node_types_count = {}
|
||||
for node in memory.nodes:
|
||||
nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type)
|
||||
nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type)
|
||||
node_types_count[nt] = node_types_count.get(nt, 0) + 1
|
||||
|
||||
|
||||
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
|
||||
|
||||
|
||||
# 根据记忆类型和节点类型计算调整系数(在配置权重基础上微调)
|
||||
if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT":
|
||||
# 事实性记忆:提升相似度权重,降低时效性权重
|
||||
@@ -777,41 +777,41 @@ class MemoryTools:
|
||||
"importance": 1.0,
|
||||
"recency": 1.0,
|
||||
}
|
||||
|
||||
|
||||
# 应用调整后的权重(基于配置的基础权重)
|
||||
weights = {
|
||||
"similarity": self.base_vector_weight * type_adjustments["similarity"],
|
||||
"importance": self.base_importance_weight * type_adjustments["importance"],
|
||||
"recency": self.base_recency_weight * type_adjustments["recency"],
|
||||
}
|
||||
|
||||
|
||||
# 归一化权重(确保总和为1.0)
|
||||
total_weight = sum(weights.values())
|
||||
if total_weight > 0:
|
||||
weights = {k: v / total_weight for k, v in weights.items()}
|
||||
|
||||
|
||||
# 综合分数计算(🔥 移除激活度影响)
|
||||
final_score = (
|
||||
similarity_score * weights["similarity"] +
|
||||
importance_score * weights["importance"] +
|
||||
recency_score * weights["recency"]
|
||||
)
|
||||
|
||||
|
||||
# 🆕 阈值过滤策略:
|
||||
# 1. 重要性过滤:应用于所有记忆(过滤极低质量)
|
||||
if memory.importance < self.search_min_importance:
|
||||
filter_stats["importance"] += 1
|
||||
logger.debug(f"❌ 过滤 {memory.id[:8]}: 重要性 {memory.importance:.2f} < 阈值 {self.search_min_importance}")
|
||||
continue
|
||||
|
||||
|
||||
# 2. 相似度过滤:不再对初始向量搜索结果过滤(信任向量搜索的排序)
|
||||
# 理由:向量搜索已经按相似度排序,返回的都是最相关结果
|
||||
# 如果再用阈值过滤,会导致"最相关的也不够相关"的矛盾
|
||||
#
|
||||
#
|
||||
# 注意:如果未来需要对扩展记忆过滤,可以在这里添加逻辑
|
||||
# if not is_initial_memory and some_score < threshold:
|
||||
# continue
|
||||
|
||||
|
||||
# 记录通过过滤的记忆(用于调试)
|
||||
if is_initial_memory:
|
||||
logger.debug(
|
||||
@@ -823,11 +823,11 @@ class MemoryTools:
|
||||
f"✅ 保留 {memory.id[:8]} [扩展]: 重要性={memory.importance:.2f}, "
|
||||
f"综合分数={final_score:.4f}"
|
||||
)
|
||||
|
||||
|
||||
# 🆕 节点类型加权:对REFERENCE/ATTRIBUTE节点额外加分(促进事实性信息召回)
|
||||
if "REFERENCE" in node_types_count or "ATTRIBUTE" in node_types_count:
|
||||
final_score *= 1.1 # 10% 加成
|
||||
|
||||
|
||||
# 🆕 用户指定的优先节点类型额外加权
|
||||
if prefer_node_types:
|
||||
for prefer_type in prefer_node_types:
|
||||
@@ -835,7 +835,7 @@ class MemoryTools:
|
||||
final_score *= 1.15 # 15% 额外加成
|
||||
logger.debug(f"记忆 {memory.id[:8]} 包含优先节点类型 {prefer_type},加权后分数: {final_score:.4f}")
|
||||
break
|
||||
|
||||
|
||||
memories_with_scores.append((memory, final_score, dominant_node_type))
|
||||
|
||||
# 按综合分数排序
|
||||
@@ -845,7 +845,7 @@ class MemoryTools:
|
||||
# 统计过滤情况
|
||||
total_candidates = len(all_memory_ids)
|
||||
filtered_count = total_candidates - len(memories_with_scores)
|
||||
|
||||
|
||||
# 6. 格式化结果(包含调试信息)
|
||||
results = []
|
||||
for memory, score, node_type in memories_with_scores[:top_k]:
|
||||
@@ -866,7 +866,7 @@ class MemoryTools:
|
||||
f"过滤{filtered_count}个 (重要性过滤) → "
|
||||
f"最终返回{len(results)}条记忆"
|
||||
)
|
||||
|
||||
|
||||
# 如果过滤率过高,发出警告
|
||||
if total_candidates > 0:
|
||||
filter_rate = filtered_count / total_candidates
|
||||
@@ -1092,20 +1092,21 @@ class MemoryTools:
|
||||
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
|
||||
|
||||
import re
|
||||
|
||||
import orjson
|
||||
|
||||
|
||||
# 清理Markdown代码块
|
||||
response = re.sub(r"```json\s*", "", response)
|
||||
response = re.sub(r"```\s*$", "", response).strip()
|
||||
|
||||
# 解析JSON
|
||||
data = orjson.loads(response)
|
||||
|
||||
|
||||
# 提取查询列表
|
||||
queries = data.get("queries", [])
|
||||
result_queries = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
|
||||
for item in queries if item.get("text", "").strip()]
|
||||
|
||||
|
||||
# 提取偏好节点类型
|
||||
prefer_node_types = data.get("prefer_node_types", [])
|
||||
# 确保类型正确且有效
|
||||
@@ -1154,7 +1155,7 @@ class MemoryTools:
|
||||
limit=top_k * 5, # 🔥 从2倍提升到5倍,提高初始召回率
|
||||
min_similarity=0.0, # 不在这里过滤,交给后续评分
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}")
|
||||
if similar_nodes:
|
||||
logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}")
|
||||
|
||||
@@ -62,7 +62,7 @@ async def expand_memories_with_semantic_filter(
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 记录已访问的记忆,避免重复
|
||||
visited_memories = set(initial_memory_ids)
|
||||
# 记录扩展的记忆及其分数
|
||||
@@ -87,17 +87,17 @@ async def expand_memories_with_semantic_filter(
|
||||
|
||||
# 获取该记忆的邻居记忆(通过边关系)
|
||||
neighbor_memory_ids = set()
|
||||
|
||||
|
||||
# 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重)
|
||||
edge_weights = {} # 记录通过不同边类型到达的记忆的权重
|
||||
|
||||
|
||||
for edge in memory.edges:
|
||||
# 获取边的目标节点
|
||||
target_node_id = edge.target_id
|
||||
source_node_id = edge.source_id
|
||||
|
||||
|
||||
# 🆕 根据边类型设置权重(优先扩展REFERENCE、ATTRIBUTE相关的边)
|
||||
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
|
||||
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type)
|
||||
if edge_type_str == "REFERENCE":
|
||||
edge_weight = 1.3 # REFERENCE边权重最高(引用关系)
|
||||
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
|
||||
@@ -108,18 +108,18 @@ async def expand_memories_with_semantic_filter(
|
||||
edge_weight = 0.9 # 一般关系适中降权
|
||||
else:
|
||||
edge_weight = 1.0 # 默认权重
|
||||
|
||||
|
||||
# 通过节点找到其他记忆
|
||||
for node_id in [target_node_id, source_node_id]:
|
||||
if node_id in graph_store.node_to_memories:
|
||||
for neighbor_id in graph_store.node_to_memories[node_id]:
|
||||
if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight:
|
||||
edge_weights[neighbor_id] = edge_weight
|
||||
|
||||
|
||||
# 将权重高的邻居记忆加入候选
|
||||
for neighbor_id, edge_weight in edge_weights.items():
|
||||
neighbor_memory_ids.add((neighbor_id, edge_weight))
|
||||
|
||||
|
||||
# 过滤掉已访问的和自己
|
||||
filtered_neighbors = []
|
||||
for neighbor_id, edge_weight in neighbor_memory_ids:
|
||||
@@ -129,7 +129,7 @@ async def expand_memories_with_semantic_filter(
|
||||
# 批量评估邻居记忆
|
||||
for neighbor_mem_id, edge_weight in filtered_neighbors:
|
||||
candidates_checked += 1
|
||||
|
||||
|
||||
neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id)
|
||||
if not neighbor_memory:
|
||||
continue
|
||||
@@ -139,7 +139,7 @@ async def expand_memories_with_semantic_filter(
|
||||
(n for n in neighbor_memory.nodes if n.has_embedding()),
|
||||
None
|
||||
)
|
||||
|
||||
|
||||
if not topic_node or topic_node.embedding is None:
|
||||
continue
|
||||
|
||||
@@ -179,11 +179,11 @@ async def expand_memories_with_semantic_filter(
|
||||
if len(expanded_memories) >= max_expanded:
|
||||
logger.debug(f"⏹️ 提前停止:已达到最大扩展数量 {max_expanded}")
|
||||
break
|
||||
|
||||
|
||||
# 早停检查
|
||||
if len(expanded_memories) >= max_expanded:
|
||||
break
|
||||
|
||||
|
||||
# 记录本层统计
|
||||
depth_stats.append({
|
||||
"depth": depth + 1,
|
||||
@@ -199,20 +199,20 @@ async def expand_memories_with_semantic_filter(
|
||||
|
||||
# 限制下一层的记忆数量,避免爆炸性增长
|
||||
current_level_memories = next_level_memories[:max_expanded]
|
||||
|
||||
|
||||
# 每层让出控制权
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
# 排序并返回
|
||||
sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded]
|
||||
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"✅ 图扩展完成: 初始{len(initial_memory_ids)}个 → "
|
||||
f"扩展{len(sorted_results)}个新记忆 "
|
||||
f"(深度={max_depth}, 阈值={semantic_threshold:.2f}, 耗时={elapsed:.3f}s)"
|
||||
)
|
||||
|
||||
|
||||
# 输出每层统计
|
||||
for stat in depth_stats:
|
||||
logger.debug(
|
||||
|
||||
@@ -78,11 +78,9 @@ __all__ = [
|
||||
# 消息
|
||||
"MaiMessages",
|
||||
# 工具函数
|
||||
"ManifestValidator",
|
||||
"PluginInfo",
|
||||
# 增强命令系统
|
||||
"PlusCommand",
|
||||
"PlusCommandAdapter",
|
||||
"PythonDependency",
|
||||
"ToolInfo",
|
||||
"ToolParamType",
|
||||
|
||||
@@ -132,7 +132,7 @@ async def generate_reply(
|
||||
prompt_mode = "s4u" # 默认使用s4u模式
|
||||
if action_data and "prompt_mode" in action_data:
|
||||
prompt_mode = action_data.get("prompt_mode", "s4u")
|
||||
|
||||
|
||||
# 将prompt_mode添加到available_actions中(作为特殊键)
|
||||
# 注意:这里我们需要暂时使用类型忽略,因为available_actions的类型定义不支持非ActionInfo值
|
||||
if available_actions is None:
|
||||
|
||||
@@ -362,7 +362,7 @@ class ChatterPlanFilter:
|
||||
return "最近没有聊天内容。", "没有未读消息。", []
|
||||
|
||||
stream_context = chat_stream.context_manager
|
||||
|
||||
|
||||
# 获取真正的已读和未读消息
|
||||
read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中
|
||||
if not read_messages:
|
||||
@@ -652,30 +652,30 @@ class ChatterPlanFilter:
|
||||
if not action_info:
|
||||
logger.debug(f"动作 {action_name} 不在可用动作列表中,保留所有参数")
|
||||
return action_data
|
||||
|
||||
|
||||
# 获取该动作定义的合法参数
|
||||
defined_params = set(action_info.action_parameters.keys())
|
||||
|
||||
|
||||
# 合法参数集合
|
||||
valid_params = defined_params
|
||||
|
||||
|
||||
# 过滤参数
|
||||
filtered_data = {}
|
||||
removed_params = []
|
||||
|
||||
|
||||
for key, value in action_data.items():
|
||||
if key in valid_params:
|
||||
filtered_data[key] = value
|
||||
else:
|
||||
removed_params.append(key)
|
||||
|
||||
|
||||
# 记录被移除的参数
|
||||
if removed_params:
|
||||
logger.info(
|
||||
f"🧹 [参数过滤] 动作 '{action_name}' 移除了多余参数: {removed_params}. "
|
||||
f"合法参数: {sorted(valid_params)}"
|
||||
)
|
||||
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]:
|
||||
|
||||
@@ -545,14 +545,14 @@ async def execute_proactive_thinking(stream_id: str):
|
||||
# 获取或创建该聊天流的执行锁
|
||||
if stream_id not in _execution_locks:
|
||||
_execution_locks[stream_id] = asyncio.Lock()
|
||||
|
||||
|
||||
lock = _execution_locks[stream_id]
|
||||
|
||||
|
||||
# 尝试获取锁,如果已被占用则跳过本次执行(防止重复)
|
||||
if lock.locked():
|
||||
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 已有正在执行的主动思考任务")
|
||||
return
|
||||
|
||||
|
||||
async with lock:
|
||||
logger.debug(f"🤔 开始主动思考 {stream_id}")
|
||||
|
||||
@@ -563,13 +563,13 @@ async def execute_proactive_thinking(stream_id: str):
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
|
||||
|
||||
if chat_stream and chat_stream.context_manager.context.is_chatter_processing:
|
||||
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"检查 chatter 处理状态时出错: {e},继续执行")
|
||||
|
||||
|
||||
# 0.1 检查白名单/黑名单
|
||||
# 从 stream_id 获取 stream_config 字符串进行验证
|
||||
try:
|
||||
|
||||
@@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata(
|
||||
# 导入插件主类
|
||||
from .plugin import AntiInjectionPlugin
|
||||
|
||||
__all__ = ["__plugin_meta__", "AntiInjectionPlugin"]
|
||||
__all__ = ["AntiInjectionPlugin", "__plugin_meta__"]
|
||||
|
||||
@@ -8,8 +8,8 @@ import time
|
||||
|
||||
from src.chat.security.interfaces import (
|
||||
SecurityAction,
|
||||
SecurityCheckResult,
|
||||
SecurityChecker,
|
||||
SecurityCheckResult,
|
||||
SecurityLevel,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
|
||||
"""
|
||||
|
||||
from src.chat.security.interfaces import SecurityAction, SecurityCheckResult
|
||||
from src.chat.security.interfaces import SecurityCheckResult
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .counter_attack import CounterAttackGenerator
|
||||
|
||||
@@ -64,15 +64,15 @@ class CoreActionsPlugin(BasePlugin):
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components: ClassVar = []
|
||||
|
||||
|
||||
# 注册 reply 动作
|
||||
if self.get_config("components.enable_reply", True):
|
||||
components.append((ReplyAction.get_action_info(), ReplyAction))
|
||||
|
||||
|
||||
# 注册 respond 动作
|
||||
if self.get_config("components.enable_respond", True):
|
||||
components.append((RespondAction.get_action_info(), RespondAction))
|
||||
|
||||
|
||||
# 注册 emoji 动作
|
||||
if self.get_config("components.enable_emoji", True):
|
||||
components.append((EmojiAction.get_action_info(), EmojiAction))
|
||||
|
||||
@@ -22,23 +22,23 @@ class ReplyAction(BaseAction):
|
||||
- 专注于理解和回应单条消息的具体内容
|
||||
- 适合 Focus 模式下的精准回复
|
||||
"""
|
||||
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "reply"
|
||||
action_description = "针对特定消息进行精准回复。深度理解并回应单条消息的具体内容。需要指定目标消息ID。"
|
||||
|
||||
|
||||
# 激活设置
|
||||
activation_type = ActionActivationType.ALWAYS # 回复动作总是可用
|
||||
mode_enable = ChatMode.ALL # 在所有模式下都可用
|
||||
parallel_action = False # 回复动作不能与其他动作并行
|
||||
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters: ClassVar = {
|
||||
"target_message_id": "要回复的目标消息ID(必需,来自未读消息的 <m...> 标签)",
|
||||
"content": "回复的具体内容(可选,由LLM生成)",
|
||||
"should_quote_reply": "是否引用原消息(可选,true/false,默认false。群聊中回复较早消息或需要明确指向时使用true)",
|
||||
}
|
||||
|
||||
|
||||
# 动作使用场景
|
||||
action_require: ClassVar = [
|
||||
"需要针对特定消息进行精准回复时使用",
|
||||
@@ -48,10 +48,10 @@ class ReplyAction(BaseAction):
|
||||
"群聊中需要明确回应某个特定用户或问题时使用",
|
||||
"关注单条消息的具体内容和上下文细节",
|
||||
]
|
||||
|
||||
|
||||
# 关联类型
|
||||
associated_types: ClassVar[list[str]] = ["text"]
|
||||
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行reply动作
|
||||
|
||||
@@ -70,21 +70,21 @@ class RespondAction(BaseAction):
|
||||
- 适合对于群聊消息下的宏观回应
|
||||
- 避免与单一用户深度对话而忽略其他用户的消息
|
||||
"""
|
||||
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "respond"
|
||||
action_description = "统一回应所有未读消息。理解整体对话动态和话题走向,生成连贯的回复。无需指定目标消息。"
|
||||
|
||||
|
||||
# 激活设置
|
||||
activation_type = ActionActivationType.ALWAYS # 回应动作总是可用
|
||||
mode_enable = ChatMode.ALL # 在所有模式下都可用
|
||||
parallel_action = False # 回应动作不能与其他动作并行
|
||||
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters: ClassVar = {
|
||||
"content": "回复的具体内容(可选,由LLM生成)",
|
||||
}
|
||||
|
||||
|
||||
# 动作使用场景
|
||||
action_require: ClassVar = [
|
||||
"需要统一回应多条未读消息时使用(Normal 模式专用)",
|
||||
@@ -94,10 +94,10 @@ class RespondAction(BaseAction):
|
||||
"适合群聊中的自然对话流,无需精确指向特定消息",
|
||||
"可以同时回应多个话题或参与者",
|
||||
]
|
||||
|
||||
|
||||
# 关联类型
|
||||
associated_types: ClassVar[list[str]] = ["text"]
|
||||
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行respond动作
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import filetype
|
||||
from collections.abc import Callable
|
||||
|
||||
import aiohttp
|
||||
import filetype
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from src.chat.utils.prompt_component_manager import prompt_component_manager
|
||||
from src.plugin_system.apis import (
|
||||
plugin_manage_api,
|
||||
)
|
||||
@@ -74,6 +74,7 @@ class SystemCommand(PlusCommand):
|
||||
• `/system permission` - 权限管理
|
||||
• `/system plugin` - 插件管理
|
||||
• `/system schedule` - 定时任务管理
|
||||
• `/system prompt` - 提示词注入管理
|
||||
"""
|
||||
elif target == "schedule":
|
||||
help_text = """📅 定时任务管理帮助
|
||||
@@ -113,8 +114,17 @@ class SystemCommand(PlusCommand):
|
||||
• /system permission nodes [插件名] - 查看权限节点
|
||||
• /system permission allnodes - 查看所有权限节点详情
|
||||
"""
|
||||
await self.send_text(help_text)
|
||||
elif target == "prompt":
|
||||
help_text = """📝 提示词注入管理帮助
|
||||
|
||||
🔎 查询命令 (需要 `system.prompt.view` 权限):
|
||||
• `/system prompt help` - 显示此帮助
|
||||
• `/system prompt map` - 查看全局注入关系图
|
||||
• `/system prompt targets` - 列出所有可被注入的核心提示词
|
||||
• `/system prompt components` - 列出所有已注册的提示词组件
|
||||
• `/system prompt info <目标名>` - 查看特定核心提示词的注入详情
|
||||
"""
|
||||
await self.send_text(help_text)
|
||||
|
||||
# =================================================================
|
||||
# Plugin Management Section
|
||||
@@ -231,6 +241,101 @@ class SystemCommand(PlusCommand):
|
||||
else:
|
||||
await self.send_text(f"❌ 恢复任务失败: `{schedule_id}`")
|
||||
|
||||
# =================================================================
|
||||
# Prompt Management Section
|
||||
# =================================================================
|
||||
async def _handle_prompt_commands(self, args: list[str]):
|
||||
"""处理提示词管理相关命令"""
|
||||
if not args or args[0].lower() in ["help", "帮助"]:
|
||||
await self._show_help("prompt")
|
||||
return
|
||||
|
||||
action = args[0].lower()
|
||||
remaining_args = args[1:]
|
||||
|
||||
if action in ["map", "关系图"]:
|
||||
await self._show_injection_map()
|
||||
elif action in ["targets", "目标"]:
|
||||
await self._list_core_prompts()
|
||||
elif action in ["components", "组件"]:
|
||||
await self._list_prompt_components()
|
||||
elif action in ["info", "详情"] and remaining_args:
|
||||
await self._get_prompt_injection_info(remaining_args[0])
|
||||
else:
|
||||
await self.send_text("❌ 提示词管理命令不合法\n使用 /system prompt help 查看帮助")
|
||||
|
||||
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
|
||||
async def _show_injection_map(self):
|
||||
"""显示全局注入关系图"""
|
||||
injection_map = await prompt_component_manager.get_full_injection_map()
|
||||
if not injection_map:
|
||||
await self.send_text("📊 当前没有任何提示词注入关系")
|
||||
return
|
||||
|
||||
response_parts = ["📊 全局提示词注入关系图:\n"]
|
||||
for target, injections in injection_map.items():
|
||||
if injections:
|
||||
response_parts.append(f"🎯 **{target}** (注入源):")
|
||||
for inj in injections:
|
||||
source_tag = f"({inj['source']})" if inj['source'] != 'static_default' else ''
|
||||
response_parts.append(f" ⎿ `{inj['name']}` (优先级: {inj['priority']}) {source_tag}")
|
||||
else:
|
||||
response_parts.append(f"🎯 **{target}** (无注入)")
|
||||
|
||||
await self._send_long_message("\n".join(response_parts))
|
||||
|
||||
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
|
||||
async def _list_core_prompts(self):
|
||||
"""列出所有可注入的核心提示词"""
|
||||
targets = prompt_component_manager.get_core_prompts()
|
||||
if not targets:
|
||||
await self.send_text("🎯 当前没有可注入的核心提示词")
|
||||
return
|
||||
|
||||
response = "🎯 所有可注入的核心提示词:\n" + "\n".join([f"• `{name}`" for name in targets])
|
||||
await self.send_text(response)
|
||||
|
||||
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
|
||||
async def _list_prompt_components(self):
|
||||
"""列出所有已注册的提示词组件"""
|
||||
components = prompt_component_manager.get_registered_prompt_component_info()
|
||||
if not components:
|
||||
await self.send_text("🧩 当前没有已注册的提示词组件")
|
||||
return
|
||||
|
||||
response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"]
|
||||
for comp in components:
|
||||
response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)")
|
||||
|
||||
await self._send_long_message("\n".join(response_parts))
|
||||
|
||||
|
||||
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
|
||||
async def _get_prompt_injection_info(self, target_name: str):
|
||||
"""获取特定核心提示词的注入详情"""
|
||||
injections = await prompt_component_manager.get_injections_for_prompt(target_name)
|
||||
|
||||
core_prompts = prompt_component_manager.get_core_prompts()
|
||||
if target_name not in core_prompts:
|
||||
await self.send_text(f"❌ 找不到核心提示词: `{target_name}`")
|
||||
return
|
||||
|
||||
if not injections:
|
||||
await self.send_text(f"🎯 核心提示词 `{target_name}` 当前没有被任何组件注入。")
|
||||
return
|
||||
|
||||
response_parts = [f"🔎 核心提示词 `{target_name}` 的注入详情:"]
|
||||
for inj in injections:
|
||||
response_parts.append(
|
||||
f" • **`{inj['name']}`** (优先级: {inj['priority']})"
|
||||
)
|
||||
response_parts.append(f" - 来源: `{inj['source']}`")
|
||||
response_parts.append(f" - 类型: `{inj['injection_type']}`")
|
||||
if inj.get('target_content'):
|
||||
response_parts.append(f" - 操作目标: `{inj['target_content']}`")
|
||||
|
||||
await self.send_text("\n".join(response_parts))
|
||||
|
||||
# =================================================================
|
||||
# Permission Management Section
|
||||
# =================================================================
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user