style: 统一代码风格并采用现代化类型注解

对整个代码库进行了一次全面的代码风格清理和现代化改造,主要包括:

- 移除了所有文件中多余的行尾空格。
- 将类型提示更新为 PEP 585 和 PEP 604 引入的现代语法(例如,使用 `list` 代替 `List`,使用 `|` 代替 `Optional`)。
- 清理了多个模块中未被使用的导入语句。
- 移除了不含插值变量的冗余 f-string。
- 调整了部分 `__init__.py` 文件中的 `__all__` 导出顺序,以保持一致性。

这些改动旨在提升代码的可读性和可维护性,使其与现代 Python 最佳实践保持一致,但未修改任何核心逻辑。
This commit is contained in:
minecraft1024a
2025-11-12 12:49:40 +08:00
parent daf8ea7e6a
commit 0e1e9935b2
33 changed files with 227 additions and 229 deletions

2
bot.py
View File

@@ -588,7 +588,7 @@ class MaiBotMain:
async def run_async_init(self, main_system): async def run_async_init(self, main_system):
"""执行异步初始化步骤""" """执行异步初始化步骤"""
# 初始化数据库表结构 # 初始化数据库表结构
await self.initialize_database_async() await self.initialize_database_async()

View File

@@ -19,14 +19,13 @@
import asyncio import asyncio
import sys import sys
from pathlib import Path from pathlib import Path
from typing import List
# 添加项目根目录到路径 # 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
async def generate_missing_embeddings( async def generate_missing_embeddings(
target_node_types: List[str] = None, target_node_types: list[str] = None,
batch_size: int = 50, batch_size: int = 50,
): ):
""" """
@@ -46,13 +45,13 @@ async def generate_missing_embeddings(
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value] target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
print(f"\n{'='*80}") print(f"\n{'='*80}")
print(f"🔧 为节点生成嵌入向量") print("🔧 为节点生成嵌入向量")
print(f"{'='*80}\n") print(f"{'='*80}\n")
print(f"目标节点类型: {', '.join(target_node_types)}") print(f"目标节点类型: {', '.join(target_node_types)}")
print(f"批处理大小: {batch_size}\n") print(f"批处理大小: {batch_size}\n")
# 1. 初始化记忆管理器 # 1. 初始化记忆管理器
print(f"🔧 正在初始化记忆管理器...") print("🔧 正在初始化记忆管理器...")
await initialize_memory_manager() await initialize_memory_manager()
manager = get_memory_manager() manager = get_memory_manager()
@@ -60,10 +59,10 @@ async def generate_missing_embeddings(
print("❌ 记忆管理器初始化失败") print("❌ 记忆管理器初始化失败")
return return
print(f"✅ 记忆管理器已初始化\n") print("✅ 记忆管理器已初始化\n")
# 2. 获取已索引的节点ID # 2. 获取已索引的节点ID
print(f"🔍 检查现有向量索引...") print("🔍 检查现有向量索引...")
existing_node_ids = set() existing_node_ids = set()
try: try:
vector_count = manager.vector_store.collection.count() vector_count = manager.vector_store.collection.count()
@@ -78,14 +77,14 @@ async def generate_missing_embeddings(
) )
if result and "ids" in result: if result and "ids" in result:
existing_node_ids.update(result["ids"]) existing_node_ids.update(result["ids"])
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n") print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
except Exception as e: except Exception as e:
logger.warning(f"获取已索引节点ID失败: {e}") logger.warning(f"获取已索引节点ID失败: {e}")
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n") print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
# 3. 收集需要生成嵌入的节点 # 3. 收集需要生成嵌入的节点
print(f"🔍 扫描需要生成嵌入的节点...") print("🔍 扫描需要生成嵌入的节点...")
all_memories = manager.graph_store.get_all_memories() all_memories = manager.graph_store.get_all_memories()
nodes_to_process = [] nodes_to_process = []
@@ -110,7 +109,7 @@ async def generate_missing_embeddings(
}) })
type_stats[node.node_type.value]["need_emb"] += 1 type_stats[node.node_type.value]["need_emb"] += 1
print(f"\n📊 扫描结果:") print("\n📊 扫描结果:")
for node_type in target_node_types: for node_type in target_node_types:
stats = type_stats[node_type] stats = type_stats[node_type]
already_ok = stats["already_indexed"] already_ok = stats["already_indexed"]
@@ -121,11 +120,11 @@ async def generate_missing_embeddings(
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n") print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
if len(nodes_to_process) == 0: if len(nodes_to_process) == 0:
print(f"✅ 所有节点已有嵌入向量,无需生成") print("✅ 所有节点已有嵌入向量,无需生成")
return return
# 3. 批量生成嵌入 # 3. 批量生成嵌入
print(f"🚀 开始生成嵌入向量...\n") print("🚀 开始生成嵌入向量...\n")
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
success_count = 0 success_count = 0
@@ -193,22 +192,22 @@ async def generate_missing_embeddings(
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n") print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
# 4. 保存图数据(更新节点的 embedding 字段) # 4. 保存图数据(更新节点的 embedding 字段)
print(f"💾 保存图数据...") print("💾 保存图数据...")
try: try:
await manager.persistence.save_graph_store(manager.graph_store) await manager.persistence.save_graph_store(manager.graph_store)
print(f"✅ 图数据已保存\n") print("✅ 图数据已保存\n")
except Exception as e: except Exception as e:
logger.error(f"保存图数据失败", exc_info=True) logger.error("保存图数据失败", exc_info=True)
print(f"❌ 保存失败: {e}\n") print(f"❌ 保存失败: {e}\n")
# 5. 验证结果 # 5. 验证结果
print(f"🔍 验证向量索引...") print("🔍 验证向量索引...")
final_vector_count = manager.vector_store.collection.count() final_vector_count = manager.vector_store.collection.count()
stats = manager.graph_store.get_statistics() stats = manager.graph_store.get_statistics()
total_nodes = stats["total_nodes"] total_nodes = stats["total_nodes"]
print(f"\n{'='*80}") print(f"\n{'='*80}")
print(f"📊 生成完成") print("📊 生成完成")
print(f"{'='*80}") print(f"{'='*80}")
print(f"处理节点数: {len(nodes_to_process)}") print(f"处理节点数: {len(nodes_to_process)}")
print(f"成功生成: {success_count}") print(f"成功生成: {success_count}")
@@ -219,7 +218,7 @@ async def generate_missing_embeddings(
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n") print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
# 6. 测试搜索 # 6. 测试搜索
print(f"🧪 测试搜索功能...") print("🧪 测试搜索功能...")
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"] test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
for query in test_queries: for query in test_queries:

View File

@@ -38,7 +38,7 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
# ========== 性能配置参数 ========== # ========== 性能配置参数 ==========
# #
# 知识提取步骤2txt转json并发控制 # 知识提取步骤2txt转json并发控制
# - 控制同时进行的LLM提取请求数量 # - 控制同时进行的LLM提取请求数量
# - 推荐值: 3-10取决于API速率限制 # - 推荐值: 3-10取决于API速率限制
@@ -184,7 +184,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
tuple: (doc_item或None, failed_hash或None) tuple: (doc_item或None, failed_hash或None)
""" """
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
# 🔧 优化:使用异步文件检查,避免阻塞 # 🔧 优化:使用异步文件检查,避免阻塞
if os.path.exists(temp_file_path): if os.path.exists(temp_file_path):
try: try:
@@ -215,11 +215,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
"extracted_entities": extracted_data.get("entities", []), "extracted_entities": extracted_data.get("entities", []),
"extracted_triples": extracted_data.get("triples", []), "extracted_triples": extracted_data.get("triples", []),
} }
# 保存到缓存(异步写入) # 保存到缓存(异步写入)
async with aiofiles.open(temp_file_path, "wb") as f: async with aiofiles.open(temp_file_path, "wb") as f:
await f.write(orjson.dumps(doc_item)) await f.write(orjson.dumps(doc_item))
return doc_item, None return doc_item, None
except Exception as e: except Exception as e:
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
@@ -249,13 +249,13 @@ async def extract_information(paragraphs_dict, model_set):
os.makedirs(TEMP_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True)
failed_hashes, open_ie_docs = [], [] failed_hashes, open_ie_docs = [], []
# 🔧 关键修复:创建单个 LLM 请求实例,复用连接 # 🔧 关键修复:创建单个 LLM 请求实例,复用连接
llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction") llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction")
# 🔧 并发控制:限制最大并发数,防止速率限制 # 🔧 并发控制:限制最大并发数,防止速率限制
semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY) semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY)
async def extract_with_semaphore(pg_hash, paragraph): async def extract_with_semaphore(pg_hash, paragraph):
"""带信号量控制的提取函数""" """带信号量控制的提取函数"""
async with semaphore: async with semaphore:
@@ -266,7 +266,7 @@ async def extract_information(paragraphs_dict, model_set):
extract_with_semaphore(p_hash, paragraph) extract_with_semaphore(p_hash, paragraph)
for p_hash, paragraph in paragraphs_dict.items() for p_hash, paragraph in paragraphs_dict.items()
] ]
total = len(tasks) total = len(tasks)
completed = 0 completed = 0
@@ -284,7 +284,7 @@ async def extract_information(paragraphs_dict, model_set):
TimeRemainingColumn(), TimeRemainingColumn(),
) as progress: ) as progress:
task = progress.add_task("[cyan]正在提取信息...", total=total) task = progress.add_task("[cyan]正在提取信息...", total=total)
# 🔧 优化:使用 asyncio.gather 并发执行所有任务 # 🔧 优化:使用 asyncio.gather 并发执行所有任务
# return_exceptions=True 确保单个失败不影响其他任务 # return_exceptions=True 确保单个失败不影响其他任务
for coro in asyncio.as_completed(tasks): for coro in asyncio.as_completed(tasks):
@@ -293,7 +293,7 @@ async def extract_information(paragraphs_dict, model_set):
failed_hashes.append(failed_hash) failed_hashes.append(failed_hash)
elif doc_item: elif doc_item:
open_ie_docs.append(doc_item) open_ie_docs.append(doc_item)
completed += 1 completed += 1
progress.update(task, advance=1) progress.update(task, advance=1)
@@ -415,7 +415,7 @@ def rebuild_faiss_only():
logger.info("--- 重建 FAISS 索引 ---") logger.info("--- 重建 FAISS 索引 ---")
# 重建索引不需要并发参数(不涉及 embedding 生成) # 重建索引不需要并发参数(不涉及 embedding 生成)
embed_manager = EmbeddingManager() embed_manager = EmbeddingManager()
logger.info("正在加载现有的 Embedding 库...") logger.info("正在加载现有的 Embedding 库...")
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()

View File

@@ -4,13 +4,13 @@
提供 Web API 用于可视化记忆图数据 提供 Web API 用于可视化记忆图数据
""" """
from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any
from collections import defaultdict
import orjson import orjson
from fastapi import APIRouter, HTTPException, Request, Query from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
@@ -29,7 +29,7 @@ router = APIRouter()
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates")) templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
def find_available_data_files() -> List[Path]: def find_available_data_files() -> list[Path]:
"""查找所有可用的记忆图数据文件""" """查找所有可用的记忆图数据文件"""
files = [] files = []
if not data_dir.exists(): if not data_dir.exists():
@@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]:
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True) return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]: def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据""" """从磁盘加载图数据"""
global graph_data_cache, current_data_file global graph_data_cache, current_data_file
@@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any
if not graph_file.exists(): if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}} return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
with open(graph_file, "r", encoding="utf-8") as f: with open(graph_file, encoding="utf-8") as f:
data = orjson.loads(f.read()) data = orjson.loads(f.read())
nodes = data.get("nodes", []) nodes = data.get("nodes", [])
@@ -150,7 +150,7 @@ async def index(request: Request):
return templates.TemplateResponse("visualizer.html", {"request": request}) return templates.TemplateResponse("visualizer.html", {"request": request})
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]: def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据""" """从 MemoryManager 提取并格式化图数据"""
if not memory_manager.graph_store: if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}} return {"nodes": [], "edges": [], "memories": [], "stats": {}}
@@ -188,7 +188,7 @@ def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
"arrows": "to", "arrows": "to",
"memory_id": memory.id, "memory_id": memory.id,
} }
edges_list = list(edges_dict.values()) edges_list = list(edges_dict.values())
stats = memory_manager.get_statistics() stats = memory_manager.get_statistics()
@@ -261,7 +261,7 @@ async def get_paginated_graph(
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"), page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"), min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"), node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
): ):
"""分页获取图数据,支持重要性过滤""" """分页获取图数据,支持重要性过滤"""
try: try:
@@ -301,13 +301,13 @@ async def get_paginated_graph(
total_pages = (total_nodes + page_size - 1) // page_size total_pages = (total_nodes + page_size - 1) // page_size
start_idx = (page - 1) * page_size start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_nodes) end_idx = min(start_idx + page_size, total_nodes)
paginated_nodes = nodes_with_importance[start_idx:end_idx] paginated_nodes = nodes_with_importance[start_idx:end_idx]
node_ids = set(n["id"] for n in paginated_nodes) node_ids = set(n["id"] for n in paginated_nodes)
# 只保留连接分页节点的边 # 只保留连接分页节点的边
paginated_edges = [ paginated_edges = [
e for e in edges e for e in edges
if e.get("from") in node_ids and e.get("to") in node_ids if e.get("from") in node_ids and e.get("to") in node_ids
] ]
@@ -383,7 +383,7 @@ async def get_clustered_graph(
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500) return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict: def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
"""简单的图聚类算法:按类型和连接度聚类""" """简单的图聚类算法:按类型和连接度聚类"""
# 构建邻接表 # 构建邻接表
adjacency = defaultdict(set) adjacency = defaultdict(set)
@@ -412,21 +412,21 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
for node in type_nodes: for node in type_nodes:
importance = len(adjacency[node["id"]]) importance = len(adjacency[node["id"]])
node_importance.append((node, importance)) node_importance.append((node, importance))
node_importance.sort(key=lambda x: x[1], reverse=True) node_importance.sort(key=lambda x: x[1], reverse=True)
# 保留前N个重要节点 # 保留前N个重要节点
keep_count = min(len(type_nodes), max_nodes // len(type_groups)) keep_count = min(len(type_nodes), max_nodes // len(type_groups))
for node, importance in node_importance[:keep_count]: for node, importance in node_importance[:keep_count]:
clustered_nodes.append(node) clustered_nodes.append(node)
node_mapping[node["id"]] = node["id"] node_mapping[node["id"]] = node["id"]
# 其余节点聚合为一个超级节点 # 其余节点聚合为一个超级节点
if len(node_importance) > keep_count: if len(node_importance) > keep_count:
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]] clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}" cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)" cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
clustered_nodes.append({ clustered_nodes.append({
"id": cluster_id, "id": cluster_id,
"label": cluster_label, "label": cluster_label,
@@ -436,7 +436,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
"cluster_size": len(clustered_node_ids), "cluster_size": len(clustered_node_ids),
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示 "clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
}) })
for node_id in clustered_node_ids: for node_id in clustered_node_ids:
node_mapping[node_id] = cluster_id node_mapping[node_id] = cluster_id
@@ -445,7 +445,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
for edge in edges: for edge in edges:
from_id = node_mapping.get(edge["from"]) from_id = node_mapping.get(edge["from"])
to_id = node_mapping.get(edge["to"]) to_id = node_mapping.get(edge["to"])
if from_id and to_id and from_id != to_id: if from_id and to_id and from_id != to_id:
edge_key = tuple(sorted([from_id, to_id])) edge_key = tuple(sorted([from_id, to_id]))
if edge_key not in edge_set: if edge_key not in edge_set:

View File

@@ -1,6 +1,5 @@
from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Literal from typing import Literal
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query

View File

@@ -161,16 +161,16 @@ class EmbeddingStore:
# 限制 chunk_size 和 max_workers 在合理范围内 # 限制 chunk_size 和 max_workers 在合理范围内
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE)) chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS)) max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
semaphore = asyncio.Semaphore(max_workers) semaphore = asyncio.Semaphore(max_workers)
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
results = {} results = {}
# 将字符串列表分成多个 chunk # 将字符串列表分成多个 chunk
chunks = [] chunks = []
for i in range(0, len(strs), chunk_size): for i in range(0, len(strs), chunk_size):
chunks.append(strs[i : i + chunk_size]) chunks.append(strs[i : i + chunk_size])
async def _process_chunk(chunk: list[str]): async def _process_chunk(chunk: list[str]):
"""处理一个 chunk 的字符串(批量获取 embedding""" """处理一个 chunk 的字符串(批量获取 embedding"""
async with semaphore: async with semaphore:
@@ -180,12 +180,12 @@ class EmbeddingStore:
embedding = await EmbeddingStore._get_embedding_async(llm, s) embedding = await EmbeddingStore._get_embedding_async(llm, s)
embeddings.append(embedding) embeddings.append(embedding)
results[s] = embedding results[s] = embedding
if progress_callback: if progress_callback:
progress_callback(len(chunk)) progress_callback(len(chunk))
return embeddings return embeddings
# 并发处理所有 chunks # 并发处理所有 chunks
tasks = [_process_chunk(chunk) for chunk in chunks] tasks = [_process_chunk(chunk) for chunk in chunks]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@@ -418,22 +418,22 @@ class EmbeddingStore:
# 🔧 修复:检查所有 embedding 的维度是否一致 # 🔧 修复:检查所有 embedding 的维度是否一致
dimensions = [len(emb) for emb in array] dimensions = [len(emb) for emb in array]
unique_dims = set(dimensions) unique_dims = set(dimensions)
if len(unique_dims) > 1: if len(unique_dims) > 1:
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}") logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}") logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
# 获取期望的维度(使用最常见的维度) # 获取期望的维度(使用最常见的维度)
from collections import Counter from collections import Counter
dim_counter = Counter(dimensions) dim_counter = Counter(dimensions)
expected_dim = dim_counter.most_common(1)[0][0] expected_dim = dim_counter.most_common(1)[0][0]
logger.warning(f"将使用最常见的维度: {expected_dim}") logger.warning(f"将使用最常见的维度: {expected_dim}")
# 过滤掉维度不匹配的 embedding # 过滤掉维度不匹配的 embedding
filtered_array = [] filtered_array = []
filtered_idx2hash = {} filtered_idx2hash = {}
skipped_count = 0 skipped_count = 0
for i, emb in enumerate(array): for i, emb in enumerate(array):
if len(emb) == expected_dim: if len(emb) == expected_dim:
filtered_array.append(emb) filtered_array.append(emb)
@@ -442,11 +442,11 @@ class EmbeddingStore:
skipped_count += 1 skipped_count += 1
hash_key = self.idx2hash[str(i)] hash_key = self.idx2hash[str(i)]
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}") logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding") logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
array = filtered_array array = filtered_array
self.idx2hash = filtered_idx2hash self.idx2hash = filtered_idx2hash
if not array: if not array:
logger.error("过滤后没有可用的 embedding无法构建索引") logger.error("过滤后没有可用的 embedding无法构建索引")
embedding_dim = expected_dim embedding_dim = expected_dim

View File

@@ -13,4 +13,4 @@ __all__ = [
"StreamLoopManager", "StreamLoopManager",
"message_manager", "message_manager",
"stream_loop_manager", "stream_loop_manager",
] ]

View File

@@ -82,7 +82,7 @@ class SingleStreamContextManager:
self.total_messages += 1 self.total_messages += 1
self.last_access_time = time.time() self.last_access_time = time.time()
# 如果使用了缓存系统,输出调试信息 # 如果使用了缓存系统,输出调试信息
if cache_enabled and self.context.is_cache_enabled: if cache_enabled and self.context.is_cache_enabled:
if self.context.is_chatter_processing: if self.context.is_chatter_processing:

View File

@@ -111,9 +111,9 @@ class StreamLoopManager:
# 获取或创建该流的启动锁 # 获取或创建该流的启动锁
if stream_id not in self._stream_start_locks: if stream_id not in self._stream_start_locks:
self._stream_start_locks[stream_id] = asyncio.Lock() self._stream_start_locks[stream_id] = asyncio.Lock()
lock = self._stream_start_locks[stream_id] lock = self._stream_start_locks[stream_id]
# 使用锁防止并发启动同一个流的多个循环任务 # 使用锁防止并发启动同一个流的多个循环任务
async with lock: async with lock:
# 获取流上下文 # 获取流上下文
@@ -148,7 +148,7 @@ class StreamLoopManager:
# 紧急取消 # 紧急取消
context.stream_loop_task.cancel() context.stream_loop_task.cancel()
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}") loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
# 将任务记录到 StreamContext 中 # 将任务记录到 StreamContext 中
@@ -249,7 +249,7 @@ class StreamLoopManager:
self.stats["total_process_cycles"] += 1 self.stats["total_process_cycles"] += 1
if success: if success:
logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功") logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功")
# 🔒 处理成功后,等待一小段时间确保清理操作完成 # 🔒 处理成功后,等待一小段时间确保清理操作完成
# 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环 # 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@@ -379,7 +379,7 @@ class StreamLoopManager:
self.chatter_manager.process_stream_context(stream_id, context), self.chatter_manager.process_stream_context(stream_id, context),
name=f"chatter_process_{stream_id}" name=f"chatter_process_{stream_id}"
) )
# 等待 chatter 任务完成 # 等待 chatter 任务完成
results = await chatter_task results = await chatter_task
success = results.get("success", False) success = results.get("success", False)
@@ -395,8 +395,8 @@ class StreamLoopManager:
else: else:
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}") logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
return success return success
except asyncio.CancelledError: except asyncio.CancelledError:
if chatter_task and not chatter_task.done(): if chatter_task and not chatter_task.done():
chatter_task.cancel() chatter_task.cancel()
raise raise
@@ -706,4 +706,4 @@ class StreamLoopManager:
# 全局流循环管理器实例 # 全局流循环管理器实例
stream_loop_manager = StreamLoopManager() stream_loop_manager = StreamLoopManager()

View File

@@ -417,7 +417,7 @@ class MessageManager:
return return
# 记录详细信息 # 记录详细信息
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}" msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
for msg in unread_messages[:3]] # 只显示前3条 for msg in unread_messages[:3]] # 只显示前3条
logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}") logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}")
@@ -446,15 +446,15 @@ class MessageManager:
context = chat_stream.context_manager.context context = chat_stream.context_manager.context
if hasattr(context, "unread_messages") and context.unread_messages: if hasattr(context, "unread_messages") and context.unread_messages:
unread_count = len(context.unread_messages) unread_count = len(context.unread_messages)
# 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们 # 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们
if unread_count > 0: if unread_count > 0:
# 获取所有未读消息的 ID # 获取所有未读消息的 ID
message_ids = [msg.message_id for msg in context.unread_messages] message_ids = [msg.message_id for msg in context.unread_messages]
# 标记为已读(会移到历史消息) # 标记为已读(会移到历史消息)
success = chat_stream.context_manager.mark_messages_as_read(message_ids) success = chat_stream.context_manager.mark_messages_as_read(message_ids)
if success: if success:
logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读") logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读")
else: else:
@@ -481,7 +481,7 @@ class MessageManager:
try: try:
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id) chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'): if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
chat_stream.context_manager.context.is_chatter_processing = is_processing chat_stream.context_manager.context.is_chatter_processing = is_processing
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}") logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
except Exception as e: except Exception as e:
@@ -517,7 +517,7 @@ class MessageManager:
try: try:
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id) chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'): if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
return chat_stream.context_manager.context.is_chatter_processing return chat_stream.context_manager.context.is_chatter_processing
except Exception: except Exception:
pass pass
@@ -677,4 +677,4 @@ class MessageManager:
# 创建全局消息管理器实例 # 创建全局消息管理器实例
message_manager = MessageManager() message_manager = MessageManager()

View File

@@ -248,16 +248,16 @@ class ChatterActionManager:
try: try:
# 根据动作类型确定提示词模式 # 根据动作类型确定提示词模式
prompt_mode = "s4u" if action_name == "reply" else "normal" prompt_mode = "s4u" if action_name == "reply" else "normal"
# 将prompt_mode传递给generate_reply # 将prompt_mode传递给generate_reply
action_data_with_mode = (action_data or {}).copy() action_data_with_mode = (action_data or {}).copy()
action_data_with_mode["prompt_mode"] = prompt_mode action_data_with_mode["prompt_mode"] = prompt_mode
# 只传递当前正在执行的动作,而不是所有可用动作 # 只传递当前正在执行的动作,而不是所有可用动作
# 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用" # 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用"
current_action_info = self._using_actions.get(action_name) current_action_info = self._using_actions.get(action_name)
current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {} current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {}
# 附加目标消息信息(如果存在) # 附加目标消息信息(如果存在)
if target_message: if target_message:
# 提取目标消息的关键信息 # 提取目标消息的关键信息
@@ -268,7 +268,7 @@ class ChatterActionManager:
"time": getattr(target_message, "time", 0), "time": getattr(target_message, "time", 0),
} }
current_actions["_target_message"] = target_msg_info current_actions["_target_message"] = target_msg_info
success, response_set, _ = await generator_api.generate_reply( success, response_set, _ = await generator_api.generate_reply(
chat_stream=chat_stream, chat_stream=chat_stream,
reply_message=target_message, reply_message=target_message,
@@ -295,12 +295,12 @@ class ChatterActionManager:
should_quote_reply = None should_quote_reply = None
if action_data and isinstance(action_data, dict): if action_data and isinstance(action_data, dict):
should_quote_reply = action_data.get("should_quote_reply", None) should_quote_reply = action_data.get("should_quote_reply", None)
# respond动作默认不引用回复保持对话流畅 # respond动作默认不引用回复保持对话流畅
if action_name == "respond" and should_quote_reply is None: if action_name == "respond" and should_quote_reply is None:
should_quote_reply = False should_quote_reply = False
async def _after_reply(): async def _after_reply():
# 发送并存储回复 # 发送并存储回复
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply( loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
chat_stream, chat_stream,

View File

@@ -365,7 +365,7 @@ class DefaultReplyer:
# 确保类型安全 # 确保类型安全
if isinstance(mode, str): if isinstance(mode, str):
prompt_mode_value = mode prompt_mode_value = mode
# 构建 Prompt # 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留 with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_reply_context( prompt = await self.build_prompt_reply_context(
@@ -1171,16 +1171,16 @@ class DefaultReplyer:
from src.plugin_system.apis.chat_api import get_chat_manager from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream_obj = await chat_manager.get_stream(chat_id) chat_stream_obj = await chat_manager.get_stream(chat_id)
if chat_stream_obj: if chat_stream_obj:
unread_messages = chat_stream_obj.context_manager.get_unread_messages() unread_messages = chat_stream_obj.context_manager.get_unread_messages()
if unread_messages: if unread_messages:
# 使用最后一条未读消息作为参考 # 使用最后一条未读消息作为参考
last_msg = unread_messages[-1] last_msg = unread_messages[-1]
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else "" user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else "" user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else "" user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
processed_plain_text = last_msg.processed_plain_text or "" processed_plain_text = last_msg.processed_plain_text or ""
else: else:
# 没有未读消息,使用默认值 # 没有未读消息,使用默认值
@@ -1263,19 +1263,19 @@ class DefaultReplyer:
if available_actions: if available_actions:
# 过滤掉特殊键以_开头 # 过滤掉特殊键以_开头
action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")} action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")}
# 提取目标消息信息(如果存在) # 提取目标消息信息(如果存在)
target_msg_info = available_actions.get("_target_message") # type: ignore target_msg_info = available_actions.get("_target_message") # type: ignore
if action_items: if action_items:
if len(action_items) == 1: if len(action_items) == 1:
# 单个动作 # 单个动作
action_name, action_info = list(action_items.items())[0] action_name, action_info = list(action_items.items())[0]
action_desc = action_info.description action_desc = action_info.description
# 构建基础决策信息 # 构建基础决策信息
action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n" action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n"
# 只有需要目标消息的动作才显示目标消息详情 # 只有需要目标消息的动作才显示目标消息详情
# respond 动作是统一回应所有未读消息,不应该显示特定目标消息 # respond 动作是统一回应所有未读消息,不应该显示特定目标消息
if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict): if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict):
@@ -1284,7 +1284,7 @@ class DefaultReplyer:
content = target_msg_info.get("content", "") content = target_msg_info.get("content", "")
msg_time = target_msg_info.get("time", 0) msg_time = target_msg_info.get("time", 0)
time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间" time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间"
action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n" action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n"
else: else:
# 多个动作 # 多个动作
@@ -2137,7 +2137,7 @@ class DefaultReplyer:
except Exception as e: except Exception as e:
logger.error(f"存储聊天记忆失败: {e}") logger.error(f"存储聊天记忆失败: {e}")
def weighted_sample_no_replacement(items, weights, k) -> list: def weighted_sample_no_replacement(items, weights, k) -> list:
""" """

View File

@@ -5,12 +5,12 @@
插件可以通过实现这些接口来扩展安全功能。 插件可以通过实现这些接口来扩展安全功能。
""" """
from .interfaces import SecurityCheckResult, SecurityChecker from .interfaces import SecurityChecker, SecurityCheckResult
from .manager import SecurityManager, get_security_manager from .manager import SecurityManager, get_security_manager
__all__ = [ __all__ = [
"SecurityChecker",
"SecurityCheckResult", "SecurityCheckResult",
"SecurityChecker",
"SecurityManager", "SecurityManager",
"get_security_manager", "get_security_manager",
] ]

View File

@@ -10,7 +10,7 @@ from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
logger = get_logger("security.manager") logger = get_logger("security.manager")

View File

@@ -98,7 +98,7 @@ class StreamContext(BaseDataModel):
break break
def mark_message_as_read(self, message_id: str): def mark_message_as_read(self, message_id: str):
"""标记消息为已读""" """标记消息为已读"""
# 先找到要标记的消息(处理 int/str 类型不匹配问题) # 先找到要标记的消息(处理 int/str 类型不匹配问题)
message_to_mark = None message_to_mark = None
for msg in self.unread_messages: for msg in self.unread_messages:
@@ -106,7 +106,7 @@ class StreamContext(BaseDataModel):
if str(msg.message_id) == str(message_id): if str(msg.message_id) == str(message_id):
message_to_mark = msg message_to_mark = msg
break break
# 然后移动到历史消息 # 然后移动到历史消息
if message_to_mark: if message_to_mark:
message_to_mark.is_read = True message_to_mark.is_read = True

View File

@@ -9,11 +9,12 @@
""" """
import asyncio import asyncio
import builtins
import time import time
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union from typing import Any, Generic, TypeVar
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.memory_utils import estimate_size_smart from src.common.memory_utils import estimate_size_smart
@@ -96,7 +97,7 @@ class LRUCache(Generic[T]):
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._stats = CacheStats() self._stats = CacheStats()
async def get(self, key: str) -> Optional[T]: async def get(self, key: str) -> T | None:
"""获取缓存值 """获取缓存值
Args: Args:
@@ -137,8 +138,8 @@ class LRUCache(Generic[T]):
self, self,
key: str, key: str,
value: T, value: T,
size: Optional[int] = None, size: int | None = None,
ttl: Optional[float] = None, ttl: float | None = None,
) -> None: ) -> None:
"""设置缓存值 """设置缓存值
@@ -287,8 +288,8 @@ class MultiLevelCache:
async def get( async def get(
self, self,
key: str, key: str,
loader: Optional[Callable[[], Any]] = None, loader: Callable[[], Any] | None = None,
) -> Optional[Any]: ) -> Any | None:
"""从缓存获取数据 """从缓存获取数据
查询顺序L1 -> L2 -> loader 查询顺序L1 -> L2 -> loader
@@ -329,8 +330,8 @@ class MultiLevelCache:
self, self,
key: str, key: str,
value: Any, value: Any,
size: Optional[int] = None, size: int | None = None,
ttl: Optional[float] = None, ttl: float | None = None,
) -> None: ) -> None:
"""设置缓存值 """设置缓存值
@@ -390,7 +391,7 @@ class MultiLevelCache:
await self.l2_cache.clear() await self.l2_cache.clear()
logger.info("所有缓存已清空") logger.info("所有缓存已清空")
async def get_stats(self) -> Dict[str, Any]: async def get_stats(self) -> dict[str, Any]:
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)""" """获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
# 🔧 修复:并行获取统计信息,避免锁嵌套 # 🔧 修复:并行获取统计信息,避免锁嵌套
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1")) l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
@@ -492,7 +493,7 @@ class MultiLevelCache:
logger.error(f"{cache_name}统计获取异常: {e}") logger.error(f"{cache_name}统计获取异常: {e}")
return CacheStats() return CacheStats()
async def _get_cache_keys_safe(self, cache) -> Set[str]: async def _get_cache_keys_safe(self, cache) -> builtins.set[str]:
"""安全获取缓存键集合(带超时)""" """安全获取缓存键集合(带超时)"""
try: try:
# 快速获取键集合,使用超时避免死锁 # 快速获取键集合,使用超时避免死锁
@@ -507,12 +508,12 @@ class MultiLevelCache:
logger.error(f"缓存键获取异常: {e}") logger.error(f"缓存键获取异常: {e}")
return set() return set()
async def _extract_keys_with_lock(self, cache) -> Set[str]: async def _extract_keys_with_lock(self, cache) -> builtins.set[str]:
"""在锁保护下提取键集合""" """在锁保护下提取键集合"""
async with cache._lock: async with cache._lock:
return set(cache._cache.keys()) return set(cache._cache.keys())
async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int: async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int:
"""安全计算内存使用(带超时)""" """安全计算内存使用(带超时)"""
if not keys: if not keys:
return 0 return 0
@@ -529,7 +530,7 @@ class MultiLevelCache:
logger.error(f"内存计算异常: {e}") logger.error(f"内存计算异常: {e}")
return 0 return 0
async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int: async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int:
"""在锁保护下计算内存使用""" """在锁保护下计算内存使用"""
total_size = 0 total_size = 0
async with cache._lock: async with cache._lock:
@@ -749,7 +750,7 @@ class MultiLevelCache:
# 全局缓存实例 # 全局缓存实例
_global_cache: Optional[MultiLevelCache] = None _global_cache: MultiLevelCache | None = None
_cache_lock = asyncio.Lock() _cache_lock = asyncio.Lock()

View File

@@ -3,7 +3,6 @@ import socket
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from rich.traceback import install from rich.traceback import install
from uvicorn import Config from uvicorn import Config
from uvicorn import Server as UvicornServer from uvicorn import Server as UvicornServer

View File

@@ -436,7 +436,7 @@ class OpenaiClient(BaseClient):
# 🔧 优化增加连接池限制支持高并发embedding请求 # 🔧 优化增加连接池限制支持高并发embedding请求
# 默认httpx限制为100对于高频embedding场景不够用 # 默认httpx限制为100对于高频embedding场景不够用
import httpx import httpx
limits = httpx.Limits( limits = httpx.Limits(
max_keepalive_connections=200, # 保持活跃连接数原100 max_keepalive_connections=200, # 保持活跃连接数原100
max_connections=300, # 最大总连接数原100 max_connections=300, # 最大总连接数原100

View File

@@ -128,7 +128,7 @@ class MemoryBuilder:
# 6. 构建 Memory 对象 # 6. 构建 Memory 对象
# 新记忆应该有较高的初始激活度 # 新记忆应该有较高的初始激活度
initial_activation = 0.75 # 新记忆初始激活度为 0.75 initial_activation = 0.75 # 新记忆初始激活度为 0.75
memory = Memory( memory = Memory(
id=memory_id, id=memory_id,
subject_id=subject_node.id, subject_id=subject_node.id,

View File

@@ -149,7 +149,7 @@ class MemoryManager:
# 读取阈值过滤配置 # 读取阈值过滤配置
search_min_importance = self.config.search_min_importance search_min_importance = self.config.search_min_importance
search_similarity_threshold = self.config.search_similarity_threshold search_similarity_threshold = self.config.search_similarity_threshold
logger.info( logger.info(
f"📊 配置检查: search_max_expand_depth={expand_depth}, " f"📊 配置检查: search_max_expand_depth={expand_depth}, "
f"search_expand_semantic_threshold={expand_semantic_threshold}, " f"search_expand_semantic_threshold={expand_semantic_threshold}, "
@@ -415,7 +415,7 @@ class MemoryManager:
# 使用配置的默认值 # 使用配置的默认值
if top_k is None: if top_k is None:
top_k = getattr(self.config, "search_top_k", 10) top_k = getattr(self.config, "search_top_k", 10)
# 准备搜索参数 # 准备搜索参数
params = { params = {
"query": query, "query": query,
@@ -948,7 +948,7 @@ class MemoryManager:
) )
else: else:
logger.debug(f"记忆已删除: {memory_id} (删除了 {deleted_vectors} 个向量)") logger.debug(f"记忆已删除: {memory_id} (删除了 {deleted_vectors} 个向量)")
# 4. 保存更新 # 4. 保存更新
await self.persistence.save_graph_store(self.graph_store) await self.persistence.save_graph_store(self.graph_store)
return True return True
@@ -981,7 +981,7 @@ class MemoryManager:
try: try:
forgotten_count = 0 forgotten_count = 0
all_memories = self.graph_store.get_all_memories() all_memories = self.graph_store.get_all_memories()
# 获取配置参数 # 获取配置参数
min_importance = getattr(self.config, "forgetting_min_importance", 0.8) min_importance = getattr(self.config, "forgetting_min_importance", 0.8)
decay_rate = getattr(self.config, "activation_decay_rate", 0.9) decay_rate = getattr(self.config, "activation_decay_rate", 0.9)
@@ -1007,10 +1007,10 @@ class MemoryManager:
try: try:
last_access_dt = datetime.fromisoformat(last_access) last_access_dt = datetime.fromisoformat(last_access)
days_passed = (datetime.now() - last_access_dt).days days_passed = (datetime.now() - last_access_dt).days
# 应用指数衰减activation = base * (decay_rate ^ days) # 应用指数衰减activation = base * (decay_rate ^ days)
current_activation = base_activation * (decay_rate ** days_passed) current_activation = base_activation * (decay_rate ** days_passed)
logger.debug( logger.debug(
f"记忆 {memory.id[:8]}: 基础激活度={base_activation:.3f}, " f"记忆 {memory.id[:8]}: 基础激活度={base_activation:.3f}, "
f"经过{days_passed}天衰减后={current_activation:.3f}" f"经过{days_passed}天衰减后={current_activation:.3f}"
@@ -1032,20 +1032,20 @@ class MemoryManager:
# 批量遗忘记忆(不立即清理孤立节点) # 批量遗忘记忆(不立即清理孤立节点)
if memories_to_forget: if memories_to_forget:
logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...") logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...")
for memory_id, activation in memories_to_forget: for memory_id, activation in memories_to_forget:
# cleanup_orphans=False暂不清理孤立节点 # cleanup_orphans=False暂不清理孤立节点
success = await self.forget_memory(memory_id, cleanup_orphans=False) success = await self.forget_memory(memory_id, cleanup_orphans=False)
if success: if success:
forgotten_count += 1 forgotten_count += 1
# 统一清理孤立节点和边 # 统一清理孤立节点和边
logger.info("批量遗忘完成,开始统一清理孤立节点和边...") logger.info("批量遗忘完成,开始统一清理孤立节点和边...")
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges() orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
# 保存最终更新 # 保存最终更新
await self.persistence.save_graph_store(self.graph_store) await self.persistence.save_graph_store(self.graph_store)
logger.info( logger.info(
f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, " f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, "
f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边" f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边"
@@ -1076,31 +1076,31 @@ class MemoryManager:
# 1. 清理孤立节点 # 1. 清理孤立节点
# graph_store.node_to_memories 记录了每个节点属于哪些记忆 # graph_store.node_to_memories 记录了每个节点属于哪些记忆
nodes_to_remove = [] nodes_to_remove = []
for node_id, memory_ids in list(self.graph_store.node_to_memories.items()): for node_id, memory_ids in list(self.graph_store.node_to_memories.items()):
# 如果节点不再属于任何记忆,标记为删除 # 如果节点不再属于任何记忆,标记为删除
if not memory_ids: if not memory_ids:
nodes_to_remove.append(node_id) nodes_to_remove.append(node_id)
# 从图中删除孤立节点 # 从图中删除孤立节点
for node_id in nodes_to_remove: for node_id in nodes_to_remove:
if self.graph_store.graph.has_node(node_id): if self.graph_store.graph.has_node(node_id):
self.graph_store.graph.remove_node(node_id) self.graph_store.graph.remove_node(node_id)
orphan_nodes_count += 1 orphan_nodes_count += 1
# 从映射中删除 # 从映射中删除
if node_id in self.graph_store.node_to_memories: if node_id in self.graph_store.node_to_memories:
del self.graph_store.node_to_memories[node_id] del self.graph_store.node_to_memories[node_id]
# 2. 清理孤立边(指向已删除节点的边) # 2. 清理孤立边(指向已删除节点的边)
edges_to_remove = [] edges_to_remove = []
for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'): for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
# 检查边的源节点和目标节点是否还存在于node_to_memories中 # 检查边的源节点和目标节点是否还存在于node_to_memories中
if source not in self.graph_store.node_to_memories or \ if source not in self.graph_store.node_to_memories or \
target not in self.graph_store.node_to_memories: target not in self.graph_store.node_to_memories:
edges_to_remove.append((source, target)) edges_to_remove.append((source, target))
# 删除孤立边 # 删除孤立边
for source, target in edges_to_remove: for source, target in edges_to_remove:
try: try:
@@ -1108,12 +1108,12 @@ class MemoryManager:
orphan_edges_count += 1 orphan_edges_count += 1
except Exception as e: except Exception as e:
logger.debug(f"删除边失败 {source} -> {target}: {e}") logger.debug(f"删除边失败 {source} -> {target}: {e}")
if orphan_nodes_count > 0 or orphan_edges_count > 0: if orphan_nodes_count > 0 or orphan_edges_count > 0:
logger.info( logger.info(
f"清理完成: {orphan_nodes_count} 个孤立节点, {orphan_edges_count} 条孤立边" f"清理完成: {orphan_nodes_count} 个孤立节点, {orphan_edges_count} 条孤立边"
) )
return orphan_nodes_count, orphan_edges_count return orphan_nodes_count, orphan_edges_count
except Exception as e: except Exception as e:
@@ -1255,7 +1255,7 @@ class MemoryManager:
mem for mem in recent_memories mem for mem in recent_memories
if mem.importance >= min_importance_for_consolidation if mem.importance >= min_importance_for_consolidation
] ]
result["importance_filtered"] = len(recent_memories) - len(important_memories) result["importance_filtered"] = len(recent_memories) - len(important_memories)
logger.info( logger.info(
f"📊 步骤2: 重要性过滤 (阈值={min_importance_for_consolidation:.2f}): " f"📊 步骤2: 重要性过滤 (阈值={min_importance_for_consolidation:.2f}): "
@@ -1379,26 +1379,26 @@ class MemoryManager:
# ===== 步骤4: 向量检索关联记忆 + LLM分析关系 ===== # ===== 步骤4: 向量检索关联记忆 + LLM分析关系 =====
# 过滤掉已删除的记忆 # 过滤掉已删除的记忆
remaining_memories = [m for m in important_memories if m.id not in deleted_ids] remaining_memories = [m for m in important_memories if m.id not in deleted_ids]
if not remaining_memories: if not remaining_memories:
logger.info("✅ 记忆整理完成: 去重后无剩余记忆") logger.info("✅ 记忆整理完成: 去重后无剩余记忆")
return return
logger.info(f"📍 步骤4: 开始关联分析 ({len(remaining_memories)} 条记忆)...") logger.info(f"📍 步骤4: 开始关联分析 ({len(remaining_memories)} 条记忆)...")
# 分批处理记忆关联 # 分批处理记忆关联
llm_batch_size = getattr(self.config, "consolidation_llm_batch_size", 10) llm_batch_size = getattr(self.config, "consolidation_llm_batch_size", 10)
max_candidates_per_memory = getattr(self.config, "consolidation_max_candidates", 5) max_candidates_per_memory = getattr(self.config, "consolidation_max_candidates", 5)
min_confidence = getattr(self.config, "consolidation_min_confidence", 0.6) min_confidence = getattr(self.config, "consolidation_min_confidence", 0.6)
all_new_edges = [] # 收集所有新建的边 all_new_edges = [] # 收集所有新建的边
for batch_start in range(0, len(remaining_memories), llm_batch_size): for batch_start in range(0, len(remaining_memories), llm_batch_size):
batch_end = min(batch_start + llm_batch_size, len(remaining_memories)) batch_end = min(batch_start + llm_batch_size, len(remaining_memories))
batch = remaining_memories[batch_start:batch_end] batch = remaining_memories[batch_start:batch_end]
logger.debug(f"处理批次 {batch_start//llm_batch_size + 1}/{(len(remaining_memories)-1)//llm_batch_size + 1}") logger.debug(f"处理批次 {batch_start//llm_batch_size + 1}/{(len(remaining_memories)-1)//llm_batch_size + 1}")
for memory in batch: for memory in batch:
# 跳过已经有很多连接的记忆 # 跳过已经有很多连接的记忆
existing_edges = len([ existing_edges = len([
@@ -1451,14 +1451,14 @@ class MemoryManager:
except Exception as e: except Exception as e:
logger.warning(f"创建关联边失败: {e}") logger.warning(f"创建关联边失败: {e}")
continue continue
# 每个批次后让出控制权 # 每个批次后让出控制权
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
# ===== 步骤5: 统一更新记忆数据 ===== # ===== 步骤5: 统一更新记忆数据 =====
if all_new_edges: if all_new_edges:
logger.info(f"📍 步骤5: 统一更新 {len(all_new_edges)} 条新关联边...") logger.info(f"📍 步骤5: 统一更新 {len(all_new_edges)} 条新关联边...")
for memory, edge, relation in all_new_edges: for memory, edge, relation in all_new_edges:
try: try:
# 添加到图 # 添加到图
@@ -2298,7 +2298,7 @@ class MemoryManager:
# 使用 asyncio.wait_for 来支持取消 # 使用 asyncio.wait_for 来支持取消
await asyncio.wait_for( await asyncio.wait_for(
asyncio.sleep(initial_delay), asyncio.sleep(initial_delay),
timeout=float('inf') # 允许随时取消 timeout=float("inf") # 允许随时取消
) )
# 检查是否仍然需要运行 # 检查是否仍然需要运行

View File

@@ -482,7 +482,7 @@ class GraphStore:
for node in memory.nodes: for node in memory.nodes:
if node.id in self.node_to_memories: if node.id in self.node_to_memories:
self.node_to_memories[node.id].discard(memory_id) self.node_to_memories[node.id].discard(memory_id)
# 可选:立即清理孤立节点 # 可选:立即清理孤立节点
if cleanup_orphans: if cleanup_orphans:
# 如果该节点不再属于任何记忆,从图中移除节点 # 如果该节点不再属于任何记忆,从图中移除节点

View File

@@ -70,12 +70,12 @@ class MemoryTools:
self.max_expand_depth = max_expand_depth self.max_expand_depth = max_expand_depth
self.expand_semantic_threshold = expand_semantic_threshold self.expand_semantic_threshold = expand_semantic_threshold
self.search_top_k = search_top_k self.search_top_k = search_top_k
# 保存权重配置 # 保存权重配置
self.base_vector_weight = search_vector_weight self.base_vector_weight = search_vector_weight
self.base_importance_weight = search_importance_weight self.base_importance_weight = search_importance_weight
self.base_recency_weight = search_recency_weight self.base_recency_weight = search_recency_weight
# 保存阈值过滤配置 # 保存阈值过滤配置
self.search_min_importance = search_min_importance self.search_min_importance = search_min_importance
self.search_similarity_threshold = search_similarity_threshold self.search_similarity_threshold = search_similarity_threshold
@@ -511,14 +511,14 @@ class MemoryTools:
# 1. 根据策略选择检索方式 # 1. 根据策略选择检索方式
llm_prefer_types = [] # LLM识别的偏好节点类型 llm_prefer_types = [] # LLM识别的偏好节点类型
if use_multi_query: if use_multi_query:
# 多查询策略(返回节点列表 + 偏好类型) # 多查询策略(返回节点列表 + 偏好类型)
similar_nodes, llm_prefer_types = await self._multi_query_search(query, top_k, context) similar_nodes, llm_prefer_types = await self._multi_query_search(query, top_k, context)
else: else:
# 传统单查询策略 # 传统单查询策略
similar_nodes = await self._single_query_search(query, top_k) similar_nodes = await self._single_query_search(query, top_k)
# 合并用户指定的偏好类型和LLM识别的偏好类型 # 合并用户指定的偏好类型和LLM识别的偏好类型
all_prefer_types = list(set(prefer_node_types + llm_prefer_types)) all_prefer_types = list(set(prefer_node_types + llm_prefer_types))
if all_prefer_types: if all_prefer_types:
@@ -546,7 +546,7 @@ class MemoryTools:
# 记录最高分数 # 记录最高分数
if mem_id not in memory_scores or similarity > memory_scores[mem_id]: if mem_id not in memory_scores or similarity > memory_scores[mem_id]:
memory_scores[mem_id] = similarity memory_scores[mem_id] = similarity
# 🔥 详细日志:检查初始召回情况 # 🔥 详细日志:检查初始召回情况
logger.info( logger.info(
f"初始向量搜索: 返回{len(similar_nodes)}个节点 → " f"初始向量搜索: 返回{len(similar_nodes)}个节点 → "
@@ -554,8 +554,8 @@ class MemoryTools:
) )
if len(initial_memory_ids) == 0: if len(initial_memory_ids) == 0:
logger.warning( logger.warning(
f"⚠️ 向量搜索未找到任何记忆!" "⚠️ 向量搜索未找到任何记忆!"
f"可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大" "可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
) )
# 输出相似节点的详细信息用于调试 # 输出相似节点的详细信息用于调试
if similar_nodes: if similar_nodes:
@@ -613,7 +613,7 @@ class MemoryTools:
key=lambda x: final_scores[x], key=lambda x: final_scores[x],
reverse=True reverse=True
) # 🔥 不再提前截断,让所有候选参与详细评分 ) # 🔥 不再提前截断,让所有候选参与详细评分
# 🔍 统计初始记忆的相似度分布(用于诊断) # 🔍 统计初始记忆的相似度分布(用于诊断)
if memory_scores: if memory_scores:
similarities = list(memory_scores.values()) similarities = list(memory_scores.values())
@@ -628,7 +628,7 @@ class MemoryTools:
# 5. 获取完整记忆并进行最终排序(优化后的动态权重系统) # 5. 获取完整记忆并进行最终排序(优化后的动态权重系统)
memories_with_scores = [] memories_with_scores = []
filter_stats = {"importance": 0, "similarity": 0, "total_checked": 0} # 过滤统计 filter_stats = {"importance": 0, "similarity": 0, "total_checked": 0} # 过滤统计
for memory_id in sorted_memory_ids: # 遍历所有候选 for memory_id in sorted_memory_ids: # 遍历所有候选
memory = self.graph_store.get_memory_by_id(memory_id) memory = self.graph_store.get_memory_by_id(memory_id)
if memory: if memory:
@@ -636,7 +636,7 @@ class MemoryTools:
# 基础分数 # 基础分数
similarity_score = final_scores[memory_id] similarity_score = final_scores[memory_id]
importance_score = memory.importance importance_score = memory.importance
# 🆕 区分记忆来源(用于过滤) # 🆕 区分记忆来源(用于过滤)
is_initial_memory = memory_id in memory_scores # 是否来自初始向量搜索 is_initial_memory = memory_id in memory_scores # 是否来自初始向量搜索
true_similarity = memory_scores.get(memory_id, 0.0) if is_initial_memory else None true_similarity = memory_scores.get(memory_id, 0.0) if is_initial_memory else None
@@ -659,16 +659,16 @@ class MemoryTools:
activation_score = memory.activation activation_score = memory.activation
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调 # 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type) memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type)
# 检测记忆的主要节点类型 # 检测记忆的主要节点类型
node_types_count = {} node_types_count = {}
for node in memory.nodes: for node in memory.nodes:
nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type) nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type)
node_types_count[nt] = node_types_count.get(nt, 0) + 1 node_types_count[nt] = node_types_count.get(nt, 0) + 1
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown" dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
# 根据记忆类型和节点类型计算调整系数(在配置权重基础上微调) # 根据记忆类型和节点类型计算调整系数(在配置权重基础上微调)
if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT": if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT":
# 事实性记忆:提升相似度权重,降低时效性权重 # 事实性记忆:提升相似度权重,降低时效性权重
@@ -698,41 +698,41 @@ class MemoryTools:
"importance": 1.0, "importance": 1.0,
"recency": 1.0, "recency": 1.0,
} }
# 应用调整后的权重(基于配置的基础权重) # 应用调整后的权重(基于配置的基础权重)
weights = { weights = {
"similarity": self.base_vector_weight * type_adjustments["similarity"], "similarity": self.base_vector_weight * type_adjustments["similarity"],
"importance": self.base_importance_weight * type_adjustments["importance"], "importance": self.base_importance_weight * type_adjustments["importance"],
"recency": self.base_recency_weight * type_adjustments["recency"], "recency": self.base_recency_weight * type_adjustments["recency"],
} }
# 归一化权重确保总和为1.0 # 归一化权重确保总和为1.0
total_weight = sum(weights.values()) total_weight = sum(weights.values())
if total_weight > 0: if total_weight > 0:
weights = {k: v / total_weight for k, v in weights.items()} weights = {k: v / total_weight for k, v in weights.items()}
# 综合分数计算(🔥 移除激活度影响) # 综合分数计算(🔥 移除激活度影响)
final_score = ( final_score = (
similarity_score * weights["similarity"] + similarity_score * weights["similarity"] +
importance_score * weights["importance"] + importance_score * weights["importance"] +
recency_score * weights["recency"] recency_score * weights["recency"]
) )
# 🆕 阈值过滤策略: # 🆕 阈值过滤策略:
# 1. 重要性过滤:应用于所有记忆(过滤极低质量) # 1. 重要性过滤:应用于所有记忆(过滤极低质量)
if memory.importance < self.search_min_importance: if memory.importance < self.search_min_importance:
filter_stats["importance"] += 1 filter_stats["importance"] += 1
logger.debug(f"❌ 过滤 {memory.id[:8]}: 重要性 {memory.importance:.2f} < 阈值 {self.search_min_importance}") logger.debug(f"❌ 过滤 {memory.id[:8]}: 重要性 {memory.importance:.2f} < 阈值 {self.search_min_importance}")
continue continue
# 2. 相似度过滤:不再对初始向量搜索结果过滤(信任向量搜索的排序) # 2. 相似度过滤:不再对初始向量搜索结果过滤(信任向量搜索的排序)
# 理由:向量搜索已经按相似度排序,返回的都是最相关结果 # 理由:向量搜索已经按相似度排序,返回的都是最相关结果
# 如果再用阈值过滤,会导致"最相关的也不够相关"的矛盾 # 如果再用阈值过滤,会导致"最相关的也不够相关"的矛盾
# #
# 注意:如果未来需要对扩展记忆过滤,可以在这里添加逻辑 # 注意:如果未来需要对扩展记忆过滤,可以在这里添加逻辑
# if not is_initial_memory and some_score < threshold: # if not is_initial_memory and some_score < threshold:
# continue # continue
# 记录通过过滤的记忆(用于调试) # 记录通过过滤的记忆(用于调试)
if is_initial_memory: if is_initial_memory:
logger.debug( logger.debug(
@@ -744,11 +744,11 @@ class MemoryTools:
f"✅ 保留 {memory.id[:8]} [扩展]: 重要性={memory.importance:.2f}, " f"✅ 保留 {memory.id[:8]} [扩展]: 重要性={memory.importance:.2f}, "
f"综合分数={final_score:.4f}" f"综合分数={final_score:.4f}"
) )
# 🆕 节点类型加权对REFERENCE/ATTRIBUTE节点额外加分促进事实性信息召回 # 🆕 节点类型加权对REFERENCE/ATTRIBUTE节点额外加分促进事实性信息召回
if "REFERENCE" in node_types_count or "ATTRIBUTE" in node_types_count: if "REFERENCE" in node_types_count or "ATTRIBUTE" in node_types_count:
final_score *= 1.1 # 10% 加成 final_score *= 1.1 # 10% 加成
# 🆕 用户指定的优先节点类型额外加权 # 🆕 用户指定的优先节点类型额外加权
if prefer_node_types: if prefer_node_types:
for prefer_type in prefer_node_types: for prefer_type in prefer_node_types:
@@ -756,7 +756,7 @@ class MemoryTools:
final_score *= 1.15 # 15% 额外加成 final_score *= 1.15 # 15% 额外加成
logger.debug(f"记忆 {memory.id[:8]} 包含优先节点类型 {prefer_type},加权后分数: {final_score:.4f}") logger.debug(f"记忆 {memory.id[:8]} 包含优先节点类型 {prefer_type},加权后分数: {final_score:.4f}")
break break
memories_with_scores.append((memory, final_score, dominant_node_type)) memories_with_scores.append((memory, final_score, dominant_node_type))
# 按综合分数排序 # 按综合分数排序
@@ -766,7 +766,7 @@ class MemoryTools:
# 统计过滤情况 # 统计过滤情况
total_candidates = len(all_memory_ids) total_candidates = len(all_memory_ids)
filtered_count = total_candidates - len(memories_with_scores) filtered_count = total_candidates - len(memories_with_scores)
# 6. 格式化结果(包含调试信息) # 6. 格式化结果(包含调试信息)
results = [] results = []
for memory, score, node_type in memories_with_scores[:top_k]: for memory, score, node_type in memories_with_scores[:top_k]:
@@ -787,7 +787,7 @@ class MemoryTools:
f"过滤{filtered_count}个 (重要性过滤) → " f"过滤{filtered_count}个 (重要性过滤) → "
f"最终返回{len(results)}条记忆" f"最终返回{len(results)}条记忆"
) )
# 如果过滤率过高,发出警告 # 如果过滤率过高,发出警告
if total_candidates > 0: if total_candidates > 0:
filter_rate = filtered_count / total_candidates filter_rate = filtered_count / total_candidates
@@ -1000,20 +1000,21 @@ class MemoryTools:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300) response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
import re import re
import orjson import orjson
# 清理Markdown代码块 # 清理Markdown代码块
response = re.sub(r"```json\s*", "", response) response = re.sub(r"```json\s*", "", response)
response = re.sub(r"```\s*$", "", response).strip() response = re.sub(r"```\s*$", "", response).strip()
# 解析JSON # 解析JSON
data = orjson.loads(response) data = orjson.loads(response)
# 提取查询列表 # 提取查询列表
queries = data.get("queries", []) queries = data.get("queries", [])
result_queries = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) result_queries = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
for item in queries if item.get("text", "").strip()] for item in queries if item.get("text", "").strip()]
# 提取偏好节点类型 # 提取偏好节点类型
prefer_node_types = data.get("prefer_node_types", []) prefer_node_types = data.get("prefer_node_types", [])
# 确保类型正确且有效 # 确保类型正确且有效
@@ -1062,7 +1063,7 @@ class MemoryTools:
limit=top_k * 5, # 🔥 从2倍提升到5倍提高初始召回率 limit=top_k * 5, # 🔥 从2倍提升到5倍提高初始召回率
min_similarity=0.0, # 不在这里过滤,交给后续评分 min_similarity=0.0, # 不在这里过滤,交给后续评分
) )
logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}") logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}")
if similar_nodes: if similar_nodes:
logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}") logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}")

View File

@@ -62,7 +62,7 @@ async def expand_memories_with_semantic_filter(
try: try:
import time import time
start_time = time.time() start_time = time.time()
# 记录已访问的记忆,避免重复 # 记录已访问的记忆,避免重复
visited_memories = set(initial_memory_ids) visited_memories = set(initial_memory_ids)
# 记录扩展的记忆及其分数 # 记录扩展的记忆及其分数
@@ -87,17 +87,17 @@ async def expand_memories_with_semantic_filter(
# 获取该记忆的邻居记忆(通过边关系) # 获取该记忆的邻居记忆(通过边关系)
neighbor_memory_ids = set() neighbor_memory_ids = set()
# 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重) # 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重)
edge_weights = {} # 记录通过不同边类型到达的记忆的权重 edge_weights = {} # 记录通过不同边类型到达的记忆的权重
for edge in memory.edges: for edge in memory.edges:
# 获取边的目标节点 # 获取边的目标节点
target_node_id = edge.target_id target_node_id = edge.target_id
source_node_id = edge.source_id source_node_id = edge.source_id
# 🆕 根据边类型设置权重优先扩展REFERENCE、ATTRIBUTE相关的边 # 🆕 根据边类型设置权重优先扩展REFERENCE、ATTRIBUTE相关的边
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type) edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type)
if edge_type_str == "REFERENCE": if edge_type_str == "REFERENCE":
edge_weight = 1.3 # REFERENCE边权重最高引用关系 edge_weight = 1.3 # REFERENCE边权重最高引用关系
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]: elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
@@ -108,18 +108,18 @@ async def expand_memories_with_semantic_filter(
edge_weight = 0.9 # 一般关系适中降权 edge_weight = 0.9 # 一般关系适中降权
else: else:
edge_weight = 1.0 # 默认权重 edge_weight = 1.0 # 默认权重
# 通过节点找到其他记忆 # 通过节点找到其他记忆
for node_id in [target_node_id, source_node_id]: for node_id in [target_node_id, source_node_id]:
if node_id in graph_store.node_to_memories: if node_id in graph_store.node_to_memories:
for neighbor_id in graph_store.node_to_memories[node_id]: for neighbor_id in graph_store.node_to_memories[node_id]:
if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight: if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight:
edge_weights[neighbor_id] = edge_weight edge_weights[neighbor_id] = edge_weight
# 将权重高的邻居记忆加入候选 # 将权重高的邻居记忆加入候选
for neighbor_id, edge_weight in edge_weights.items(): for neighbor_id, edge_weight in edge_weights.items():
neighbor_memory_ids.add((neighbor_id, edge_weight)) neighbor_memory_ids.add((neighbor_id, edge_weight))
# 过滤掉已访问的和自己 # 过滤掉已访问的和自己
filtered_neighbors = [] filtered_neighbors = []
for neighbor_id, edge_weight in neighbor_memory_ids: for neighbor_id, edge_weight in neighbor_memory_ids:
@@ -129,7 +129,7 @@ async def expand_memories_with_semantic_filter(
# 批量评估邻居记忆 # 批量评估邻居记忆
for neighbor_mem_id, edge_weight in filtered_neighbors: for neighbor_mem_id, edge_weight in filtered_neighbors:
candidates_checked += 1 candidates_checked += 1
neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id) neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id)
if not neighbor_memory: if not neighbor_memory:
continue continue
@@ -139,7 +139,7 @@ async def expand_memories_with_semantic_filter(
(n for n in neighbor_memory.nodes if n.has_embedding()), (n for n in neighbor_memory.nodes if n.has_embedding()),
None None
) )
if not topic_node or topic_node.embedding is None: if not topic_node or topic_node.embedding is None:
continue continue
@@ -179,11 +179,11 @@ async def expand_memories_with_semantic_filter(
if len(expanded_memories) >= max_expanded: if len(expanded_memories) >= max_expanded:
logger.debug(f"⏹️ 提前停止:已达到最大扩展数量 {max_expanded}") logger.debug(f"⏹️ 提前停止:已达到最大扩展数量 {max_expanded}")
break break
# 早停检查 # 早停检查
if len(expanded_memories) >= max_expanded: if len(expanded_memories) >= max_expanded:
break break
# 记录本层统计 # 记录本层统计
depth_stats.append({ depth_stats.append({
"depth": depth + 1, "depth": depth + 1,
@@ -199,20 +199,20 @@ async def expand_memories_with_semantic_filter(
# 限制下一层的记忆数量,避免爆炸性增长 # 限制下一层的记忆数量,避免爆炸性增长
current_level_memories = next_level_memories[:max_expanded] current_level_memories = next_level_memories[:max_expanded]
# 每层让出控制权 # 每层让出控制权
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
# 排序并返回 # 排序并返回
sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded] sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded]
elapsed = time.time() - start_time elapsed = time.time() - start_time
logger.info( logger.info(
f"✅ 图扩展完成: 初始{len(initial_memory_ids)}个 → " f"✅ 图扩展完成: 初始{len(initial_memory_ids)}个 → "
f"扩展{len(sorted_results)}个新记忆 " f"扩展{len(sorted_results)}个新记忆 "
f"(深度={max_depth}, 阈值={semantic_threshold:.2f}, 耗时={elapsed:.3f}s)" f"(深度={max_depth}, 阈值={semantic_threshold:.2f}, 耗时={elapsed:.3f}s)"
) )
# 输出每层统计 # 输出每层统计
for stat in depth_stats: for stat in depth_stats:
logger.debug( logger.debug(

View File

@@ -132,7 +132,7 @@ async def generate_reply(
prompt_mode = "s4u" # 默认使用s4u模式 prompt_mode = "s4u" # 默认使用s4u模式
if action_data and "prompt_mode" in action_data: if action_data and "prompt_mode" in action_data:
prompt_mode = action_data.get("prompt_mode", "s4u") prompt_mode = action_data.get("prompt_mode", "s4u")
# 将prompt_mode添加到available_actions中作为特殊键 # 将prompt_mode添加到available_actions中作为特殊键
# 注意这里我们需要暂时使用类型忽略因为available_actions的类型定义不支持非ActionInfo值 # 注意这里我们需要暂时使用类型忽略因为available_actions的类型定义不支持非ActionInfo值
if available_actions is None: if available_actions is None:

View File

@@ -362,7 +362,7 @@ class ChatterPlanFilter:
return "最近没有聊天内容。", "没有未读消息。", [] return "最近没有聊天内容。", "没有未读消息。", []
stream_context = chat_stream.context_manager stream_context = chat_stream.context_manager
# 获取真正的已读和未读消息 # 获取真正的已读和未读消息
read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中 read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中
if not read_messages: if not read_messages:
@@ -652,30 +652,30 @@ class ChatterPlanFilter:
if not action_info: if not action_info:
logger.debug(f"动作 {action_name} 不在可用动作列表中,保留所有参数") logger.debug(f"动作 {action_name} 不在可用动作列表中,保留所有参数")
return action_data return action_data
# 获取该动作定义的合法参数 # 获取该动作定义的合法参数
defined_params = set(action_info.action_parameters.keys()) defined_params = set(action_info.action_parameters.keys())
# 合法参数集合 # 合法参数集合
valid_params = defined_params valid_params = defined_params
# 过滤参数 # 过滤参数
filtered_data = {} filtered_data = {}
removed_params = [] removed_params = []
for key, value in action_data.items(): for key, value in action_data.items():
if key in valid_params: if key in valid_params:
filtered_data[key] = value filtered_data[key] = value
else: else:
removed_params.append(key) removed_params.append(key)
# 记录被移除的参数 # 记录被移除的参数
if removed_params: if removed_params:
logger.info( logger.info(
f"🧹 [参数过滤] 动作 '{action_name}' 移除了多余参数: {removed_params}. " f"🧹 [参数过滤] 动作 '{action_name}' 移除了多余参数: {removed_params}. "
f"合法参数: {sorted(valid_params)}" f"合法参数: {sorted(valid_params)}"
) )
return filtered_data return filtered_data
def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]: def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]:

View File

@@ -545,14 +545,14 @@ async def execute_proactive_thinking(stream_id: str):
# 获取或创建该聊天流的执行锁 # 获取或创建该聊天流的执行锁
if stream_id not in _execution_locks: if stream_id not in _execution_locks:
_execution_locks[stream_id] = asyncio.Lock() _execution_locks[stream_id] = asyncio.Lock()
lock = _execution_locks[stream_id] lock = _execution_locks[stream_id]
# 尝试获取锁,如果已被占用则跳过本次执行(防止重复) # 尝试获取锁,如果已被占用则跳过本次执行(防止重复)
if lock.locked(): if lock.locked():
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 已有正在执行的主动思考任务") logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 已有正在执行的主动思考任务")
return return
async with lock: async with lock:
logger.debug(f"🤔 开始主动思考 {stream_id}") logger.debug(f"🤔 开始主动思考 {stream_id}")
@@ -563,13 +563,13 @@ async def execute_proactive_thinking(stream_id: str):
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id) chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and chat_stream.context_manager.context.is_chatter_processing: if chat_stream and chat_stream.context_manager.context.is_chatter_processing:
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息") logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息")
return return
except Exception as e: except Exception as e:
logger.warning(f"检查 chatter 处理状态时出错: {e},继续执行") logger.warning(f"检查 chatter 处理状态时出错: {e},继续执行")
# 0.1 检查白名单/黑名单 # 0.1 检查白名单/黑名单
# 从 stream_id 获取 stream_config 字符串进行验证 # 从 stream_id 获取 stream_config 字符串进行验证
try: try:

View File

@@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata(
# 导入插件主类 # 导入插件主类
from .plugin import AntiInjectionPlugin from .plugin import AntiInjectionPlugin
__all__ = ["__plugin_meta__", "AntiInjectionPlugin"] __all__ = ["AntiInjectionPlugin", "__plugin_meta__"]

View File

@@ -8,8 +8,8 @@ import time
from src.chat.security.interfaces import ( from src.chat.security.interfaces import (
SecurityAction, SecurityAction,
SecurityCheckResult,
SecurityChecker, SecurityChecker,
SecurityCheckResult,
SecurityLevel, SecurityLevel,
) )
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -4,7 +4,7 @@
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。 处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
""" """
from src.chat.security.interfaces import SecurityAction, SecurityCheckResult from src.chat.security.interfaces import SecurityCheckResult
from src.common.logger import get_logger from src.common.logger import get_logger
from .counter_attack import CounterAttackGenerator from .counter_attack import CounterAttackGenerator

View File

@@ -64,15 +64,15 @@ class CoreActionsPlugin(BasePlugin):
# --- 根据配置注册组件 --- # --- 根据配置注册组件 ---
components: ClassVar = [] components: ClassVar = []
# 注册 reply 动作 # 注册 reply 动作
if self.get_config("components.enable_reply", True): if self.get_config("components.enable_reply", True):
components.append((ReplyAction.get_action_info(), ReplyAction)) components.append((ReplyAction.get_action_info(), ReplyAction))
# 注册 respond 动作 # 注册 respond 动作
if self.get_config("components.enable_respond", True): if self.get_config("components.enable_respond", True):
components.append((RespondAction.get_action_info(), RespondAction)) components.append((RespondAction.get_action_info(), RespondAction))
# 注册 emoji 动作 # 注册 emoji 动作
if self.get_config("components.enable_emoji", True): if self.get_config("components.enable_emoji", True):
components.append((EmojiAction.get_action_info(), EmojiAction)) components.append((EmojiAction.get_action_info(), EmojiAction))

View File

@@ -22,23 +22,23 @@ class ReplyAction(BaseAction):
- 专注于理解和回应单条消息的具体内容 - 专注于理解和回应单条消息的具体内容
- 适合 Focus 模式下的精准回复 - 适合 Focus 模式下的精准回复
""" """
# 动作基本信息 # 动作基本信息
action_name = "reply" action_name = "reply"
action_description = "针对特定消息进行精准回复。深度理解并回应单条消息的具体内容。需要指定目标消息ID。" action_description = "针对特定消息进行精准回复。深度理解并回应单条消息的具体内容。需要指定目标消息ID。"
# 激活设置 # 激活设置
activation_type = ActionActivationType.ALWAYS # 回复动作总是可用 activation_type = ActionActivationType.ALWAYS # 回复动作总是可用
mode_enable = ChatMode.ALL # 在所有模式下都可用 mode_enable = ChatMode.ALL # 在所有模式下都可用
parallel_action = False # 回复动作不能与其他动作并行 parallel_action = False # 回复动作不能与其他动作并行
# 动作参数定义 # 动作参数定义
action_parameters: ClassVar = { action_parameters: ClassVar = {
"target_message_id": "要回复的目标消息ID必需来自未读消息的 <m...> 标签)", "target_message_id": "要回复的目标消息ID必需来自未读消息的 <m...> 标签)",
"content": "回复的具体内容可选由LLM生成", "content": "回复的具体内容可选由LLM生成",
"should_quote_reply": "是否引用原消息可选true/false默认false。群聊中回复较早消息或需要明确指向时使用true", "should_quote_reply": "是否引用原消息可选true/false默认false。群聊中回复较早消息或需要明确指向时使用true",
} }
# 动作使用场景 # 动作使用场景
action_require: ClassVar = [ action_require: ClassVar = [
"需要针对特定消息进行精准回复时使用", "需要针对特定消息进行精准回复时使用",
@@ -48,10 +48,10 @@ class ReplyAction(BaseAction):
"群聊中需要明确回应某个特定用户或问题时使用", "群聊中需要明确回应某个特定用户或问题时使用",
"关注单条消息的具体内容和上下文细节", "关注单条消息的具体内容和上下文细节",
] ]
# 关联类型 # 关联类型
associated_types: ClassVar[list[str]] = ["text"] associated_types: ClassVar[list[str]] = ["text"]
async def execute(self) -> tuple[bool, str]: async def execute(self) -> tuple[bool, str]:
"""执行reply动作 """执行reply动作
@@ -70,21 +70,21 @@ class RespondAction(BaseAction):
- 适合对于群聊消息下的宏观回应 - 适合对于群聊消息下的宏观回应
- 避免与单一用户深度对话而忽略其他用户的消息 - 避免与单一用户深度对话而忽略其他用户的消息
""" """
# 动作基本信息 # 动作基本信息
action_name = "respond" action_name = "respond"
action_description = "统一回应所有未读消息。理解整体对话动态和话题走向,生成连贯的回复。无需指定目标消息。" action_description = "统一回应所有未读消息。理解整体对话动态和话题走向,生成连贯的回复。无需指定目标消息。"
# 激活设置 # 激活设置
activation_type = ActionActivationType.ALWAYS # 回应动作总是可用 activation_type = ActionActivationType.ALWAYS # 回应动作总是可用
mode_enable = ChatMode.ALL # 在所有模式下都可用 mode_enable = ChatMode.ALL # 在所有模式下都可用
parallel_action = False # 回应动作不能与其他动作并行 parallel_action = False # 回应动作不能与其他动作并行
# 动作参数定义 # 动作参数定义
action_parameters: ClassVar = { action_parameters: ClassVar = {
"content": "回复的具体内容可选由LLM生成", "content": "回复的具体内容可选由LLM生成",
} }
# 动作使用场景 # 动作使用场景
action_require: ClassVar = [ action_require: ClassVar = [
"需要统一回应多条未读消息时使用Normal 模式专用)", "需要统一回应多条未读消息时使用Normal 模式专用)",
@@ -94,10 +94,10 @@ class RespondAction(BaseAction):
"适合群聊中的自然对话流,无需精确指向特定消息", "适合群聊中的自然对话流,无需精确指向特定消息",
"可以同时回应多个话题或参与者", "可以同时回应多个话题或参与者",
] ]
# 关联类型 # 关联类型
associated_types: ClassVar[list[str]] = ["text"] associated_types: ClassVar[list[str]] = ["text"]
async def execute(self) -> tuple[bool, str]: async def execute(self) -> tuple[bool, str]:
"""执行respond动作 """执行respond动作

View File

@@ -6,10 +6,10 @@
import asyncio import asyncio
import base64 import base64
import datetime import datetime
import filetype
from collections.abc import Callable from collections.abc import Callable
import aiohttp import aiohttp
import filetype
from maim_message import UserInfo from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager

View File

@@ -17,7 +17,6 @@ import uuid
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from contextlib import suppress
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum