This commit is contained in:
Windpicker-owo
2025-11-12 13:38:12 +08:00
36 changed files with 934 additions and 626 deletions

2
bot.py
View File

@@ -588,7 +588,7 @@ class MaiBotMain:
async def run_async_init(self, main_system): async def run_async_init(self, main_system):
"""执行异步初始化步骤""" """执行异步初始化步骤"""
# 初始化数据库表结构 # 初始化数据库表结构
await self.initialize_database_async() await self.initialize_database_async()

View File

@@ -19,14 +19,13 @@
import asyncio import asyncio
import sys import sys
from pathlib import Path from pathlib import Path
from typing import List
# 添加项目根目录到路径 # 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
async def generate_missing_embeddings( async def generate_missing_embeddings(
target_node_types: List[str] = None, target_node_types: list[str] = None,
batch_size: int = 50, batch_size: int = 50,
): ):
""" """
@@ -46,13 +45,13 @@ async def generate_missing_embeddings(
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value] target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
print(f"\n{'='*80}") print(f"\n{'='*80}")
print(f"🔧 为节点生成嵌入向量") print("🔧 为节点生成嵌入向量")
print(f"{'='*80}\n") print(f"{'='*80}\n")
print(f"目标节点类型: {', '.join(target_node_types)}") print(f"目标节点类型: {', '.join(target_node_types)}")
print(f"批处理大小: {batch_size}\n") print(f"批处理大小: {batch_size}\n")
# 1. 初始化记忆管理器 # 1. 初始化记忆管理器
print(f"🔧 正在初始化记忆管理器...") print("🔧 正在初始化记忆管理器...")
await initialize_memory_manager() await initialize_memory_manager()
manager = get_memory_manager() manager = get_memory_manager()
@@ -60,10 +59,10 @@ async def generate_missing_embeddings(
print("❌ 记忆管理器初始化失败") print("❌ 记忆管理器初始化失败")
return return
print(f"✅ 记忆管理器已初始化\n") print("✅ 记忆管理器已初始化\n")
# 2. 获取已索引的节点ID # 2. 获取已索引的节点ID
print(f"🔍 检查现有向量索引...") print("🔍 检查现有向量索引...")
existing_node_ids = set() existing_node_ids = set()
try: try:
vector_count = manager.vector_store.collection.count() vector_count = manager.vector_store.collection.count()
@@ -78,14 +77,14 @@ async def generate_missing_embeddings(
) )
if result and "ids" in result: if result and "ids" in result:
existing_node_ids.update(result["ids"]) existing_node_ids.update(result["ids"])
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n") print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
except Exception as e: except Exception as e:
logger.warning(f"获取已索引节点ID失败: {e}") logger.warning(f"获取已索引节点ID失败: {e}")
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n") print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
# 3. 收集需要生成嵌入的节点 # 3. 收集需要生成嵌入的节点
print(f"🔍 扫描需要生成嵌入的节点...") print("🔍 扫描需要生成嵌入的节点...")
all_memories = manager.graph_store.get_all_memories() all_memories = manager.graph_store.get_all_memories()
nodes_to_process = [] nodes_to_process = []
@@ -110,7 +109,7 @@ async def generate_missing_embeddings(
}) })
type_stats[node.node_type.value]["need_emb"] += 1 type_stats[node.node_type.value]["need_emb"] += 1
print(f"\n📊 扫描结果:") print("\n📊 扫描结果:")
for node_type in target_node_types: for node_type in target_node_types:
stats = type_stats[node_type] stats = type_stats[node_type]
already_ok = stats["already_indexed"] already_ok = stats["already_indexed"]
@@ -121,11 +120,11 @@ async def generate_missing_embeddings(
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n") print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
if len(nodes_to_process) == 0: if len(nodes_to_process) == 0:
print(f"✅ 所有节点已有嵌入向量,无需生成") print("✅ 所有节点已有嵌入向量,无需生成")
return return
# 3. 批量生成嵌入 # 3. 批量生成嵌入
print(f"🚀 开始生成嵌入向量...\n") print("🚀 开始生成嵌入向量...\n")
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
success_count = 0 success_count = 0
@@ -193,22 +192,22 @@ async def generate_missing_embeddings(
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n") print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
# 4. 保存图数据(更新节点的 embedding 字段) # 4. 保存图数据(更新节点的 embedding 字段)
print(f"💾 保存图数据...") print("💾 保存图数据...")
try: try:
await manager.persistence.save_graph_store(manager.graph_store) await manager.persistence.save_graph_store(manager.graph_store)
print(f"✅ 图数据已保存\n") print("✅ 图数据已保存\n")
except Exception as e: except Exception as e:
logger.error(f"保存图数据失败", exc_info=True) logger.error("保存图数据失败", exc_info=True)
print(f"❌ 保存失败: {e}\n") print(f"❌ 保存失败: {e}\n")
# 5. 验证结果 # 5. 验证结果
print(f"🔍 验证向量索引...") print("🔍 验证向量索引...")
final_vector_count = manager.vector_store.collection.count() final_vector_count = manager.vector_store.collection.count()
stats = manager.graph_store.get_statistics() stats = manager.graph_store.get_statistics()
total_nodes = stats["total_nodes"] total_nodes = stats["total_nodes"]
print(f"\n{'='*80}") print(f"\n{'='*80}")
print(f"📊 生成完成") print("📊 生成完成")
print(f"{'='*80}") print(f"{'='*80}")
print(f"处理节点数: {len(nodes_to_process)}") print(f"处理节点数: {len(nodes_to_process)}")
print(f"成功生成: {success_count}") print(f"成功生成: {success_count}")
@@ -219,7 +218,7 @@ async def generate_missing_embeddings(
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n") print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
# 6. 测试搜索 # 6. 测试搜索
print(f"🧪 测试搜索功能...") print("🧪 测试搜索功能...")
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"] test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
for query in test_queries: for query in test_queries:

View File

@@ -38,7 +38,7 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
# ========== 性能配置参数 ========== # ========== 性能配置参数 ==========
# #
# 知识提取步骤2txt转json并发控制 # 知识提取步骤2txt转json并发控制
# - 控制同时进行的LLM提取请求数量 # - 控制同时进行的LLM提取请求数量
# - 推荐值: 3-10取决于API速率限制 # - 推荐值: 3-10取决于API速率限制
@@ -184,7 +184,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
tuple: (doc_item或None, failed_hash或None) tuple: (doc_item或None, failed_hash或None)
""" """
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
# 🔧 优化:使用异步文件检查,避免阻塞 # 🔧 优化:使用异步文件检查,避免阻塞
if os.path.exists(temp_file_path): if os.path.exists(temp_file_path):
try: try:
@@ -215,11 +215,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
"extracted_entities": extracted_data.get("entities", []), "extracted_entities": extracted_data.get("entities", []),
"extracted_triples": extracted_data.get("triples", []), "extracted_triples": extracted_data.get("triples", []),
} }
# 保存到缓存(异步写入) # 保存到缓存(异步写入)
async with aiofiles.open(temp_file_path, "wb") as f: async with aiofiles.open(temp_file_path, "wb") as f:
await f.write(orjson.dumps(doc_item)) await f.write(orjson.dumps(doc_item))
return doc_item, None return doc_item, None
except Exception as e: except Exception as e:
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
@@ -249,13 +249,13 @@ async def extract_information(paragraphs_dict, model_set):
os.makedirs(TEMP_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True)
failed_hashes, open_ie_docs = [], [] failed_hashes, open_ie_docs = [], []
# 🔧 关键修复:创建单个 LLM 请求实例,复用连接 # 🔧 关键修复:创建单个 LLM 请求实例,复用连接
llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction") llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction")
# 🔧 并发控制:限制最大并发数,防止速率限制 # 🔧 并发控制:限制最大并发数,防止速率限制
semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY) semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY)
async def extract_with_semaphore(pg_hash, paragraph): async def extract_with_semaphore(pg_hash, paragraph):
"""带信号量控制的提取函数""" """带信号量控制的提取函数"""
async with semaphore: async with semaphore:
@@ -266,7 +266,7 @@ async def extract_information(paragraphs_dict, model_set):
extract_with_semaphore(p_hash, paragraph) extract_with_semaphore(p_hash, paragraph)
for p_hash, paragraph in paragraphs_dict.items() for p_hash, paragraph in paragraphs_dict.items()
] ]
total = len(tasks) total = len(tasks)
completed = 0 completed = 0
@@ -284,7 +284,7 @@ async def extract_information(paragraphs_dict, model_set):
TimeRemainingColumn(), TimeRemainingColumn(),
) as progress: ) as progress:
task = progress.add_task("[cyan]正在提取信息...", total=total) task = progress.add_task("[cyan]正在提取信息...", total=total)
# 🔧 优化:使用 asyncio.gather 并发执行所有任务 # 🔧 优化:使用 asyncio.gather 并发执行所有任务
# return_exceptions=True 确保单个失败不影响其他任务 # return_exceptions=True 确保单个失败不影响其他任务
for coro in asyncio.as_completed(tasks): for coro in asyncio.as_completed(tasks):
@@ -293,7 +293,7 @@ async def extract_information(paragraphs_dict, model_set):
failed_hashes.append(failed_hash) failed_hashes.append(failed_hash)
elif doc_item: elif doc_item:
open_ie_docs.append(doc_item) open_ie_docs.append(doc_item)
completed += 1 completed += 1
progress.update(task, advance=1) progress.update(task, advance=1)
@@ -415,7 +415,7 @@ def rebuild_faiss_only():
logger.info("--- 重建 FAISS 索引 ---") logger.info("--- 重建 FAISS 索引 ---")
# 重建索引不需要并发参数(不涉及 embedding 生成) # 重建索引不需要并发参数(不涉及 embedding 生成)
embed_manager = EmbeddingManager() embed_manager = EmbeddingManager()
logger.info("正在加载现有的 Embedding 库...") logger.info("正在加载现有的 Embedding 库...")
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()

View File

@@ -4,13 +4,13 @@
提供 Web API 用于可视化记忆图数据 提供 Web API 用于可视化记忆图数据
""" """
from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any
from collections import defaultdict
import orjson import orjson
from fastapi import APIRouter, HTTPException, Request, Query from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
@@ -29,7 +29,7 @@ router = APIRouter()
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates")) templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
def find_available_data_files() -> List[Path]: def find_available_data_files() -> list[Path]:
"""查找所有可用的记忆图数据文件""" """查找所有可用的记忆图数据文件"""
files = [] files = []
if not data_dir.exists(): if not data_dir.exists():
@@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]:
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True) return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]: def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据""" """从磁盘加载图数据"""
global graph_data_cache, current_data_file global graph_data_cache, current_data_file
@@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any
if not graph_file.exists(): if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}} return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
with open(graph_file, "r", encoding="utf-8") as f: with open(graph_file, encoding="utf-8") as f:
data = orjson.loads(f.read()) data = orjson.loads(f.read())
nodes = data.get("nodes", []) nodes = data.get("nodes", [])
@@ -150,7 +150,7 @@ async def index(request: Request):
return templates.TemplateResponse("visualizer.html", {"request": request}) return templates.TemplateResponse("visualizer.html", {"request": request})
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]: def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据""" """从 MemoryManager 提取并格式化图数据"""
if not memory_manager.graph_store: if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}} return {"nodes": [], "edges": [], "memories": [], "stats": {}}
@@ -188,7 +188,7 @@ def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
"arrows": "to", "arrows": "to",
"memory_id": memory.id, "memory_id": memory.id,
} }
edges_list = list(edges_dict.values()) edges_list = list(edges_dict.values())
stats = memory_manager.get_statistics() stats = memory_manager.get_statistics()
@@ -261,7 +261,7 @@ async def get_paginated_graph(
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"), page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"), min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"), node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
): ):
"""分页获取图数据,支持重要性过滤""" """分页获取图数据,支持重要性过滤"""
try: try:
@@ -301,13 +301,13 @@ async def get_paginated_graph(
total_pages = (total_nodes + page_size - 1) // page_size total_pages = (total_nodes + page_size - 1) // page_size
start_idx = (page - 1) * page_size start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_nodes) end_idx = min(start_idx + page_size, total_nodes)
paginated_nodes = nodes_with_importance[start_idx:end_idx] paginated_nodes = nodes_with_importance[start_idx:end_idx]
node_ids = set(n["id"] for n in paginated_nodes) node_ids = set(n["id"] for n in paginated_nodes)
# 只保留连接分页节点的边 # 只保留连接分页节点的边
paginated_edges = [ paginated_edges = [
e for e in edges e for e in edges
if e.get("from") in node_ids and e.get("to") in node_ids if e.get("from") in node_ids and e.get("to") in node_ids
] ]
@@ -383,7 +383,7 @@ async def get_clustered_graph(
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500) return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict: def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
"""简单的图聚类算法:按类型和连接度聚类""" """简单的图聚类算法:按类型和连接度聚类"""
# 构建邻接表 # 构建邻接表
adjacency = defaultdict(set) adjacency = defaultdict(set)
@@ -412,21 +412,21 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
for node in type_nodes: for node in type_nodes:
importance = len(adjacency[node["id"]]) importance = len(adjacency[node["id"]])
node_importance.append((node, importance)) node_importance.append((node, importance))
node_importance.sort(key=lambda x: x[1], reverse=True) node_importance.sort(key=lambda x: x[1], reverse=True)
# 保留前N个重要节点 # 保留前N个重要节点
keep_count = min(len(type_nodes), max_nodes // len(type_groups)) keep_count = min(len(type_nodes), max_nodes // len(type_groups))
for node, importance in node_importance[:keep_count]: for node, importance in node_importance[:keep_count]:
clustered_nodes.append(node) clustered_nodes.append(node)
node_mapping[node["id"]] = node["id"] node_mapping[node["id"]] = node["id"]
# 其余节点聚合为一个超级节点 # 其余节点聚合为一个超级节点
if len(node_importance) > keep_count: if len(node_importance) > keep_count:
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]] clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}" cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)" cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
clustered_nodes.append({ clustered_nodes.append({
"id": cluster_id, "id": cluster_id,
"label": cluster_label, "label": cluster_label,
@@ -436,7 +436,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
"cluster_size": len(clustered_node_ids), "cluster_size": len(clustered_node_ids),
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示 "clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
}) })
for node_id in clustered_node_ids: for node_id in clustered_node_ids:
node_mapping[node_id] = cluster_id node_mapping[node_id] = cluster_id
@@ -445,7 +445,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
for edge in edges: for edge in edges:
from_id = node_mapping.get(edge["from"]) from_id = node_mapping.get(edge["from"])
to_id = node_mapping.get(edge["to"]) to_id = node_mapping.get(edge["to"])
if from_id and to_id and from_id != to_id: if from_id and to_id and from_id != to_id:
edge_key = tuple(sorted([from_id, to_id])) edge_key = tuple(sorted([from_id, to_id]))
if edge_key not in edge_set: if edge_key not in edge_set:

View File

@@ -1,6 +1,5 @@
from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Literal from typing import Literal
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query

View File

@@ -161,16 +161,16 @@ class EmbeddingStore:
# 限制 chunk_size 和 max_workers 在合理范围内 # 限制 chunk_size 和 max_workers 在合理范围内
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE)) chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS)) max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
semaphore = asyncio.Semaphore(max_workers) semaphore = asyncio.Semaphore(max_workers)
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
results = {} results = {}
# 将字符串列表分成多个 chunk # 将字符串列表分成多个 chunk
chunks = [] chunks = []
for i in range(0, len(strs), chunk_size): for i in range(0, len(strs), chunk_size):
chunks.append(strs[i : i + chunk_size]) chunks.append(strs[i : i + chunk_size])
async def _process_chunk(chunk: list[str]): async def _process_chunk(chunk: list[str]):
"""处理一个 chunk 的字符串(批量获取 embedding""" """处理一个 chunk 的字符串(批量获取 embedding"""
async with semaphore: async with semaphore:
@@ -180,12 +180,12 @@ class EmbeddingStore:
embedding = await EmbeddingStore._get_embedding_async(llm, s) embedding = await EmbeddingStore._get_embedding_async(llm, s)
embeddings.append(embedding) embeddings.append(embedding)
results[s] = embedding results[s] = embedding
if progress_callback: if progress_callback:
progress_callback(len(chunk)) progress_callback(len(chunk))
return embeddings return embeddings
# 并发处理所有 chunks # 并发处理所有 chunks
tasks = [_process_chunk(chunk) for chunk in chunks] tasks = [_process_chunk(chunk) for chunk in chunks]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@@ -418,22 +418,22 @@ class EmbeddingStore:
# 🔧 修复:检查所有 embedding 的维度是否一致 # 🔧 修复:检查所有 embedding 的维度是否一致
dimensions = [len(emb) for emb in array] dimensions = [len(emb) for emb in array]
unique_dims = set(dimensions) unique_dims = set(dimensions)
if len(unique_dims) > 1: if len(unique_dims) > 1:
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}") logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}") logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
# 获取期望的维度(使用最常见的维度) # 获取期望的维度(使用最常见的维度)
from collections import Counter from collections import Counter
dim_counter = Counter(dimensions) dim_counter = Counter(dimensions)
expected_dim = dim_counter.most_common(1)[0][0] expected_dim = dim_counter.most_common(1)[0][0]
logger.warning(f"将使用最常见的维度: {expected_dim}") logger.warning(f"将使用最常见的维度: {expected_dim}")
# 过滤掉维度不匹配的 embedding # 过滤掉维度不匹配的 embedding
filtered_array = [] filtered_array = []
filtered_idx2hash = {} filtered_idx2hash = {}
skipped_count = 0 skipped_count = 0
for i, emb in enumerate(array): for i, emb in enumerate(array):
if len(emb) == expected_dim: if len(emb) == expected_dim:
filtered_array.append(emb) filtered_array.append(emb)
@@ -442,11 +442,11 @@ class EmbeddingStore:
skipped_count += 1 skipped_count += 1
hash_key = self.idx2hash[str(i)] hash_key = self.idx2hash[str(i)]
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}") logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding") logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
array = filtered_array array = filtered_array
self.idx2hash = filtered_idx2hash self.idx2hash = filtered_idx2hash
if not array: if not array:
logger.error("过滤后没有可用的 embedding无法构建索引") logger.error("过滤后没有可用的 embedding无法构建索引")
embedding_dim = expected_dim embedding_dim = expected_dim

View File

@@ -13,4 +13,4 @@ __all__ = [
"StreamLoopManager", "StreamLoopManager",
"message_manager", "message_manager",
"stream_loop_manager", "stream_loop_manager",
] ]

View File

@@ -82,7 +82,7 @@ class SingleStreamContextManager:
self.total_messages += 1 self.total_messages += 1
self.last_access_time = time.time() self.last_access_time = time.time()
# 如果使用了缓存系统,输出调试信息 # 如果使用了缓存系统,输出调试信息
if cache_enabled and self.context.is_cache_enabled: if cache_enabled and self.context.is_cache_enabled:
if self.context.is_chatter_processing: if self.context.is_chatter_processing:

View File

@@ -111,9 +111,9 @@ class StreamLoopManager:
# 获取或创建该流的启动锁 # 获取或创建该流的启动锁
if stream_id not in self._stream_start_locks: if stream_id not in self._stream_start_locks:
self._stream_start_locks[stream_id] = asyncio.Lock() self._stream_start_locks[stream_id] = asyncio.Lock()
lock = self._stream_start_locks[stream_id] lock = self._stream_start_locks[stream_id]
# 使用锁防止并发启动同一个流的多个循环任务 # 使用锁防止并发启动同一个流的多个循环任务
async with lock: async with lock:
# 获取流上下文 # 获取流上下文
@@ -148,7 +148,7 @@ class StreamLoopManager:
# 紧急取消 # 紧急取消
context.stream_loop_task.cancel() context.stream_loop_task.cancel()
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}") loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
# 将任务记录到 StreamContext 中 # 将任务记录到 StreamContext 中
@@ -252,7 +252,7 @@ class StreamLoopManager:
self.stats["total_process_cycles"] += 1 self.stats["total_process_cycles"] += 1
if success: if success:
logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功") logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功")
# 🔒 处理成功后,等待一小段时间确保清理操作完成 # 🔒 处理成功后,等待一小段时间确保清理操作完成
# 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环 # 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@@ -382,7 +382,7 @@ class StreamLoopManager:
self.chatter_manager.process_stream_context(stream_id, context), self.chatter_manager.process_stream_context(stream_id, context),
name=f"chatter_process_{stream_id}" name=f"chatter_process_{stream_id}"
) )
# 等待 chatter 任务完成 # 等待 chatter 任务完成
results = await chatter_task results = await chatter_task
success = results.get("success", False) success = results.get("success", False)
@@ -398,8 +398,8 @@ class StreamLoopManager:
else: else:
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}") logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
return success return success
except asyncio.CancelledError: except asyncio.CancelledError:
if chatter_task and not chatter_task.done(): if chatter_task and not chatter_task.done():
chatter_task.cancel() chatter_task.cancel()
raise raise
@@ -709,4 +709,4 @@ class StreamLoopManager:
# 全局流循环管理器实例 # 全局流循环管理器实例
stream_loop_manager = StreamLoopManager() stream_loop_manager = StreamLoopManager()

View File

@@ -417,7 +417,7 @@ class MessageManager:
return return
# 记录详细信息 # 记录详细信息
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}" msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
for msg in unread_messages[:3]] # 只显示前3条 for msg in unread_messages[:3]] # 只显示前3条
logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}") logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}")
@@ -446,15 +446,15 @@ class MessageManager:
context = chat_stream.context_manager.context context = chat_stream.context_manager.context
if hasattr(context, "unread_messages") and context.unread_messages: if hasattr(context, "unread_messages") and context.unread_messages:
unread_count = len(context.unread_messages) unread_count = len(context.unread_messages)
# 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们 # 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们
if unread_count > 0: if unread_count > 0:
# 获取所有未读消息的 ID # 获取所有未读消息的 ID
message_ids = [msg.message_id for msg in context.unread_messages] message_ids = [msg.message_id for msg in context.unread_messages]
# 标记为已读(会移到历史消息) # 标记为已读(会移到历史消息)
success = chat_stream.context_manager.mark_messages_as_read(message_ids) success = chat_stream.context_manager.mark_messages_as_read(message_ids)
if success: if success:
logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读") logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读")
else: else:
@@ -481,7 +481,7 @@ class MessageManager:
try: try:
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id) chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'): if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
chat_stream.context_manager.context.is_chatter_processing = is_processing chat_stream.context_manager.context.is_chatter_processing = is_processing
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}") logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
except Exception as e: except Exception as e:
@@ -517,7 +517,7 @@ class MessageManager:
try: try:
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id) chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'): if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
return chat_stream.context_manager.context.is_chatter_processing return chat_stream.context_manager.context.is_chatter_processing
except Exception: except Exception:
pass pass
@@ -677,4 +677,4 @@ class MessageManager:
# 创建全局消息管理器实例 # 创建全局消息管理器实例
message_manager = MessageManager() message_manager = MessageManager()

View File

@@ -248,16 +248,16 @@ class ChatterActionManager:
try: try:
# 根据动作类型确定提示词模式 # 根据动作类型确定提示词模式
prompt_mode = "s4u" if action_name == "reply" else "normal" prompt_mode = "s4u" if action_name == "reply" else "normal"
# 将prompt_mode传递给generate_reply # 将prompt_mode传递给generate_reply
action_data_with_mode = (action_data or {}).copy() action_data_with_mode = (action_data or {}).copy()
action_data_with_mode["prompt_mode"] = prompt_mode action_data_with_mode["prompt_mode"] = prompt_mode
# 只传递当前正在执行的动作,而不是所有可用动作 # 只传递当前正在执行的动作,而不是所有可用动作
# 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用" # 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用"
current_action_info = self._using_actions.get(action_name) current_action_info = self._using_actions.get(action_name)
current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {} current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {}
# 附加目标消息信息(如果存在) # 附加目标消息信息(如果存在)
if target_message: if target_message:
# 提取目标消息的关键信息 # 提取目标消息的关键信息
@@ -268,7 +268,7 @@ class ChatterActionManager:
"time": getattr(target_message, "time", 0), "time": getattr(target_message, "time", 0),
} }
current_actions["_target_message"] = target_msg_info current_actions["_target_message"] = target_msg_info
success, response_set, _ = await generator_api.generate_reply( success, response_set, _ = await generator_api.generate_reply(
chat_stream=chat_stream, chat_stream=chat_stream,
reply_message=target_message, reply_message=target_message,
@@ -295,12 +295,12 @@ class ChatterActionManager:
should_quote_reply = None should_quote_reply = None
if action_data and isinstance(action_data, dict): if action_data and isinstance(action_data, dict):
should_quote_reply = action_data.get("should_quote_reply", None) should_quote_reply = action_data.get("should_quote_reply", None)
# respond动作默认不引用回复保持对话流畅 # respond动作默认不引用回复保持对话流畅
if action_name == "respond" and should_quote_reply is None: if action_name == "respond" and should_quote_reply is None:
should_quote_reply = False should_quote_reply = False
async def _after_reply(): async def _after_reply():
# 发送并存储回复 # 发送并存储回复
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply( loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
chat_stream, chat_stream,

View File

@@ -365,7 +365,7 @@ class DefaultReplyer:
# 确保类型安全 # 确保类型安全
if isinstance(mode, str): if isinstance(mode, str):
prompt_mode_value = mode prompt_mode_value = mode
# 构建 Prompt # 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留 with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_reply_context( prompt = await self.build_prompt_reply_context(
@@ -1171,16 +1171,16 @@ class DefaultReplyer:
from src.plugin_system.apis.chat_api import get_chat_manager from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream_obj = await chat_manager.get_stream(chat_id) chat_stream_obj = await chat_manager.get_stream(chat_id)
if chat_stream_obj: if chat_stream_obj:
unread_messages = chat_stream_obj.context_manager.get_unread_messages() unread_messages = chat_stream_obj.context_manager.get_unread_messages()
if unread_messages: if unread_messages:
# 使用最后一条未读消息作为参考 # 使用最后一条未读消息作为参考
last_msg = unread_messages[-1] last_msg = unread_messages[-1]
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else "" user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else "" user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else "" user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
processed_plain_text = last_msg.processed_plain_text or "" processed_plain_text = last_msg.processed_plain_text or ""
else: else:
# 没有未读消息,使用默认值 # 没有未读消息,使用默认值
@@ -1263,19 +1263,19 @@ class DefaultReplyer:
if available_actions: if available_actions:
# 过滤掉特殊键以_开头 # 过滤掉特殊键以_开头
action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")} action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")}
# 提取目标消息信息(如果存在) # 提取目标消息信息(如果存在)
target_msg_info = available_actions.get("_target_message") # type: ignore target_msg_info = available_actions.get("_target_message") # type: ignore
if action_items: if action_items:
if len(action_items) == 1: if len(action_items) == 1:
# 单个动作 # 单个动作
action_name, action_info = list(action_items.items())[0] action_name, action_info = list(action_items.items())[0]
action_desc = action_info.description action_desc = action_info.description
# 构建基础决策信息 # 构建基础决策信息
action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n" action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n"
# 只有需要目标消息的动作才显示目标消息详情 # 只有需要目标消息的动作才显示目标消息详情
# respond 动作是统一回应所有未读消息,不应该显示特定目标消息 # respond 动作是统一回应所有未读消息,不应该显示特定目标消息
if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict): if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict):
@@ -1284,7 +1284,7 @@ class DefaultReplyer:
content = target_msg_info.get("content", "") content = target_msg_info.get("content", "")
msg_time = target_msg_info.get("time", 0) msg_time = target_msg_info.get("time", 0)
time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间" time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间"
action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n" action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n"
else: else:
# 多个动作 # 多个动作
@@ -2137,7 +2137,7 @@ class DefaultReplyer:
except Exception as e: except Exception as e:
logger.error(f"存储聊天记忆失败: {e}") logger.error(f"存储聊天记忆失败: {e}")
def weighted_sample_no_replacement(items, weights, k) -> list: def weighted_sample_no_replacement(items, weights, k) -> list:
""" """

View File

@@ -5,12 +5,12 @@
插件可以通过实现这些接口来扩展安全功能。 插件可以通过实现这些接口来扩展安全功能。
""" """
from .interfaces import SecurityCheckResult, SecurityChecker from .interfaces import SecurityChecker, SecurityCheckResult
from .manager import SecurityManager, get_security_manager from .manager import SecurityManager, get_security_manager
__all__ = [ __all__ = [
"SecurityChecker",
"SecurityCheckResult", "SecurityCheckResult",
"SecurityChecker",
"SecurityManager", "SecurityManager",
"get_security_manager", "get_security_manager",
] ]

View File

@@ -10,7 +10,7 @@ from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
logger = get_logger("security.manager") logger = get_logger("security.manager")

View File

@@ -1,5 +1,7 @@
import asyncio import asyncio
import copy
import re import re
from collections.abc import Awaitable, Callable
from src.chat.utils.prompt_params import PromptParameters from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -12,122 +14,205 @@ logger = get_logger("prompt_component_manager")
class PromptComponentManager: class PromptComponentManager:
""" """
管理所有 `BasePrompt` 组件的单例类 一个统一的、动态的、可观测的提示词组件管理中心
该管理器负责: 该管理器是整个提示词动态注入系统的核心,它负责:
1. 从 `component_registry` 中查询 `BasePrompt` 子类。 1. **规则加载**: 在系统启动时,自动扫描所有已注册的 `BasePrompt` 组件,
2. 根据注入点目标Prompt名称对它们进行筛选 并将其静态定义的 `injection_rules` 加载为默认的动态规则
3. 提供一个接口以便在构建核心Prompt时能够获取并执行所有相关的组件。 2. **动态管理**: 提供线程安全的 API允许在运行时动态地添加、更新或移除注入规则
使得提示词的结构可以被实时调整。
3. **状态观测**: 提供丰富的查询 API用于观测系统当前完整的注入状态
例如查询所有注入到特定目标的规则、或查询某个组件定义的所有规则。
4. **注入应用**: 在构建核心 Prompt 时,根据统一的、按优先级排序的规则集,
动态地修改和装配提示词模板,实现灵活的提示词组合。
""" """
def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, type[BasePrompt]]]: def __init__(self):
""" """初始化管理器实例。"""
获取指定目标Prompt的所有注入规则及其关联的组件类 # _dynamic_rules 是管理器的核心状态,存储所有注入规则
# 结构: {
# "target_prompt_name": {
# "prompt_component_name": (InjectionRule, content_provider, source)
# }
# }
# content_provider 是一个异步函数,用于在应用规则时动态生成注入内容。
# source 记录了规则的来源(例如 "static_default" 或 "runtime")。
self._dynamic_rules: dict[str, dict[str, tuple[InjectionRule, Callable[..., Awaitable[str]], str]]] = {}
self._lock = asyncio.Lock() # 使用异步锁确保对 _dynamic_rules 的并发访问安全。
self._initialized = False # 标记静态规则是否已加载,防止重复加载。
Args: # --- 核心生命周期与初始化 ---
target_prompt_name (str): 目标 Prompt 的名称。
Returns: def load_static_rules(self):
list[tuple[InjectionRule, Type[BasePrompt]]]: 一个元组列表,
每个元组包含一个注入规则和其对应的 Prompt 组件类,并已根据优先级排序。
""" """
# 从注册表中获取所有已启用的 PROMPT 类型的组件 在系统启动时加载所有静态注入规则。
该方法会扫描所有已在 `component_registry` 中注册并启用的 Prompt 组件,
将其类变量 `injection_rules` 转换为管理器的动态规则。
这确保了所有插件定义的默认注入行为在系统启动时就能生效。
此操作是幂等的,一旦初始化完成就不会重复执行。
"""
if self._initialized:
return
logger.info("正在加载静态 Prompt 注入规则...")
# 从组件注册表中获取所有已启用的 Prompt 组件
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT) enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
matching_rules = []
# 遍历所有启用的 Prompt 组件,查找与目标 Prompt 相关的注入规则
for prompt_name, prompt_info in enabled_prompts.items(): for prompt_name, prompt_info in enabled_prompts.items():
if not isinstance(prompt_info, PromptInfo): if not isinstance(prompt_info, PromptInfo):
continue continue
# prompt_info.injection_rules 已经经过了后向兼容处理,确保总是列表 component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
for rule in prompt_info.injection_rules: if not (component_class and issubclass(component_class, BasePrompt)):
# 如果规则的目标是当前指定的 Prompt logger.warning(f"无法为 '{prompt_name}' 加载静态规则,因为它不是一个有效的 Prompt 组件。")
if rule.target_prompt == target_prompt_name: continue
# 获取该规则对应的组件类
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
# 确保获取到的确实是一个 BasePrompt 的子类
if component_class and issubclass(component_class, BasePrompt):
matching_rules.append((rule, component_class))
# 根据规则的优先级进行排序,数字越小,优先级越高,越先应用 def create_provider(cls: type[BasePrompt]) -> Callable[[PromptParameters], Awaitable[str]]:
matching_rules.sort(key=lambda x: x[0].priority) """
return matching_rules 为静态组件创建一个内容提供者闭包 (Content Provider Closure)。
这个闭包捕获了组件的类 `cls`,并返回一个标准的 `content_provider` 异步函数。
当 `apply_injections` 需要内容时,它会调用这个函数。
函数内部会实例化组件,并执行其 `execute` 方法来获取注入内容。
Args:
cls (type[BasePrompt]): 需要为其创建提供者的 Prompt 组件类。
Returns:
Callable[[PromptParameters], Awaitable[str]]: 一个符合管理器标准的异步内容提供者。
"""
async def content_provider(params: PromptParameters) -> str:
"""实际执行内容生成的异步函数。"""
try:
# 从注册表获取最新的组件信息,包括插件配置
p_info = component_registry.get_component_info(cls.prompt_name, ComponentType.PROMPT)
plugin_config = {}
if isinstance(p_info, PromptInfo):
plugin_config = component_registry.get_plugin_config(p_info.plugin_name)
# 实例化组件并执行
instance = cls(params=params, plugin_config=plugin_config)
result = await instance.execute()
return str(result) if result is not None else ""
except Exception as e:
logger.error(f"执行静态规则提供者 '{cls.prompt_name}' 时出错: {e}", exc_info=True)
return "" # 出错时返回空字符串,避免影响主流程
return content_provider
# 为该组件的每条静态注入规则创建并注册一个动态规则
for rule in prompt_info.injection_rules:
provider = create_provider(component_class)
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
target_rules[prompt_name] = (rule, provider, "static_default")
self._initialized = True
logger.info(f"静态 Prompt 注入规则加载完成,共处理 {len(enabled_prompts)} 个组件。")
# --- 运行时规则管理 API ---
async def add_injection_rule(
self,
prompt_name: str,
rule: InjectionRule,
content_provider: Callable[..., Awaitable[str]],
source: str = "runtime",
) -> bool:
"""
动态添加或更新一条注入规则。
此方法允许在系统运行时,由外部逻辑(如插件、命令)向管理器中添加新的注入行为。
如果已存在同名组件针对同一目标的规则,此方法会覆盖旧规则。
Args:
prompt_name (str): 动态注入组件的唯一名称。
rule (InjectionRule): 描述注入行为的规则对象。
content_provider (Callable[..., Awaitable[str]]):
一个异步函数,用于在应用注入时动态生成内容。
函数签名应为: `async def provider(params: "PromptParameters") -> str`
source (str, optional): 规则的来源标识,默认为 "runtime"
Returns:
bool: 如果成功添加或更新,则返回 True。
"""
async with self._lock:
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
target_rules[prompt_name] = (rule, content_provider, source)
logger.info(f"成功添加/更新注入规则: '{prompt_name}' -> '{rule.target_prompt}' (来源: {source})")
return True
async def remove_injection_rule(self, prompt_name: str, target_prompt: str) -> bool:
"""
移除一条动态注入规则。
Args:
prompt_name (str): 要移除的注入组件的名称。
target_prompt (str): 该组件注入的目标核心提示词名称。
Returns:
bool: 如果成功移除,则返回 True如果规则不存在则返回 False。
"""
async with self._lock:
if target_prompt in self._dynamic_rules and prompt_name in self._dynamic_rules[target_prompt]:
del self._dynamic_rules[target_prompt][prompt_name]
# 如果目标下已无任何规则,则清理掉这个键
if not self._dynamic_rules[target_prompt]:
del self._dynamic_rules[target_prompt]
logger.info(f"成功移除注入规则: '{prompt_name}' from '{target_prompt}'")
return True
logger.warning(f"尝试移除注入规则失败: 未找到 '{prompt_name}' on '{target_prompt}'")
return False
# --- 核心注入逻辑 ---
async def apply_injections( async def apply_injections(
self, target_prompt_name: str, original_template: str, params: PromptParameters self, target_prompt_name: str, original_template: str, params: PromptParameters
) -> str: ) -> str:
""" """
获取、实例化并执行所有相关组件,然后根据注入规则修改原始模板。 【核心方法】根据目标名称,应用所有匹配的注入规则,返回修改后的模板。
这是一个三步走的过程 这是提示词构建流程中的关键步骤。它会执行以下操作
1. 实例化所有需要执行的组件 1. 检查并确保静态规则已加载
2. 并行执行它们的 `execute` 方法以获取注入内容 2. 获取所有注入到 `target_prompt_name` 的规则
3. 按照优先级顺序,将内容注入到原始模板中 3. 按照规则的 `priority` 属性进行升序排序,优先级数字越小越先应用
4. 依次执行每个规则的 `content_provider` 来异步获取注入内容。
5. 根据规则的 `injection_type` (如 PREPEND, APPEND, REPLACE 等) 将内容应用到模板上。
Args: Args:
target_prompt_name (str): 目标 Prompt 的名称。 target_prompt_name (str): 目标核心提示词的名称。
original_template (str): 原始的、未经修改的 Prompt 模板字符串 original_template (str): 未经修改的原始提示词模板
params (PromptParameters): 传递给 Prompt 组件实例的参数 params (PromptParameters): 当前请求的参数,会传递给 `content_provider`
Returns: Returns:
str: 应用了所有注入规则后,修改过的 Prompt 模板字符串。 str: 应用了所有注入规则后,最终生成的提示词模板字符串。
""" """
rules_with_classes = self._get_rules_for(target_prompt_name) if not self._initialized:
# 如果没有找到任何匹配的规则,就直接返回原始模板,啥也不干 self.load_static_rules()
if not rules_with_classes:
# 步骤 1: 获取所有指向当前目标的规则
# 使用 .values() 获取 (rule, provider, source) 元组列表
rules_for_target = list(self._dynamic_rules.get(target_prompt_name, {}).values())
if not rules_for_target:
return original_template return original_template
# --- 第一步: 实例化所有需要执行的组件 --- # 步骤 2: 按优先级排序,数字越小越优先
instance_map = {} # 存储组件实例,虽然目前没直接用,但留着总没错 rules_for_target.sort(key=lambda x: x[0].priority)
tasks = [] # 存放所有需要并行执行的 execute 异步任务
components_to_execute = [] # 存放需要执行的组件类,用于后续结果映射
for rule, component_class in rules_with_classes: # 步骤 3: 依次执行内容提供者并根据注入类型修改模板
# 如果注入类型是 REMOVE那就不需要执行组件了因为它不产生内容 modified_template = original_template
for rule, provider, source in rules_for_target:
content = ""
# 对于非 REMOVE 类型的注入,需要先获取内容
if rule.injection_type != InjectionType.REMOVE: if rule.injection_type != InjectionType.REMOVE:
try: try:
# 获取组件的元信息,主要是为了拿到插件名称来读取插件配置 content = await provider(params)
prompt_info = component_registry.get_component_info(
component_class.prompt_name, ComponentType.PROMPT
)
if not isinstance(prompt_info, PromptInfo):
plugin_config = {}
else:
# 从注册表获取该组件所属插件的配置
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
# 实例化组件,并传入参数和插件配置
instance = component_class(params=params, plugin_config=plugin_config)
instance_map[component_class.prompt_name] = instance
# 将组件的 execute 方法作为一个任务添加到列表中
tasks.append(instance.execute())
components_to_execute.append(component_class)
except Exception as e: except Exception as e:
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}") logger.error(f"执行规则 '{rule}' (来源: {source}) 的内容提供者时失败: {e}", exc_info=True)
# 即使失败,也添加一个立即完成的空任务,以保持与其他任务的索引同步 continue # 跳过失败的 provider不中断整个流程
tasks.append(asyncio.create_task(asyncio.sleep(0, result=e))) # type: ignore
# --- 第二步: 并行执行所有组件的 execute 方法 ---
# 使用 asyncio.gather 来同时运行所有任务,提高效率
results = await asyncio.gather(*tasks, return_exceptions=True)
# 创建一个从组件名到执行结果的映射,方便后续查找
result_map = {
components_to_execute[i].prompt_name: res
for i, res in enumerate(results)
if not isinstance(res, Exception) # 只包含成功的结果
}
# 单独处理并记录执行失败的组件
for i, res in enumerate(results):
if isinstance(res, Exception):
logger.error(f"执行 Prompt 组件 '{components_to_execute[i].prompt_name}' 失败: {res}")
# --- 第三步: 按优先级顺序应用注入规则 ---
modified_template = original_template
for rule, component_class in rules_with_classes:
# 从结果映射中获取该组件生成的内容
content = result_map.get(component_class.prompt_name)
# 应用注入逻辑
try: try:
if rule.injection_type == InjectionType.PREPEND: if rule.injection_type == InjectionType.PREPEND:
if content: if content:
@@ -136,28 +221,178 @@ class PromptComponentManager:
if content: if content:
modified_template = f"{modified_template}\n{content}" modified_template = f"{modified_template}\n{content}"
elif rule.injection_type == InjectionType.REPLACE: elif rule.injection_type == InjectionType.REPLACE:
# 使用正则表达式替换目标内容 # 只有在 content 不为 None 且 target_content 有效时才执行替换
if content and rule.target_content: if content is not None and rule.target_content:
modified_template = re.sub(rule.target_content, str(content), modified_template) modified_template = re.sub(rule.target_content, str(content), modified_template)
elif rule.injection_type == InjectionType.INSERT_AFTER: elif rule.injection_type == InjectionType.INSERT_AFTER:
# 在匹配到的内容后面插入
if content and rule.target_content: if content and rule.target_content:
# re.sub a little trick: \g<0> represents the entire matched string # 使用 `\g<0>` 在正则匹配的整个内容后添加新内容
replacement = f"\\g<0>\n{content}" replacement = f"\\g<0>\n{content}"
modified_template = re.sub(rule.target_content, replacement, modified_template) modified_template = re.sub(rule.target_content, replacement, modified_template)
elif rule.injection_type == InjectionType.REMOVE: elif rule.injection_type == InjectionType.REMOVE:
# 使用正则表达式移除目标内容
if rule.target_content: if rule.target_content:
modified_template = re.sub(rule.target_content, "", modified_template) modified_template = re.sub(rule.target_content, "", modified_template)
except re.error as e: except re.error as e:
logger.error( logger.error(f"应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')")
f"在为 '{component_class.prompt_name}' 应用规则时发生正则错误: {e} (pattern: '{rule.target_content}')"
)
except Exception as e: except Exception as e:
logger.error(f"应用 Prompt 注入规则 '{rule}' 失败: {e}") logger.error(f"应用注入规则 '{rule}' (来源: {source}) 失败: {e}", exc_info=True)
return modified_template return modified_template
async def preview_prompt_injections(
self, target_prompt_name: str, params: PromptParameters
) -> str:
"""
【预览功能】模拟应用所有注入规则,返回最终生成的模板字符串,而不实际修改任何状态。
# 创建全局单例 这个方法对于调试和测试非常有用,可以查看在特定参数下,
一个核心提示词经过所有注入规则处理后会变成什么样子。
Args:
target_prompt_name (str): 希望预览的目标核心提示词名称。
params (PromptParameters): 模拟的请求参数。
Returns:
str: 模拟生成的最终提示词模板字符串。如果找不到模板,则返回错误信息。
"""
try:
# 从全局提示词管理器获取最原始的模板内容
from src.chat.utils.prompt import global_prompt_manager
original_prompt = global_prompt_manager._prompts.get(target_prompt_name)
if not original_prompt:
logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。")
return f"Error: Prompt '{target_prompt_name}' not found."
original_template = original_prompt.template
except KeyError:
logger.warning(f"无法预览 '{target_prompt_name}',因为找不到这个核心 Prompt。")
return f"Error: Prompt '{target_prompt_name}' not found."
# 直接调用核心注入逻辑来模拟结果
return await self.apply_injections(target_prompt_name, original_template, params)
# --- 状态观测与查询 API ---
def get_core_prompts(self) -> list[str]:
"""获取所有已注册的核心提示词模板名称列表(即所有可注入的目标)。"""
from src.chat.utils.prompt import global_prompt_manager
return list(global_prompt_manager._prompts.keys())
def get_core_prompt_contents(self) -> dict[str, str]:
"""获取所有核心提示词模板的原始内容。"""
from src.chat.utils.prompt import global_prompt_manager
return {name: prompt.template for name, prompt in global_prompt_manager._prompts.items()}
def get_registered_prompt_component_info(self) -> list[PromptInfo]:
"""获取所有在 ComponentRegistry 中注册的 Prompt 组件信息。"""
components = component_registry.get_components_by_type(ComponentType.PROMPT).values()
return [info for info in components if isinstance(info, PromptInfo)]
async def get_full_injection_map(self) -> dict[str, list[dict]]:
"""
获取当前完整的注入映射图。
此方法提供了一个系统全局的注入视图展示了每个核心提示词target
被哪些注入组件source以何种优先级注入。
Returns:
dict[str, list[dict]]: 一个字典,键是目标提示词名称,
值是按优先级排序的注入信息列表。
`[{"name": str, "priority": int, "source": str}]`
"""
injection_map = {}
async with self._lock:
# 合并所有动态规则的目标和所有核心提示词,确保所有潜在目标都被包含
all_targets = set(self._dynamic_rules.keys()) | set(self.get_core_prompts())
for target in sorted(all_targets):
rules = self._dynamic_rules.get(target, {})
if not rules:
injection_map[target] = []
continue
info_list = []
for prompt_name, (rule, _, source) in rules.items():
info_list.append({"name": prompt_name, "priority": rule.priority, "source": source})
# 按优先级排序后存入 map
info_list.sort(key=lambda x: x["priority"])
injection_map[target] = info_list
return injection_map
async def get_injections_for_prompt(self, target_prompt_name: str) -> list[dict]:
"""
获取指定核心提示词模板的所有注入信息(包含详细规则)。
Args:
target_prompt_name (str): 目标核心提示词的名称。
Returns:
list[dict]: 一个包含注入规则详细信息的列表,已按优先级排序。
"""
rules_for_target = self._dynamic_rules.get(target_prompt_name, {})
if not rules_for_target:
return []
info_list = []
for prompt_name, (rule, _, source) in rules_for_target.items():
info_list.append(
{
"name": prompt_name,
"priority": rule.priority,
"source": source,
"injection_type": rule.injection_type.value,
"target_content": rule.target_content,
}
)
info_list.sort(key=lambda x: x["priority"])
return info_list
def get_all_dynamic_rules(self) -> dict[str, dict[str, "InjectionRule"]]:
"""
获取所有当前的动态注入规则,以 InjectionRule 对象形式返回。
此方法返回一个深拷贝的规则副本,隐藏了 `content_provider` 等内部实现细节。
适合用于展示或序列化当前的规则配置。
"""
rules_copy = {}
for target, rules in self._dynamic_rules.items():
target_copy = {name: rule for name, (rule, _, _) in rules.items()}
rules_copy[target] = target_copy
return copy.deepcopy(rules_copy)
def get_rules_for_target(self, target_prompt: str) -> dict[str, InjectionRule]:
"""
获取所有注入到指定核心提示词的动态规则。
Args:
target_prompt (str): 目标核心提示词的名称。
Returns:
dict[str, InjectionRule]: 一个字典,键是注入组件的名称,值是 `InjectionRule` 对象。
如果找不到任何注入到该目标的规则,则返回一个空字典。
"""
target_rules = self._dynamic_rules.get(target_prompt, {})
return {name: copy.deepcopy(rule_info[0]) for name, rule_info in target_rules.items()}
def get_rules_by_component(self, component_name: str) -> dict[str, InjectionRule]:
"""
获取由指定的单个注入组件定义的所有动态规则。
Args:
component_name (str): 注入组件的名称。
Returns:
dict[str, InjectionRule]: 一个字典,键是目标核心提示词的名称,值是 `InjectionRule` 对象。
如果该组件没有定义任何注入规则,则返回一个空字典。
"""
found_rules = {}
for target, rules in self._dynamic_rules.items():
if component_name in rules:
rule_info = rules[component_name]
found_rules[target] = copy.deepcopy(rule_info[0])
return found_rules
# 创建全局单例 (Singleton)
# 在整个应用程序中,应该只使用这一个 `prompt_component_manager` 实例,
# 以确保所有部分都共享和操作同一份动态规则集。
prompt_component_manager = PromptComponentManager() prompt_component_manager = PromptComponentManager()

View File

@@ -98,7 +98,7 @@ class StreamContext(BaseDataModel):
break break
def mark_message_as_read(self, message_id: str): def mark_message_as_read(self, message_id: str):
"""标记消息为已读""" """标记消息为已读"""
# 先找到要标记的消息(处理 int/str 类型不匹配问题) # 先找到要标记的消息(处理 int/str 类型不匹配问题)
message_to_mark = None message_to_mark = None
for msg in self.unread_messages: for msg in self.unread_messages:
@@ -106,7 +106,7 @@ class StreamContext(BaseDataModel):
if str(msg.message_id) == str(message_id): if str(msg.message_id) == str(message_id):
message_to_mark = msg message_to_mark = msg
break break
# 然后移动到历史消息 # 然后移动到历史消息
if message_to_mark: if message_to_mark:
message_to_mark.is_read = True message_to_mark.is_read = True

View File

@@ -9,11 +9,12 @@
""" """
import asyncio import asyncio
import builtins
import time import time
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union from typing import Any, Generic, TypeVar
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.memory_utils import estimate_size_smart from src.common.memory_utils import estimate_size_smart
@@ -96,7 +97,7 @@ class LRUCache(Generic[T]):
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._stats = CacheStats() self._stats = CacheStats()
async def get(self, key: str) -> Optional[T]: async def get(self, key: str) -> T | None:
"""获取缓存值 """获取缓存值
Args: Args:
@@ -137,8 +138,8 @@ class LRUCache(Generic[T]):
self, self,
key: str, key: str,
value: T, value: T,
size: Optional[int] = None, size: int | None = None,
ttl: Optional[float] = None, ttl: float | None = None,
) -> None: ) -> None:
"""设置缓存值 """设置缓存值
@@ -287,8 +288,8 @@ class MultiLevelCache:
async def get( async def get(
self, self,
key: str, key: str,
loader: Optional[Callable[[], Any]] = None, loader: Callable[[], Any] | None = None,
) -> Optional[Any]: ) -> Any | None:
"""从缓存获取数据 """从缓存获取数据
查询顺序L1 -> L2 -> loader 查询顺序L1 -> L2 -> loader
@@ -329,8 +330,8 @@ class MultiLevelCache:
self, self,
key: str, key: str,
value: Any, value: Any,
size: Optional[int] = None, size: int | None = None,
ttl: Optional[float] = None, ttl: float | None = None,
) -> None: ) -> None:
"""设置缓存值 """设置缓存值
@@ -390,7 +391,7 @@ class MultiLevelCache:
await self.l2_cache.clear() await self.l2_cache.clear()
logger.info("所有缓存已清空") logger.info("所有缓存已清空")
async def get_stats(self) -> Dict[str, Any]: async def get_stats(self) -> dict[str, Any]:
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)""" """获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
# 🔧 修复:并行获取统计信息,避免锁嵌套 # 🔧 修复:并行获取统计信息,避免锁嵌套
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1")) l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
@@ -492,7 +493,7 @@ class MultiLevelCache:
logger.error(f"{cache_name}统计获取异常: {e}") logger.error(f"{cache_name}统计获取异常: {e}")
return CacheStats() return CacheStats()
async def _get_cache_keys_safe(self, cache) -> Set[str]: async def _get_cache_keys_safe(self, cache) -> builtins.set[str]:
"""安全获取缓存键集合(带超时)""" """安全获取缓存键集合(带超时)"""
try: try:
# 快速获取键集合,使用超时避免死锁 # 快速获取键集合,使用超时避免死锁
@@ -507,12 +508,12 @@ class MultiLevelCache:
logger.error(f"缓存键获取异常: {e}") logger.error(f"缓存键获取异常: {e}")
return set() return set()
async def _extract_keys_with_lock(self, cache) -> Set[str]: async def _extract_keys_with_lock(self, cache) -> builtins.set[str]:
"""在锁保护下提取键集合""" """在锁保护下提取键集合"""
async with cache._lock: async with cache._lock:
return set(cache._cache.keys()) return set(cache._cache.keys())
async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int: async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int:
"""安全计算内存使用(带超时)""" """安全计算内存使用(带超时)"""
if not keys: if not keys:
return 0 return 0
@@ -529,7 +530,7 @@ class MultiLevelCache:
logger.error(f"内存计算异常: {e}") logger.error(f"内存计算异常: {e}")
return 0 return 0
async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int: async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int:
"""在锁保护下计算内存使用""" """在锁保护下计算内存使用"""
total_size = 0 total_size = 0
async with cache._lock: async with cache._lock:
@@ -749,7 +750,7 @@ class MultiLevelCache:
# 全局缓存实例 # 全局缓存实例
_global_cache: Optional[MultiLevelCache] = None _global_cache: MultiLevelCache | None = None
_cache_lock = asyncio.Lock() _cache_lock = asyncio.Lock()

View File

@@ -3,7 +3,6 @@ import socket
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from rich.traceback import install from rich.traceback import install
from uvicorn import Config from uvicorn import Config
from uvicorn import Server as UvicornServer from uvicorn import Server as UvicornServer

View File

@@ -436,7 +436,7 @@ class OpenaiClient(BaseClient):
# 🔧 优化增加连接池限制支持高并发embedding请求 # 🔧 优化增加连接池限制支持高并发embedding请求
# 默认httpx限制为100对于高频embedding场景不够用 # 默认httpx限制为100对于高频embedding场景不够用
import httpx import httpx
limits = httpx.Limits( limits = httpx.Limits(
max_keepalive_connections=200, # 保持活跃连接数原100 max_keepalive_connections=200, # 保持活跃连接数原100
max_connections=300, # 最大总连接数原100 max_connections=300, # 最大总连接数原100

View File

@@ -128,7 +128,7 @@ class MemoryBuilder:
# 6. 构建 Memory 对象 # 6. 构建 Memory 对象
# 新记忆应该有较高的初始激活度 # 新记忆应该有较高的初始激活度
initial_activation = 0.75 # 新记忆初始激活度为 0.75 initial_activation = 0.75 # 新记忆初始激活度为 0.75
memory = Memory( memory = Memory(
id=memory_id, id=memory_id,
subject_id=subject_node.id, subject_id=subject_node.id,

View File

@@ -149,7 +149,7 @@ class MemoryManager:
# 读取阈值过滤配置 # 读取阈值过滤配置
search_min_importance = self.config.search_min_importance search_min_importance = self.config.search_min_importance
search_similarity_threshold = self.config.search_similarity_threshold search_similarity_threshold = self.config.search_similarity_threshold
logger.info( logger.info(
f"📊 配置检查: search_max_expand_depth={expand_depth}, " f"📊 配置检查: search_max_expand_depth={expand_depth}, "
f"search_expand_semantic_threshold={expand_semantic_threshold}, " f"search_expand_semantic_threshold={expand_semantic_threshold}, "
@@ -417,7 +417,7 @@ class MemoryManager:
# 使用配置的默认值 # 使用配置的默认值
if top_k is None: if top_k is None:
top_k = getattr(self.config, "search_top_k", 10) top_k = getattr(self.config, "search_top_k", 10)
# 准备搜索参数 # 准备搜索参数
params = { params = {
"query": query, "query": query,
@@ -951,7 +951,7 @@ class MemoryManager:
) )
else: else:
logger.debug(f"记忆已删除: {memory_id} (删除了 {deleted_vectors} 个向量)") logger.debug(f"记忆已删除: {memory_id} (删除了 {deleted_vectors} 个向量)")
# 4. 保存更新 # 4. 保存更新
await self.persistence.save_graph_store(self.graph_store) await self.persistence.save_graph_store(self.graph_store)
return True return True
@@ -984,7 +984,7 @@ class MemoryManager:
try: try:
forgotten_count = 0 forgotten_count = 0
all_memories = self.graph_store.get_all_memories() all_memories = self.graph_store.get_all_memories()
# 获取配置参数 # 获取配置参数
min_importance = getattr(self.config, "forgetting_min_importance", 0.8) min_importance = getattr(self.config, "forgetting_min_importance", 0.8)
decay_rate = getattr(self.config, "activation_decay_rate", 0.9) decay_rate = getattr(self.config, "activation_decay_rate", 0.9)
@@ -1010,10 +1010,10 @@ class MemoryManager:
try: try:
last_access_dt = datetime.fromisoformat(last_access) last_access_dt = datetime.fromisoformat(last_access)
days_passed = (datetime.now() - last_access_dt).days days_passed = (datetime.now() - last_access_dt).days
# 应用指数衰减activation = base * (decay_rate ^ days) # 应用指数衰减activation = base * (decay_rate ^ days)
current_activation = base_activation * (decay_rate ** days_passed) current_activation = base_activation * (decay_rate ** days_passed)
logger.debug( logger.debug(
f"记忆 {memory.id[:8]}: 基础激活度={base_activation:.3f}, " f"记忆 {memory.id[:8]}: 基础激活度={base_activation:.3f}, "
f"经过{days_passed}天衰减后={current_activation:.3f}" f"经过{days_passed}天衰减后={current_activation:.3f}"
@@ -1035,20 +1035,20 @@ class MemoryManager:
# 批量遗忘记忆(不立即清理孤立节点) # 批量遗忘记忆(不立即清理孤立节点)
if memories_to_forget: if memories_to_forget:
logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...") logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...")
for memory_id, activation in memories_to_forget: for memory_id, activation in memories_to_forget:
# cleanup_orphans=False暂不清理孤立节点 # cleanup_orphans=False暂不清理孤立节点
success = await self.forget_memory(memory_id, cleanup_orphans=False) success = await self.forget_memory(memory_id, cleanup_orphans=False)
if success: if success:
forgotten_count += 1 forgotten_count += 1
# 统一清理孤立节点和边 # 统一清理孤立节点和边
logger.info("批量遗忘完成,开始统一清理孤立节点和边...") logger.info("批量遗忘完成,开始统一清理孤立节点和边...")
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges() orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
# 保存最终更新 # 保存最终更新
await self.persistence.save_graph_store(self.graph_store) await self.persistence.save_graph_store(self.graph_store)
logger.info( logger.info(
f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, " f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, "
f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边" f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边"
@@ -1079,31 +1079,31 @@ class MemoryManager:
# 1. 清理孤立节点 # 1. 清理孤立节点
# graph_store.node_to_memories 记录了每个节点属于哪些记忆 # graph_store.node_to_memories 记录了每个节点属于哪些记忆
nodes_to_remove = [] nodes_to_remove = []
for node_id, memory_ids in list(self.graph_store.node_to_memories.items()): for node_id, memory_ids in list(self.graph_store.node_to_memories.items()):
# 如果节点不再属于任何记忆,标记为删除 # 如果节点不再属于任何记忆,标记为删除
if not memory_ids: if not memory_ids:
nodes_to_remove.append(node_id) nodes_to_remove.append(node_id)
# 从图中删除孤立节点 # 从图中删除孤立节点
for node_id in nodes_to_remove: for node_id in nodes_to_remove:
if self.graph_store.graph.has_node(node_id): if self.graph_store.graph.has_node(node_id):
self.graph_store.graph.remove_node(node_id) self.graph_store.graph.remove_node(node_id)
orphan_nodes_count += 1 orphan_nodes_count += 1
# 从映射中删除 # 从映射中删除
if node_id in self.graph_store.node_to_memories: if node_id in self.graph_store.node_to_memories:
del self.graph_store.node_to_memories[node_id] del self.graph_store.node_to_memories[node_id]
# 2. 清理孤立边(指向已删除节点的边) # 2. 清理孤立边(指向已删除节点的边)
edges_to_remove = [] edges_to_remove = []
for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'): for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
# 检查边的源节点和目标节点是否还存在于node_to_memories中 # 检查边的源节点和目标节点是否还存在于node_to_memories中
if source not in self.graph_store.node_to_memories or \ if source not in self.graph_store.node_to_memories or \
target not in self.graph_store.node_to_memories: target not in self.graph_store.node_to_memories:
edges_to_remove.append((source, target)) edges_to_remove.append((source, target))
# 删除孤立边 # 删除孤立边
for source, target in edges_to_remove: for source, target in edges_to_remove:
try: try:
@@ -1111,12 +1111,12 @@ class MemoryManager:
orphan_edges_count += 1 orphan_edges_count += 1
except Exception as e: except Exception as e:
logger.debug(f"删除边失败 {source} -> {target}: {e}") logger.debug(f"删除边失败 {source} -> {target}: {e}")
if orphan_nodes_count > 0 or orphan_edges_count > 0: if orphan_nodes_count > 0 or orphan_edges_count > 0:
logger.info( logger.info(
f"清理完成: {orphan_nodes_count} 个孤立节点, {orphan_edges_count} 条孤立边" f"清理完成: {orphan_nodes_count} 个孤立节点, {orphan_edges_count} 条孤立边"
) )
return orphan_nodes_count, orphan_edges_count return orphan_nodes_count, orphan_edges_count
except Exception as e: except Exception as e:
@@ -1258,7 +1258,7 @@ class MemoryManager:
mem for mem in recent_memories mem for mem in recent_memories
if mem.importance >= min_importance_for_consolidation if mem.importance >= min_importance_for_consolidation
] ]
result["importance_filtered"] = len(recent_memories) - len(important_memories) result["importance_filtered"] = len(recent_memories) - len(important_memories)
logger.info( logger.info(
f"📊 步骤2: 重要性过滤 (阈值={min_importance_for_consolidation:.2f}): " f"📊 步骤2: 重要性过滤 (阈值={min_importance_for_consolidation:.2f}): "
@@ -1382,26 +1382,26 @@ class MemoryManager:
# ===== 步骤4: 向量检索关联记忆 + LLM分析关系 ===== # ===== 步骤4: 向量检索关联记忆 + LLM分析关系 =====
# 过滤掉已删除的记忆 # 过滤掉已删除的记忆
remaining_memories = [m for m in important_memories if m.id not in deleted_ids] remaining_memories = [m for m in important_memories if m.id not in deleted_ids]
if not remaining_memories: if not remaining_memories:
logger.info("✅ 记忆整理完成: 去重后无剩余记忆") logger.info("✅ 记忆整理完成: 去重后无剩余记忆")
return return
logger.info(f"📍 步骤4: 开始关联分析 ({len(remaining_memories)} 条记忆)...") logger.info(f"📍 步骤4: 开始关联分析 ({len(remaining_memories)} 条记忆)...")
# 分批处理记忆关联 # 分批处理记忆关联
llm_batch_size = getattr(self.config, "consolidation_llm_batch_size", 10) llm_batch_size = getattr(self.config, "consolidation_llm_batch_size", 10)
max_candidates_per_memory = getattr(self.config, "consolidation_max_candidates", 5) max_candidates_per_memory = getattr(self.config, "consolidation_max_candidates", 5)
min_confidence = getattr(self.config, "consolidation_min_confidence", 0.6) min_confidence = getattr(self.config, "consolidation_min_confidence", 0.6)
all_new_edges = [] # 收集所有新建的边 all_new_edges = [] # 收集所有新建的边
for batch_start in range(0, len(remaining_memories), llm_batch_size): for batch_start in range(0, len(remaining_memories), llm_batch_size):
batch_end = min(batch_start + llm_batch_size, len(remaining_memories)) batch_end = min(batch_start + llm_batch_size, len(remaining_memories))
batch = remaining_memories[batch_start:batch_end] batch = remaining_memories[batch_start:batch_end]
logger.debug(f"处理批次 {batch_start//llm_batch_size + 1}/{(len(remaining_memories)-1)//llm_batch_size + 1}") logger.debug(f"处理批次 {batch_start//llm_batch_size + 1}/{(len(remaining_memories)-1)//llm_batch_size + 1}")
for memory in batch: for memory in batch:
# 跳过已经有很多连接的记忆 # 跳过已经有很多连接的记忆
existing_edges = len([ existing_edges = len([
@@ -1454,14 +1454,14 @@ class MemoryManager:
except Exception as e: except Exception as e:
logger.warning(f"创建关联边失败: {e}") logger.warning(f"创建关联边失败: {e}")
continue continue
# 每个批次后让出控制权 # 每个批次后让出控制权
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
# ===== 步骤5: 统一更新记忆数据 ===== # ===== 步骤5: 统一更新记忆数据 =====
if all_new_edges: if all_new_edges:
logger.info(f"📍 步骤5: 统一更新 {len(all_new_edges)} 条新关联边...") logger.info(f"📍 步骤5: 统一更新 {len(all_new_edges)} 条新关联边...")
for memory, edge, relation in all_new_edges: for memory, edge, relation in all_new_edges:
try: try:
# 添加到图 # 添加到图
@@ -2301,7 +2301,7 @@ class MemoryManager:
# 使用 asyncio.wait_for 来支持取消 # 使用 asyncio.wait_for 来支持取消
await asyncio.wait_for( await asyncio.wait_for(
asyncio.sleep(initial_delay), asyncio.sleep(initial_delay),
timeout=float('inf') # 允许随时取消 timeout=float("inf") # 允许随时取消
) )
# 检查是否仍然需要运行 # 检查是否仍然需要运行

View File

@@ -482,7 +482,7 @@ class GraphStore:
for node in memory.nodes: for node in memory.nodes:
if node.id in self.node_to_memories: if node.id in self.node_to_memories:
self.node_to_memories[node.id].discard(memory_id) self.node_to_memories[node.id].discard(memory_id)
# 可选:立即清理孤立节点 # 可选:立即清理孤立节点
if cleanup_orphans: if cleanup_orphans:
# 如果该节点不再属于任何记忆,从图中移除节点 # 如果该节点不再属于任何记忆,从图中移除节点

View File

@@ -72,12 +72,12 @@ class MemoryTools:
self.max_expand_depth = max_expand_depth self.max_expand_depth = max_expand_depth
self.expand_semantic_threshold = expand_semantic_threshold self.expand_semantic_threshold = expand_semantic_threshold
self.search_top_k = search_top_k self.search_top_k = search_top_k
# 保存权重配置 # 保存权重配置
self.base_vector_weight = search_vector_weight self.base_vector_weight = search_vector_weight
self.base_importance_weight = search_importance_weight self.base_importance_weight = search_importance_weight
self.base_recency_weight = search_recency_weight self.base_recency_weight = search_recency_weight
# 保存阈值过滤配置 # 保存阈值过滤配置
self.search_min_importance = search_min_importance self.search_min_importance = search_min_importance
self.search_similarity_threshold = search_similarity_threshold self.search_similarity_threshold = search_similarity_threshold
@@ -516,14 +516,14 @@ class MemoryTools:
# 1. 根据策略选择检索方式 # 1. 根据策略选择检索方式
llm_prefer_types = [] # LLM识别的偏好节点类型 llm_prefer_types = [] # LLM识别的偏好节点类型
if use_multi_query: if use_multi_query:
# 多查询策略(返回节点列表 + 偏好类型) # 多查询策略(返回节点列表 + 偏好类型)
similar_nodes, llm_prefer_types = await self._multi_query_search(query, top_k, context) similar_nodes, llm_prefer_types = await self._multi_query_search(query, top_k, context)
else: else:
# 传统单查询策略 # 传统单查询策略
similar_nodes = await self._single_query_search(query, top_k) similar_nodes = await self._single_query_search(query, top_k)
# 合并用户指定的偏好类型和LLM识别的偏好类型 # 合并用户指定的偏好类型和LLM识别的偏好类型
all_prefer_types = list(set(prefer_node_types + llm_prefer_types)) all_prefer_types = list(set(prefer_node_types + llm_prefer_types))
if all_prefer_types: if all_prefer_types:
@@ -551,7 +551,7 @@ class MemoryTools:
# 记录最高分数 # 记录最高分数
if mem_id not in memory_scores or similarity > memory_scores[mem_id]: if mem_id not in memory_scores or similarity > memory_scores[mem_id]:
memory_scores[mem_id] = similarity memory_scores[mem_id] = similarity
# 🔥 详细日志:检查初始召回情况 # 🔥 详细日志:检查初始召回情况
logger.info( logger.info(
f"初始向量搜索: 返回{len(similar_nodes)}个节点 → " f"初始向量搜索: 返回{len(similar_nodes)}个节点 → "
@@ -559,8 +559,8 @@ class MemoryTools:
) )
if len(initial_memory_ids) == 0: if len(initial_memory_ids) == 0:
logger.warning( logger.warning(
f"⚠️ 向量搜索未找到任何记忆!" "⚠️ 向量搜索未找到任何记忆!"
f"可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大" "可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
) )
# 输出相似节点的详细信息用于调试 # 输出相似节点的详细信息用于调试
if similar_nodes: if similar_nodes:
@@ -692,7 +692,7 @@ class MemoryTools:
key=lambda x: final_scores[x], key=lambda x: final_scores[x],
reverse=True reverse=True
) # 🔥 不再提前截断,让所有候选参与详细评分 ) # 🔥 不再提前截断,让所有候选参与详细评分
# 🔍 统计初始记忆的相似度分布(用于诊断) # 🔍 统计初始记忆的相似度分布(用于诊断)
if memory_scores: if memory_scores:
similarities = list(memory_scores.values()) similarities = list(memory_scores.values())
@@ -707,7 +707,7 @@ class MemoryTools:
# 5. 获取完整记忆并进行最终排序(优化后的动态权重系统) # 5. 获取完整记忆并进行最终排序(优化后的动态权重系统)
memories_with_scores = [] memories_with_scores = []
filter_stats = {"importance": 0, "similarity": 0, "total_checked": 0} # 过滤统计 filter_stats = {"importance": 0, "similarity": 0, "total_checked": 0} # 过滤统计
for memory_id in sorted_memory_ids: # 遍历所有候选 for memory_id in sorted_memory_ids: # 遍历所有候选
memory = self.graph_store.get_memory_by_id(memory_id) memory = self.graph_store.get_memory_by_id(memory_id)
if memory: if memory:
@@ -715,7 +715,7 @@ class MemoryTools:
# 基础分数 # 基础分数
similarity_score = final_scores[memory_id] similarity_score = final_scores[memory_id]
importance_score = memory.importance importance_score = memory.importance
# 🆕 区分记忆来源(用于过滤) # 🆕 区分记忆来源(用于过滤)
is_initial_memory = memory_id in memory_scores # 是否来自初始向量搜索 is_initial_memory = memory_id in memory_scores # 是否来自初始向量搜索
true_similarity = memory_scores.get(memory_id, 0.0) if is_initial_memory else None true_similarity = memory_scores.get(memory_id, 0.0) if is_initial_memory else None
@@ -738,16 +738,16 @@ class MemoryTools:
activation_score = memory.activation activation_score = memory.activation
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调 # 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type) memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type)
# 检测记忆的主要节点类型 # 检测记忆的主要节点类型
node_types_count = {} node_types_count = {}
for node in memory.nodes: for node in memory.nodes:
nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type) nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type)
node_types_count[nt] = node_types_count.get(nt, 0) + 1 node_types_count[nt] = node_types_count.get(nt, 0) + 1
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown" dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
# 根据记忆类型和节点类型计算调整系数(在配置权重基础上微调) # 根据记忆类型和节点类型计算调整系数(在配置权重基础上微调)
if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT": if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT":
# 事实性记忆:提升相似度权重,降低时效性权重 # 事实性记忆:提升相似度权重,降低时效性权重
@@ -777,41 +777,41 @@ class MemoryTools:
"importance": 1.0, "importance": 1.0,
"recency": 1.0, "recency": 1.0,
} }
# 应用调整后的权重(基于配置的基础权重) # 应用调整后的权重(基于配置的基础权重)
weights = { weights = {
"similarity": self.base_vector_weight * type_adjustments["similarity"], "similarity": self.base_vector_weight * type_adjustments["similarity"],
"importance": self.base_importance_weight * type_adjustments["importance"], "importance": self.base_importance_weight * type_adjustments["importance"],
"recency": self.base_recency_weight * type_adjustments["recency"], "recency": self.base_recency_weight * type_adjustments["recency"],
} }
# 归一化权重确保总和为1.0 # 归一化权重确保总和为1.0
total_weight = sum(weights.values()) total_weight = sum(weights.values())
if total_weight > 0: if total_weight > 0:
weights = {k: v / total_weight for k, v in weights.items()} weights = {k: v / total_weight for k, v in weights.items()}
# 综合分数计算(🔥 移除激活度影响) # 综合分数计算(🔥 移除激活度影响)
final_score = ( final_score = (
similarity_score * weights["similarity"] + similarity_score * weights["similarity"] +
importance_score * weights["importance"] + importance_score * weights["importance"] +
recency_score * weights["recency"] recency_score * weights["recency"]
) )
# 🆕 阈值过滤策略: # 🆕 阈值过滤策略:
# 1. 重要性过滤:应用于所有记忆(过滤极低质量) # 1. 重要性过滤:应用于所有记忆(过滤极低质量)
if memory.importance < self.search_min_importance: if memory.importance < self.search_min_importance:
filter_stats["importance"] += 1 filter_stats["importance"] += 1
logger.debug(f"❌ 过滤 {memory.id[:8]}: 重要性 {memory.importance:.2f} < 阈值 {self.search_min_importance}") logger.debug(f"❌ 过滤 {memory.id[:8]}: 重要性 {memory.importance:.2f} < 阈值 {self.search_min_importance}")
continue continue
# 2. 相似度过滤:不再对初始向量搜索结果过滤(信任向量搜索的排序) # 2. 相似度过滤:不再对初始向量搜索结果过滤(信任向量搜索的排序)
# 理由:向量搜索已经按相似度排序,返回的都是最相关结果 # 理由:向量搜索已经按相似度排序,返回的都是最相关结果
# 如果再用阈值过滤,会导致"最相关的也不够相关"的矛盾 # 如果再用阈值过滤,会导致"最相关的也不够相关"的矛盾
# #
# 注意:如果未来需要对扩展记忆过滤,可以在这里添加逻辑 # 注意:如果未来需要对扩展记忆过滤,可以在这里添加逻辑
# if not is_initial_memory and some_score < threshold: # if not is_initial_memory and some_score < threshold:
# continue # continue
# 记录通过过滤的记忆(用于调试) # 记录通过过滤的记忆(用于调试)
if is_initial_memory: if is_initial_memory:
logger.debug( logger.debug(
@@ -823,11 +823,11 @@ class MemoryTools:
f"✅ 保留 {memory.id[:8]} [扩展]: 重要性={memory.importance:.2f}, " f"✅ 保留 {memory.id[:8]} [扩展]: 重要性={memory.importance:.2f}, "
f"综合分数={final_score:.4f}" f"综合分数={final_score:.4f}"
) )
# 🆕 节点类型加权对REFERENCE/ATTRIBUTE节点额外加分促进事实性信息召回 # 🆕 节点类型加权对REFERENCE/ATTRIBUTE节点额外加分促进事实性信息召回
if "REFERENCE" in node_types_count or "ATTRIBUTE" in node_types_count: if "REFERENCE" in node_types_count or "ATTRIBUTE" in node_types_count:
final_score *= 1.1 # 10% 加成 final_score *= 1.1 # 10% 加成
# 🆕 用户指定的优先节点类型额外加权 # 🆕 用户指定的优先节点类型额外加权
if prefer_node_types: if prefer_node_types:
for prefer_type in prefer_node_types: for prefer_type in prefer_node_types:
@@ -835,7 +835,7 @@ class MemoryTools:
final_score *= 1.15 # 15% 额外加成 final_score *= 1.15 # 15% 额外加成
logger.debug(f"记忆 {memory.id[:8]} 包含优先节点类型 {prefer_type},加权后分数: {final_score:.4f}") logger.debug(f"记忆 {memory.id[:8]} 包含优先节点类型 {prefer_type},加权后分数: {final_score:.4f}")
break break
memories_with_scores.append((memory, final_score, dominant_node_type)) memories_with_scores.append((memory, final_score, dominant_node_type))
# 按综合分数排序 # 按综合分数排序
@@ -845,7 +845,7 @@ class MemoryTools:
# 统计过滤情况 # 统计过滤情况
total_candidates = len(all_memory_ids) total_candidates = len(all_memory_ids)
filtered_count = total_candidates - len(memories_with_scores) filtered_count = total_candidates - len(memories_with_scores)
# 6. 格式化结果(包含调试信息) # 6. 格式化结果(包含调试信息)
results = [] results = []
for memory, score, node_type in memories_with_scores[:top_k]: for memory, score, node_type in memories_with_scores[:top_k]:
@@ -866,7 +866,7 @@ class MemoryTools:
f"过滤{filtered_count}个 (重要性过滤) → " f"过滤{filtered_count}个 (重要性过滤) → "
f"最终返回{len(results)}条记忆" f"最终返回{len(results)}条记忆"
) )
# 如果过滤率过高,发出警告 # 如果过滤率过高,发出警告
if total_candidates > 0: if total_candidates > 0:
filter_rate = filtered_count / total_candidates filter_rate = filtered_count / total_candidates
@@ -1092,20 +1092,21 @@ class MemoryTools:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300) response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
import re import re
import orjson import orjson
# 清理Markdown代码块 # 清理Markdown代码块
response = re.sub(r"```json\s*", "", response) response = re.sub(r"```json\s*", "", response)
response = re.sub(r"```\s*$", "", response).strip() response = re.sub(r"```\s*$", "", response).strip()
# 解析JSON # 解析JSON
data = orjson.loads(response) data = orjson.loads(response)
# 提取查询列表 # 提取查询列表
queries = data.get("queries", []) queries = data.get("queries", [])
result_queries = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) result_queries = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
for item in queries if item.get("text", "").strip()] for item in queries if item.get("text", "").strip()]
# 提取偏好节点类型 # 提取偏好节点类型
prefer_node_types = data.get("prefer_node_types", []) prefer_node_types = data.get("prefer_node_types", [])
# 确保类型正确且有效 # 确保类型正确且有效
@@ -1154,7 +1155,7 @@ class MemoryTools:
limit=top_k * 5, # 🔥 从2倍提升到5倍提高初始召回率 limit=top_k * 5, # 🔥 从2倍提升到5倍提高初始召回率
min_similarity=0.0, # 不在这里过滤,交给后续评分 min_similarity=0.0, # 不在这里过滤,交给后续评分
) )
logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}") logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}")
if similar_nodes: if similar_nodes:
logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}") logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}")

View File

@@ -62,7 +62,7 @@ async def expand_memories_with_semantic_filter(
try: try:
import time import time
start_time = time.time() start_time = time.time()
# 记录已访问的记忆,避免重复 # 记录已访问的记忆,避免重复
visited_memories = set(initial_memory_ids) visited_memories = set(initial_memory_ids)
# 记录扩展的记忆及其分数 # 记录扩展的记忆及其分数
@@ -87,17 +87,17 @@ async def expand_memories_with_semantic_filter(
# 获取该记忆的邻居记忆(通过边关系) # 获取该记忆的邻居记忆(通过边关系)
neighbor_memory_ids = set() neighbor_memory_ids = set()
# 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重) # 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重)
edge_weights = {} # 记录通过不同边类型到达的记忆的权重 edge_weights = {} # 记录通过不同边类型到达的记忆的权重
for edge in memory.edges: for edge in memory.edges:
# 获取边的目标节点 # 获取边的目标节点
target_node_id = edge.target_id target_node_id = edge.target_id
source_node_id = edge.source_id source_node_id = edge.source_id
# 🆕 根据边类型设置权重优先扩展REFERENCE、ATTRIBUTE相关的边 # 🆕 根据边类型设置权重优先扩展REFERENCE、ATTRIBUTE相关的边
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type) edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type)
if edge_type_str == "REFERENCE": if edge_type_str == "REFERENCE":
edge_weight = 1.3 # REFERENCE边权重最高引用关系 edge_weight = 1.3 # REFERENCE边权重最高引用关系
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]: elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
@@ -108,18 +108,18 @@ async def expand_memories_with_semantic_filter(
edge_weight = 0.9 # 一般关系适中降权 edge_weight = 0.9 # 一般关系适中降权
else: else:
edge_weight = 1.0 # 默认权重 edge_weight = 1.0 # 默认权重
# 通过节点找到其他记忆 # 通过节点找到其他记忆
for node_id in [target_node_id, source_node_id]: for node_id in [target_node_id, source_node_id]:
if node_id in graph_store.node_to_memories: if node_id in graph_store.node_to_memories:
for neighbor_id in graph_store.node_to_memories[node_id]: for neighbor_id in graph_store.node_to_memories[node_id]:
if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight: if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight:
edge_weights[neighbor_id] = edge_weight edge_weights[neighbor_id] = edge_weight
# 将权重高的邻居记忆加入候选 # 将权重高的邻居记忆加入候选
for neighbor_id, edge_weight in edge_weights.items(): for neighbor_id, edge_weight in edge_weights.items():
neighbor_memory_ids.add((neighbor_id, edge_weight)) neighbor_memory_ids.add((neighbor_id, edge_weight))
# 过滤掉已访问的和自己 # 过滤掉已访问的和自己
filtered_neighbors = [] filtered_neighbors = []
for neighbor_id, edge_weight in neighbor_memory_ids: for neighbor_id, edge_weight in neighbor_memory_ids:
@@ -129,7 +129,7 @@ async def expand_memories_with_semantic_filter(
# 批量评估邻居记忆 # 批量评估邻居记忆
for neighbor_mem_id, edge_weight in filtered_neighbors: for neighbor_mem_id, edge_weight in filtered_neighbors:
candidates_checked += 1 candidates_checked += 1
neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id) neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id)
if not neighbor_memory: if not neighbor_memory:
continue continue
@@ -139,7 +139,7 @@ async def expand_memories_with_semantic_filter(
(n for n in neighbor_memory.nodes if n.has_embedding()), (n for n in neighbor_memory.nodes if n.has_embedding()),
None None
) )
if not topic_node or topic_node.embedding is None: if not topic_node or topic_node.embedding is None:
continue continue
@@ -179,11 +179,11 @@ async def expand_memories_with_semantic_filter(
if len(expanded_memories) >= max_expanded: if len(expanded_memories) >= max_expanded:
logger.debug(f"⏹️ 提前停止:已达到最大扩展数量 {max_expanded}") logger.debug(f"⏹️ 提前停止:已达到最大扩展数量 {max_expanded}")
break break
# 早停检查 # 早停检查
if len(expanded_memories) >= max_expanded: if len(expanded_memories) >= max_expanded:
break break
# 记录本层统计 # 记录本层统计
depth_stats.append({ depth_stats.append({
"depth": depth + 1, "depth": depth + 1,
@@ -199,20 +199,20 @@ async def expand_memories_with_semantic_filter(
# 限制下一层的记忆数量,避免爆炸性增长 # 限制下一层的记忆数量,避免爆炸性增长
current_level_memories = next_level_memories[:max_expanded] current_level_memories = next_level_memories[:max_expanded]
# 每层让出控制权 # 每层让出控制权
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
# 排序并返回 # 排序并返回
sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded] sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded]
elapsed = time.time() - start_time elapsed = time.time() - start_time
logger.info( logger.info(
f"✅ 图扩展完成: 初始{len(initial_memory_ids)}个 → " f"✅ 图扩展完成: 初始{len(initial_memory_ids)}个 → "
f"扩展{len(sorted_results)}个新记忆 " f"扩展{len(sorted_results)}个新记忆 "
f"(深度={max_depth}, 阈值={semantic_threshold:.2f}, 耗时={elapsed:.3f}s)" f"(深度={max_depth}, 阈值={semantic_threshold:.2f}, 耗时={elapsed:.3f}s)"
) )
# 输出每层统计 # 输出每层统计
for stat in depth_stats: for stat in depth_stats:
logger.debug( logger.debug(

View File

@@ -78,11 +78,9 @@ __all__ = [
# 消息 # 消息
"MaiMessages", "MaiMessages",
# 工具函数 # 工具函数
"ManifestValidator",
"PluginInfo", "PluginInfo",
# 增强命令系统 # 增强命令系统
"PlusCommand", "PlusCommand",
"PlusCommandAdapter",
"PythonDependency", "PythonDependency",
"ToolInfo", "ToolInfo",
"ToolParamType", "ToolParamType",

View File

@@ -132,7 +132,7 @@ async def generate_reply(
prompt_mode = "s4u" # 默认使用s4u模式 prompt_mode = "s4u" # 默认使用s4u模式
if action_data and "prompt_mode" in action_data: if action_data and "prompt_mode" in action_data:
prompt_mode = action_data.get("prompt_mode", "s4u") prompt_mode = action_data.get("prompt_mode", "s4u")
# 将prompt_mode添加到available_actions中作为特殊键 # 将prompt_mode添加到available_actions中作为特殊键
# 注意这里我们需要暂时使用类型忽略因为available_actions的类型定义不支持非ActionInfo值 # 注意这里我们需要暂时使用类型忽略因为available_actions的类型定义不支持非ActionInfo值
if available_actions is None: if available_actions is None:

View File

@@ -362,7 +362,7 @@ class ChatterPlanFilter:
return "最近没有聊天内容。", "没有未读消息。", [] return "最近没有聊天内容。", "没有未读消息。", []
stream_context = chat_stream.context_manager stream_context = chat_stream.context_manager
# 获取真正的已读和未读消息 # 获取真正的已读和未读消息
read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中 read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中
if not read_messages: if not read_messages:
@@ -652,30 +652,30 @@ class ChatterPlanFilter:
if not action_info: if not action_info:
logger.debug(f"动作 {action_name} 不在可用动作列表中,保留所有参数") logger.debug(f"动作 {action_name} 不在可用动作列表中,保留所有参数")
return action_data return action_data
# 获取该动作定义的合法参数 # 获取该动作定义的合法参数
defined_params = set(action_info.action_parameters.keys()) defined_params = set(action_info.action_parameters.keys())
# 合法参数集合 # 合法参数集合
valid_params = defined_params valid_params = defined_params
# 过滤参数 # 过滤参数
filtered_data = {} filtered_data = {}
removed_params = [] removed_params = []
for key, value in action_data.items(): for key, value in action_data.items():
if key in valid_params: if key in valid_params:
filtered_data[key] = value filtered_data[key] = value
else: else:
removed_params.append(key) removed_params.append(key)
# 记录被移除的参数 # 记录被移除的参数
if removed_params: if removed_params:
logger.info( logger.info(
f"🧹 [参数过滤] 动作 '{action_name}' 移除了多余参数: {removed_params}. " f"🧹 [参数过滤] 动作 '{action_name}' 移除了多余参数: {removed_params}. "
f"合法参数: {sorted(valid_params)}" f"合法参数: {sorted(valid_params)}"
) )
return filtered_data return filtered_data
def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]: def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]:

View File

@@ -545,14 +545,14 @@ async def execute_proactive_thinking(stream_id: str):
# 获取或创建该聊天流的执行锁 # 获取或创建该聊天流的执行锁
if stream_id not in _execution_locks: if stream_id not in _execution_locks:
_execution_locks[stream_id] = asyncio.Lock() _execution_locks[stream_id] = asyncio.Lock()
lock = _execution_locks[stream_id] lock = _execution_locks[stream_id]
# 尝试获取锁,如果已被占用则跳过本次执行(防止重复) # 尝试获取锁,如果已被占用则跳过本次执行(防止重复)
if lock.locked(): if lock.locked():
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 已有正在执行的主动思考任务") logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 已有正在执行的主动思考任务")
return return
async with lock: async with lock:
logger.debug(f"🤔 开始主动思考 {stream_id}") logger.debug(f"🤔 开始主动思考 {stream_id}")
@@ -563,13 +563,13 @@ async def execute_proactive_thinking(stream_id: str):
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id) chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and chat_stream.context_manager.context.is_chatter_processing: if chat_stream and chat_stream.context_manager.context.is_chatter_processing:
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息") logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息")
return return
except Exception as e: except Exception as e:
logger.warning(f"检查 chatter 处理状态时出错: {e},继续执行") logger.warning(f"检查 chatter 处理状态时出错: {e},继续执行")
# 0.1 检查白名单/黑名单 # 0.1 检查白名单/黑名单
# 从 stream_id 获取 stream_config 字符串进行验证 # 从 stream_id 获取 stream_config 字符串进行验证
try: try:

View File

@@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata(
# 导入插件主类 # 导入插件主类
from .plugin import AntiInjectionPlugin from .plugin import AntiInjectionPlugin
__all__ = ["__plugin_meta__", "AntiInjectionPlugin"] __all__ = ["AntiInjectionPlugin", "__plugin_meta__"]

View File

@@ -8,8 +8,8 @@ import time
from src.chat.security.interfaces import ( from src.chat.security.interfaces import (
SecurityAction, SecurityAction,
SecurityCheckResult,
SecurityChecker, SecurityChecker,
SecurityCheckResult,
SecurityLevel, SecurityLevel,
) )
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -4,7 +4,7 @@
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。 处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
""" """
from src.chat.security.interfaces import SecurityAction, SecurityCheckResult from src.chat.security.interfaces import SecurityCheckResult
from src.common.logger import get_logger from src.common.logger import get_logger
from .counter_attack import CounterAttackGenerator from .counter_attack import CounterAttackGenerator

View File

@@ -64,15 +64,15 @@ class CoreActionsPlugin(BasePlugin):
# --- 根据配置注册组件 --- # --- 根据配置注册组件 ---
components: ClassVar = [] components: ClassVar = []
# 注册 reply 动作 # 注册 reply 动作
if self.get_config("components.enable_reply", True): if self.get_config("components.enable_reply", True):
components.append((ReplyAction.get_action_info(), ReplyAction)) components.append((ReplyAction.get_action_info(), ReplyAction))
# 注册 respond 动作 # 注册 respond 动作
if self.get_config("components.enable_respond", True): if self.get_config("components.enable_respond", True):
components.append((RespondAction.get_action_info(), RespondAction)) components.append((RespondAction.get_action_info(), RespondAction))
# 注册 emoji 动作 # 注册 emoji 动作
if self.get_config("components.enable_emoji", True): if self.get_config("components.enable_emoji", True):
components.append((EmojiAction.get_action_info(), EmojiAction)) components.append((EmojiAction.get_action_info(), EmojiAction))

View File

@@ -22,23 +22,23 @@ class ReplyAction(BaseAction):
- 专注于理解和回应单条消息的具体内容 - 专注于理解和回应单条消息的具体内容
- 适合 Focus 模式下的精准回复 - 适合 Focus 模式下的精准回复
""" """
# 动作基本信息 # 动作基本信息
action_name = "reply" action_name = "reply"
action_description = "针对特定消息进行精准回复。深度理解并回应单条消息的具体内容。需要指定目标消息ID。" action_description = "针对特定消息进行精准回复。深度理解并回应单条消息的具体内容。需要指定目标消息ID。"
# 激活设置 # 激活设置
activation_type = ActionActivationType.ALWAYS # 回复动作总是可用 activation_type = ActionActivationType.ALWAYS # 回复动作总是可用
mode_enable = ChatMode.ALL # 在所有模式下都可用 mode_enable = ChatMode.ALL # 在所有模式下都可用
parallel_action = False # 回复动作不能与其他动作并行 parallel_action = False # 回复动作不能与其他动作并行
# 动作参数定义 # 动作参数定义
action_parameters: ClassVar = { action_parameters: ClassVar = {
"target_message_id": "要回复的目标消息ID必需来自未读消息的 <m...> 标签)", "target_message_id": "要回复的目标消息ID必需来自未读消息的 <m...> 标签)",
"content": "回复的具体内容可选由LLM生成", "content": "回复的具体内容可选由LLM生成",
"should_quote_reply": "是否引用原消息可选true/false默认false。群聊中回复较早消息或需要明确指向时使用true", "should_quote_reply": "是否引用原消息可选true/false默认false。群聊中回复较早消息或需要明确指向时使用true",
} }
# 动作使用场景 # 动作使用场景
action_require: ClassVar = [ action_require: ClassVar = [
"需要针对特定消息进行精准回复时使用", "需要针对特定消息进行精准回复时使用",
@@ -48,10 +48,10 @@ class ReplyAction(BaseAction):
"群聊中需要明确回应某个特定用户或问题时使用", "群聊中需要明确回应某个特定用户或问题时使用",
"关注单条消息的具体内容和上下文细节", "关注单条消息的具体内容和上下文细节",
] ]
# 关联类型 # 关联类型
associated_types: ClassVar[list[str]] = ["text"] associated_types: ClassVar[list[str]] = ["text"]
async def execute(self) -> tuple[bool, str]: async def execute(self) -> tuple[bool, str]:
"""执行reply动作 """执行reply动作
@@ -70,21 +70,21 @@ class RespondAction(BaseAction):
- 适合对于群聊消息下的宏观回应 - 适合对于群聊消息下的宏观回应
- 避免与单一用户深度对话而忽略其他用户的消息 - 避免与单一用户深度对话而忽略其他用户的消息
""" """
# 动作基本信息 # 动作基本信息
action_name = "respond" action_name = "respond"
action_description = "统一回应所有未读消息。理解整体对话动态和话题走向,生成连贯的回复。无需指定目标消息。" action_description = "统一回应所有未读消息。理解整体对话动态和话题走向,生成连贯的回复。无需指定目标消息。"
# 激活设置 # 激活设置
activation_type = ActionActivationType.ALWAYS # 回应动作总是可用 activation_type = ActionActivationType.ALWAYS # 回应动作总是可用
mode_enable = ChatMode.ALL # 在所有模式下都可用 mode_enable = ChatMode.ALL # 在所有模式下都可用
parallel_action = False # 回应动作不能与其他动作并行 parallel_action = False # 回应动作不能与其他动作并行
# 动作参数定义 # 动作参数定义
action_parameters: ClassVar = { action_parameters: ClassVar = {
"content": "回复的具体内容可选由LLM生成", "content": "回复的具体内容可选由LLM生成",
} }
# 动作使用场景 # 动作使用场景
action_require: ClassVar = [ action_require: ClassVar = [
"需要统一回应多条未读消息时使用Normal 模式专用)", "需要统一回应多条未读消息时使用Normal 模式专用)",
@@ -94,10 +94,10 @@ class RespondAction(BaseAction):
"适合群聊中的自然对话流,无需精确指向特定消息", "适合群聊中的自然对话流,无需精确指向特定消息",
"可以同时回应多个话题或参与者", "可以同时回应多个话题或参与者",
] ]
# 关联类型 # 关联类型
associated_types: ClassVar[list[str]] = ["text"] associated_types: ClassVar[list[str]] = ["text"]
async def execute(self) -> tuple[bool, str]: async def execute(self) -> tuple[bool, str]:
"""执行respond动作 """执行respond动作

View File

@@ -6,10 +6,10 @@
import asyncio import asyncio
import base64 import base64
import datetime import datetime
import filetype
from collections.abc import Callable from collections.abc import Callable
import aiohttp import aiohttp
import filetype
from maim_message import UserInfo from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager

View File

@@ -6,7 +6,7 @@
import re import re
from typing import ClassVar from typing import ClassVar
from src.chat.utils.prompt_component_manager import prompt_component_manager
from src.plugin_system.apis import ( from src.plugin_system.apis import (
plugin_manage_api, plugin_manage_api,
) )
@@ -74,6 +74,7 @@ class SystemCommand(PlusCommand):
• `/system permission` - 权限管理 • `/system permission` - 权限管理
• `/system plugin` - 插件管理 • `/system plugin` - 插件管理
• `/system schedule` - 定时任务管理 • `/system schedule` - 定时任务管理
• `/system prompt` - 提示词注入管理
""" """
elif target == "schedule": elif target == "schedule":
help_text = """📅 定时任务管理帮助 help_text = """📅 定时任务管理帮助
@@ -113,8 +114,17 @@ class SystemCommand(PlusCommand):
• /system permission nodes [插件名] - 查看权限节点 • /system permission nodes [插件名] - 查看权限节点
• /system permission allnodes - 查看所有权限节点详情 • /system permission allnodes - 查看所有权限节点详情
""" """
await self.send_text(help_text) elif target == "prompt":
help_text = """📝 提示词注入管理帮助
🔎 查询命令 (需要 `system.prompt.view` 权限):
• `/system prompt help` - 显示此帮助
• `/system prompt map` - 查看全局注入关系图
• `/system prompt targets` - 列出所有可被注入的核心提示词
• `/system prompt components` - 列出所有已注册的提示词组件
• `/system prompt info <目标名>` - 查看特定核心提示词的注入详情
"""
await self.send_text(help_text)
# ================================================================= # =================================================================
# Plugin Management Section # Plugin Management Section
@@ -231,6 +241,101 @@ class SystemCommand(PlusCommand):
else: else:
await self.send_text(f"❌ 恢复任务失败: `{schedule_id}`") await self.send_text(f"❌ 恢复任务失败: `{schedule_id}`")
# =================================================================
# Prompt Management Section
# =================================================================
async def _handle_prompt_commands(self, args: list[str]):
"""处理提示词管理相关命令"""
if not args or args[0].lower() in ["help", "帮助"]:
await self._show_help("prompt")
return
action = args[0].lower()
remaining_args = args[1:]
if action in ["map", "关系图"]:
await self._show_injection_map()
elif action in ["targets", "目标"]:
await self._list_core_prompts()
elif action in ["components", "组件"]:
await self._list_prompt_components()
elif action in ["info", "详情"] and remaining_args:
await self._get_prompt_injection_info(remaining_args[0])
else:
await self.send_text("❌ 提示词管理命令不合法\n使用 /system prompt help 查看帮助")
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _show_injection_map(self):
"""显示全局注入关系图"""
injection_map = await prompt_component_manager.get_full_injection_map()
if not injection_map:
await self.send_text("📊 当前没有任何提示词注入关系")
return
response_parts = ["📊 全局提示词注入关系图:\n"]
for target, injections in injection_map.items():
if injections:
response_parts.append(f"🎯 **{target}** (注入源):")
for inj in injections:
source_tag = f"({inj['source']})" if inj['source'] != 'static_default' else ''
response_parts.append(f" ⎿ `{inj['name']}` (优先级: {inj['priority']}) {source_tag}")
else:
response_parts.append(f"🎯 **{target}** (无注入)")
await self._send_long_message("\n".join(response_parts))
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _list_core_prompts(self):
"""列出所有可注入的核心提示词"""
targets = prompt_component_manager.get_core_prompts()
if not targets:
await self.send_text("🎯 当前没有可注入的核心提示词")
return
response = "🎯 所有可注入的核心提示词:\n" + "\n".join([f"• `{name}`" for name in targets])
await self.send_text(response)
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _list_prompt_components(self):
"""列出所有已注册的提示词组件"""
components = prompt_component_manager.get_registered_prompt_component_info()
if not components:
await self.send_text("🧩 当前没有已注册的提示词组件")
return
response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"]
for comp in components:
response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)")
await self._send_long_message("\n".join(response_parts))
@require_permission("prompt.view", deny_message="❌ 你没有查看提示词注入信息的权限")
async def _get_prompt_injection_info(self, target_name: str):
"""获取特定核心提示词的注入详情"""
injections = await prompt_component_manager.get_injections_for_prompt(target_name)
core_prompts = prompt_component_manager.get_core_prompts()
if target_name not in core_prompts:
await self.send_text(f"❌ 找不到核心提示词: `{target_name}`")
return
if not injections:
await self.send_text(f"🎯 核心提示词 `{target_name}` 当前没有被任何组件注入。")
return
response_parts = [f"🔎 核心提示词 `{target_name}` 的注入详情:"]
for inj in injections:
response_parts.append(
f" • **`{inj['name']}`** (优先级: {inj['priority']})"
)
response_parts.append(f" - 来源: `{inj['source']}`")
response_parts.append(f" - 类型: `{inj['injection_type']}`")
if inj.get('target_content'):
response_parts.append(f" - 操作目标: `{inj['target_content']}`")
await self.send_text("\n".join(response_parts))
# ================================================================= # =================================================================
# Permission Management Section # Permission Management Section
# ================================================================= # =================================================================

File diff suppressed because it is too large Load Diff