style: 统一代码风格并采用现代化类型注解
对整个代码库进行了一次全面的代码风格清理和现代化改造,主要包括: - 移除了所有文件中多余的行尾空格。 - 将类型提示更新为 PEP 585 和 PEP 604 引入的现代语法(例如,使用 `list` 代替 `List`,使用 `|` 代替 `Optional`)。 - 清理了多个模块中未被使用的导入语句。 - 移除了不含插值变量的冗余 f-string。 - 调整了部分 `__init__.py` 文件中的 `__all__` 导出顺序,以保持一致性。 这些改动旨在提升代码的可读性和可维护性,使其与现代 Python 最佳实践保持一致,但未修改任何核心逻辑。
This commit is contained in:
committed by
Windpicker-owo
parent
5fa004503c
commit
f44ece0b29
2
bot.py
2
bot.py
@@ -588,7 +588,7 @@ class MaiBotMain:
|
||||
|
||||
async def run_async_init(self, main_system):
|
||||
"""执行异步初始化步骤"""
|
||||
|
||||
|
||||
# 初始化数据库表结构
|
||||
await self.initialize_database_async()
|
||||
|
||||
|
||||
@@ -19,14 +19,13 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
async def generate_missing_embeddings(
|
||||
target_node_types: List[str] = None,
|
||||
target_node_types: list[str] = None,
|
||||
batch_size: int = 50,
|
||||
):
|
||||
"""
|
||||
@@ -46,13 +45,13 @@ async def generate_missing_embeddings(
|
||||
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🔧 为节点生成嵌入向量")
|
||||
print("🔧 为节点生成嵌入向量")
|
||||
print(f"{'='*80}\n")
|
||||
print(f"目标节点类型: {', '.join(target_node_types)}")
|
||||
print(f"批处理大小: {batch_size}\n")
|
||||
|
||||
# 1. 初始化记忆管理器
|
||||
print(f"🔧 正在初始化记忆管理器...")
|
||||
print("🔧 正在初始化记忆管理器...")
|
||||
await initialize_memory_manager()
|
||||
manager = get_memory_manager()
|
||||
|
||||
@@ -60,10 +59,10 @@ async def generate_missing_embeddings(
|
||||
print("❌ 记忆管理器初始化失败")
|
||||
return
|
||||
|
||||
print(f"✅ 记忆管理器已初始化\n")
|
||||
print("✅ 记忆管理器已初始化\n")
|
||||
|
||||
# 2. 获取已索引的节点ID
|
||||
print(f"🔍 检查现有向量索引...")
|
||||
print("🔍 检查现有向量索引...")
|
||||
existing_node_ids = set()
|
||||
try:
|
||||
vector_count = manager.vector_store.collection.count()
|
||||
@@ -78,14 +77,14 @@ async def generate_missing_embeddings(
|
||||
)
|
||||
if result and "ids" in result:
|
||||
existing_node_ids.update(result["ids"])
|
||||
|
||||
|
||||
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取已索引节点ID失败: {e}")
|
||||
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||
|
||||
# 3. 收集需要生成嵌入的节点
|
||||
print(f"🔍 扫描需要生成嵌入的节点...")
|
||||
print("🔍 扫描需要生成嵌入的节点...")
|
||||
all_memories = manager.graph_store.get_all_memories()
|
||||
|
||||
nodes_to_process = []
|
||||
@@ -110,7 +109,7 @@ async def generate_missing_embeddings(
|
||||
})
|
||||
type_stats[node.node_type.value]["need_emb"] += 1
|
||||
|
||||
print(f"\n📊 扫描结果:")
|
||||
print("\n📊 扫描结果:")
|
||||
for node_type in target_node_types:
|
||||
stats = type_stats[node_type]
|
||||
already_ok = stats["already_indexed"]
|
||||
@@ -121,11 +120,11 @@ async def generate_missing_embeddings(
|
||||
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
|
||||
|
||||
if len(nodes_to_process) == 0:
|
||||
print(f"✅ 所有节点已有嵌入向量,无需生成")
|
||||
print("✅ 所有节点已有嵌入向量,无需生成")
|
||||
return
|
||||
|
||||
# 3. 批量生成嵌入
|
||||
print(f"🚀 开始生成嵌入向量...\n")
|
||||
print("🚀 开始生成嵌入向量...\n")
|
||||
|
||||
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
|
||||
success_count = 0
|
||||
@@ -193,22 +192,22 @@ async def generate_missing_embeddings(
|
||||
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
|
||||
|
||||
# 4. 保存图数据(更新节点的 embedding 字段)
|
||||
print(f"💾 保存图数据...")
|
||||
print("💾 保存图数据...")
|
||||
try:
|
||||
await manager.persistence.save_graph_store(manager.graph_store)
|
||||
print(f"✅ 图数据已保存\n")
|
||||
print("✅ 图数据已保存\n")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图数据失败", exc_info=True)
|
||||
logger.error("保存图数据失败", exc_info=True)
|
||||
print(f"❌ 保存失败: {e}\n")
|
||||
|
||||
# 5. 验证结果
|
||||
print(f"🔍 验证向量索引...")
|
||||
print("🔍 验证向量索引...")
|
||||
final_vector_count = manager.vector_store.collection.count()
|
||||
stats = manager.graph_store.get_statistics()
|
||||
total_nodes = stats["total_nodes"]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"📊 生成完成")
|
||||
print("📊 生成完成")
|
||||
print(f"{'='*80}")
|
||||
print(f"处理节点数: {len(nodes_to_process)}")
|
||||
print(f"成功生成: {success_count}")
|
||||
@@ -219,7 +218,7 @@ async def generate_missing_embeddings(
|
||||
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
|
||||
|
||||
# 6. 测试搜索
|
||||
print(f"🧪 测试搜索功能...")
|
||||
print("🧪 测试搜索功能...")
|
||||
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
|
||||
|
||||
for query in test_queries:
|
||||
|
||||
@@ -38,7 +38,7 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
|
||||
|
||||
# ========== 性能配置参数 ==========
|
||||
#
|
||||
#
|
||||
# 知识提取(步骤2:txt转json)并发控制
|
||||
# - 控制同时进行的LLM提取请求数量
|
||||
# - 推荐值: 3-10,取决于API速率限制
|
||||
@@ -184,7 +184,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
tuple: (doc_item或None, failed_hash或None)
|
||||
"""
|
||||
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
|
||||
|
||||
|
||||
# 🔧 优化:使用异步文件检查,避免阻塞
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
@@ -215,11 +215,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
"extracted_entities": extracted_data.get("entities", []),
|
||||
"extracted_triples": extracted_data.get("triples", []),
|
||||
}
|
||||
|
||||
|
||||
# 保存到缓存(异步写入)
|
||||
async with aiofiles.open(temp_file_path, "wb") as f:
|
||||
await f.write(orjson.dumps(doc_item))
|
||||
|
||||
|
||||
return doc_item, None
|
||||
except Exception as e:
|
||||
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
||||
@@ -249,13 +249,13 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
failed_hashes, open_ie_docs = [], []
|
||||
|
||||
|
||||
# 🔧 关键修复:创建单个 LLM 请求实例,复用连接
|
||||
llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction")
|
||||
|
||||
|
||||
# 🔧 并发控制:限制最大并发数,防止速率限制
|
||||
semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY)
|
||||
|
||||
|
||||
async def extract_with_semaphore(pg_hash, paragraph):
|
||||
"""带信号量控制的提取函数"""
|
||||
async with semaphore:
|
||||
@@ -266,7 +266,7 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
extract_with_semaphore(p_hash, paragraph)
|
||||
for p_hash, paragraph in paragraphs_dict.items()
|
||||
]
|
||||
|
||||
|
||||
total = len(tasks)
|
||||
completed = 0
|
||||
|
||||
@@ -284,7 +284,7 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
TimeRemainingColumn(),
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]正在提取信息...", total=total)
|
||||
|
||||
|
||||
# 🔧 优化:使用 asyncio.gather 并发执行所有任务
|
||||
# return_exceptions=True 确保单个失败不影响其他任务
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
@@ -293,7 +293,7 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
failed_hashes.append(failed_hash)
|
||||
elif doc_item:
|
||||
open_ie_docs.append(doc_item)
|
||||
|
||||
|
||||
completed += 1
|
||||
progress.update(task, advance=1)
|
||||
|
||||
@@ -415,7 +415,7 @@ def rebuild_faiss_only():
|
||||
logger.info("--- 重建 FAISS 索引 ---")
|
||||
# 重建索引不需要并发参数(不涉及 embedding 生成)
|
||||
embed_manager = EmbeddingManager()
|
||||
|
||||
|
||||
logger.info("正在加载现有的 Embedding 库...")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
|
||||
@@ -4,13 +4,13 @@
|
||||
提供 Web API 用于可视化记忆图数据
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request, Query
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
@@ -29,7 +29,7 @@ router = APIRouter()
|
||||
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
|
||||
|
||||
|
||||
def find_available_data_files() -> List[Path]:
|
||||
def find_available_data_files() -> list[Path]:
|
||||
"""查找所有可用的记忆图数据文件"""
|
||||
files = []
|
||||
if not data_dir.exists():
|
||||
@@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]:
|
||||
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
|
||||
|
||||
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
|
||||
def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
|
||||
"""从磁盘加载图数据"""
|
||||
global graph_data_cache, current_data_file
|
||||
|
||||
@@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any
|
||||
if not graph_file.exists():
|
||||
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
|
||||
|
||||
with open(graph_file, "r", encoding="utf-8") as f:
|
||||
with open(graph_file, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
|
||||
nodes = data.get("nodes", [])
|
||||
@@ -150,7 +150,7 @@ async def index(request: Request):
|
||||
return templates.TemplateResponse("visualizer.html", {"request": request})
|
||||
|
||||
|
||||
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
|
||||
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
|
||||
"""从 MemoryManager 提取并格式化图数据"""
|
||||
if not memory_manager.graph_store:
|
||||
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
|
||||
@@ -188,7 +188,7 @@ def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
|
||||
"arrows": "to",
|
||||
"memory_id": memory.id,
|
||||
}
|
||||
|
||||
|
||||
edges_list = list(edges_dict.values())
|
||||
|
||||
stats = memory_manager.get_statistics()
|
||||
@@ -261,7 +261,7 @@ async def get_paginated_graph(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
|
||||
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
|
||||
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
|
||||
node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
|
||||
):
|
||||
"""分页获取图数据,支持重要性过滤"""
|
||||
try:
|
||||
@@ -301,13 +301,13 @@ async def get_paginated_graph(
|
||||
total_pages = (total_nodes + page_size - 1) // page_size
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_nodes)
|
||||
|
||||
|
||||
paginated_nodes = nodes_with_importance[start_idx:end_idx]
|
||||
node_ids = set(n["id"] for n in paginated_nodes)
|
||||
|
||||
# 只保留连接分页节点的边
|
||||
paginated_edges = [
|
||||
e for e in edges
|
||||
e for e in edges
|
||||
if e.get("from") in node_ids and e.get("to") in node_ids
|
||||
]
|
||||
|
||||
@@ -383,7 +383,7 @@ async def get_clustered_graph(
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
|
||||
def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
|
||||
"""简单的图聚类算法:按类型和连接度聚类"""
|
||||
# 构建邻接表
|
||||
adjacency = defaultdict(set)
|
||||
@@ -412,21 +412,21 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
|
||||
for node in type_nodes:
|
||||
importance = len(adjacency[node["id"]])
|
||||
node_importance.append((node, importance))
|
||||
|
||||
|
||||
node_importance.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
# 保留前N个重要节点
|
||||
keep_count = min(len(type_nodes), max_nodes // len(type_groups))
|
||||
for node, importance in node_importance[:keep_count]:
|
||||
clustered_nodes.append(node)
|
||||
node_mapping[node["id"]] = node["id"]
|
||||
|
||||
|
||||
# 其余节点聚合为一个超级节点
|
||||
if len(node_importance) > keep_count:
|
||||
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
|
||||
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
|
||||
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
|
||||
|
||||
|
||||
clustered_nodes.append({
|
||||
"id": cluster_id,
|
||||
"label": cluster_label,
|
||||
@@ -436,7 +436,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
|
||||
"cluster_size": len(clustered_node_ids),
|
||||
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
|
||||
})
|
||||
|
||||
|
||||
for node_id in clustered_node_ids:
|
||||
node_mapping[node_id] = cluster_id
|
||||
|
||||
@@ -445,7 +445,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
|
||||
for edge in edges:
|
||||
from_id = node_mapping.get(edge["from"])
|
||||
to_id = node_mapping.get(edge["to"])
|
||||
|
||||
|
||||
if from_id and to_id and from_id != to_id:
|
||||
edge_key = tuple(sorted([from_id, to_id]))
|
||||
if edge_key not in edge_set:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
|
||||
@@ -161,16 +161,16 @@ class EmbeddingStore:
|
||||
# 限制 chunk_size 和 max_workers 在合理范围内
|
||||
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
|
||||
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
|
||||
|
||||
|
||||
semaphore = asyncio.Semaphore(max_workers)
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
results = {}
|
||||
|
||||
|
||||
# 将字符串列表分成多个 chunk
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunks.append(strs[i : i + chunk_size])
|
||||
|
||||
|
||||
async def _process_chunk(chunk: list[str]):
|
||||
"""处理一个 chunk 的字符串(批量获取 embedding)"""
|
||||
async with semaphore:
|
||||
@@ -180,12 +180,12 @@ class EmbeddingStore:
|
||||
embedding = await EmbeddingStore._get_embedding_async(llm, s)
|
||||
embeddings.append(embedding)
|
||||
results[s] = embedding
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(len(chunk))
|
||||
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
# 并发处理所有 chunks
|
||||
tasks = [_process_chunk(chunk) for chunk in chunks]
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -418,22 +418,22 @@ class EmbeddingStore:
|
||||
# 🔧 修复:检查所有 embedding 的维度是否一致
|
||||
dimensions = [len(emb) for emb in array]
|
||||
unique_dims = set(dimensions)
|
||||
|
||||
|
||||
if len(unique_dims) > 1:
|
||||
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
|
||||
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
|
||||
|
||||
|
||||
# 获取期望的维度(使用最常见的维度)
|
||||
from collections import Counter
|
||||
dim_counter = Counter(dimensions)
|
||||
expected_dim = dim_counter.most_common(1)[0][0]
|
||||
logger.warning(f"将使用最常见的维度: {expected_dim}")
|
||||
|
||||
|
||||
# 过滤掉维度不匹配的 embedding
|
||||
filtered_array = []
|
||||
filtered_idx2hash = {}
|
||||
skipped_count = 0
|
||||
|
||||
|
||||
for i, emb in enumerate(array):
|
||||
if len(emb) == expected_dim:
|
||||
filtered_array.append(emb)
|
||||
@@ -442,11 +442,11 @@ class EmbeddingStore:
|
||||
skipped_count += 1
|
||||
hash_key = self.idx2hash[str(i)]
|
||||
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
|
||||
|
||||
|
||||
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
|
||||
array = filtered_array
|
||||
self.idx2hash = filtered_idx2hash
|
||||
|
||||
|
||||
if not array:
|
||||
logger.error("过滤后没有可用的 embedding,无法构建索引")
|
||||
embedding_dim = expected_dim
|
||||
|
||||
@@ -13,4 +13,4 @@ __all__ = [
|
||||
"StreamLoopManager",
|
||||
"message_manager",
|
||||
"stream_loop_manager",
|
||||
]
|
||||
]
|
||||
|
||||
@@ -82,7 +82,7 @@ class SingleStreamContextManager:
|
||||
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
|
||||
|
||||
# 如果使用了缓存系统,输出调试信息
|
||||
if cache_enabled and self.context.is_cache_enabled:
|
||||
if self.context.is_chatter_processing:
|
||||
|
||||
@@ -111,9 +111,9 @@ class StreamLoopManager:
|
||||
# 获取或创建该流的启动锁
|
||||
if stream_id not in self._stream_start_locks:
|
||||
self._stream_start_locks[stream_id] = asyncio.Lock()
|
||||
|
||||
|
||||
lock = self._stream_start_locks[stream_id]
|
||||
|
||||
|
||||
# 使用锁防止并发启动同一个流的多个循环任务
|
||||
async with lock:
|
||||
# 获取流上下文
|
||||
@@ -148,7 +148,7 @@ class StreamLoopManager:
|
||||
# 紧急取消
|
||||
context.stream_loop_task.cancel()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
|
||||
|
||||
# 将任务记录到 StreamContext 中
|
||||
@@ -252,7 +252,7 @@ class StreamLoopManager:
|
||||
self.stats["total_process_cycles"] += 1
|
||||
if success:
|
||||
logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功")
|
||||
|
||||
|
||||
# 🔒 处理成功后,等待一小段时间确保清理操作完成
|
||||
# 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环
|
||||
await asyncio.sleep(0.1)
|
||||
@@ -382,7 +382,7 @@ class StreamLoopManager:
|
||||
self.chatter_manager.process_stream_context(stream_id, context),
|
||||
name=f"chatter_process_{stream_id}"
|
||||
)
|
||||
|
||||
|
||||
# 等待 chatter 任务完成
|
||||
results = await chatter_task
|
||||
success = results.get("success", False)
|
||||
@@ -398,8 +398,8 @@ class StreamLoopManager:
|
||||
else:
|
||||
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
|
||||
|
||||
return success
|
||||
except asyncio.CancelledError:
|
||||
return success
|
||||
except asyncio.CancelledError:
|
||||
if chatter_task and not chatter_task.done():
|
||||
chatter_task.cancel()
|
||||
raise
|
||||
@@ -709,4 +709,4 @@ class StreamLoopManager:
|
||||
|
||||
|
||||
# 全局流循环管理器实例
|
||||
stream_loop_manager = StreamLoopManager()
|
||||
stream_loop_manager = StreamLoopManager()
|
||||
|
||||
@@ -417,7 +417,7 @@ class MessageManager:
|
||||
return
|
||||
|
||||
# 记录详细信息
|
||||
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
|
||||
msg_previews = [f"{str(msg.message_id)[:8] if msg.message_id else 'unknown'}:{msg.processed_plain_text[:20] if msg.processed_plain_text else '(空)'}"
|
||||
for msg in unread_messages[:3]] # 只显示前3条
|
||||
logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 开始清除 {len(unread_messages)} 条未读消息, 示例: {msg_previews}")
|
||||
|
||||
@@ -446,15 +446,15 @@ class MessageManager:
|
||||
context = chat_stream.context_manager.context
|
||||
if hasattr(context, "unread_messages") and context.unread_messages:
|
||||
unread_count = len(context.unread_messages)
|
||||
|
||||
|
||||
# 如果还有未读消息,说明 action_manager 可能遗漏了,标记它们
|
||||
if unread_count > 0:
|
||||
if unread_count > 0:
|
||||
# 获取所有未读消息的 ID
|
||||
message_ids = [msg.message_id for msg in context.unread_messages]
|
||||
|
||||
|
||||
# 标记为已读(会移到历史消息)
|
||||
success = chat_stream.context_manager.mark_messages_as_read(message_ids)
|
||||
|
||||
|
||||
if success:
|
||||
logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读")
|
||||
else:
|
||||
@@ -481,7 +481,7 @@ class MessageManager:
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
|
||||
chat_stream.context_manager.context.is_chatter_processing = is_processing
|
||||
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
|
||||
except Exception as e:
|
||||
@@ -517,7 +517,7 @@ class MessageManager:
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
|
||||
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
|
||||
return chat_stream.context_manager.context.is_chatter_processing
|
||||
except Exception:
|
||||
pass
|
||||
@@ -677,4 +677,4 @@ class MessageManager:
|
||||
|
||||
|
||||
# 创建全局消息管理器实例
|
||||
message_manager = MessageManager()
|
||||
message_manager = MessageManager()
|
||||
|
||||
@@ -248,16 +248,16 @@ class ChatterActionManager:
|
||||
try:
|
||||
# 根据动作类型确定提示词模式
|
||||
prompt_mode = "s4u" if action_name == "reply" else "normal"
|
||||
|
||||
|
||||
# 将prompt_mode传递给generate_reply
|
||||
action_data_with_mode = (action_data or {}).copy()
|
||||
action_data_with_mode["prompt_mode"] = prompt_mode
|
||||
|
||||
|
||||
# 只传递当前正在执行的动作,而不是所有可用动作
|
||||
# 这样可以让LLM明确知道"已决定执行X动作",而不是"有这些动作可用"
|
||||
current_action_info = self._using_actions.get(action_name)
|
||||
current_actions: dict[str, Any] = {action_name: current_action_info} if current_action_info else {}
|
||||
|
||||
|
||||
# 附加目标消息信息(如果存在)
|
||||
if target_message:
|
||||
# 提取目标消息的关键信息
|
||||
@@ -268,7 +268,7 @@ class ChatterActionManager:
|
||||
"time": getattr(target_message, "time", 0),
|
||||
}
|
||||
current_actions["_target_message"] = target_msg_info
|
||||
|
||||
|
||||
success, response_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=chat_stream,
|
||||
reply_message=target_message,
|
||||
@@ -295,12 +295,12 @@ class ChatterActionManager:
|
||||
should_quote_reply = None
|
||||
if action_data and isinstance(action_data, dict):
|
||||
should_quote_reply = action_data.get("should_quote_reply", None)
|
||||
|
||||
|
||||
# respond动作默认不引用回复,保持对话流畅
|
||||
if action_name == "respond" and should_quote_reply is None:
|
||||
should_quote_reply = False
|
||||
|
||||
async def _after_reply():
|
||||
async def _after_reply():
|
||||
# 发送并存储回复
|
||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||
chat_stream,
|
||||
|
||||
@@ -372,7 +372,7 @@ class DefaultReplyer:
|
||||
# 确保类型安全
|
||||
if isinstance(mode, str):
|
||||
prompt_mode_value = mode
|
||||
|
||||
|
||||
# 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_reply_context(
|
||||
@@ -1166,16 +1166,16 @@ class DefaultReplyer:
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream_obj = await chat_manager.get_stream(chat_id)
|
||||
|
||||
|
||||
if chat_stream_obj:
|
||||
unread_messages = chat_stream_obj.context_manager.get_unread_messages()
|
||||
if unread_messages:
|
||||
# 使用最后一条未读消息作为参考
|
||||
last_msg = unread_messages[-1]
|
||||
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform
|
||||
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else ""
|
||||
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else ""
|
||||
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else ""
|
||||
platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
|
||||
user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
|
||||
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
|
||||
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
|
||||
processed_plain_text = last_msg.processed_plain_text or ""
|
||||
else:
|
||||
# 没有未读消息,使用默认值
|
||||
@@ -1258,19 +1258,19 @@ class DefaultReplyer:
|
||||
if available_actions:
|
||||
# 过滤掉特殊键(以_开头)
|
||||
action_items = {k: v for k, v in available_actions.items() if not k.startswith("_")}
|
||||
|
||||
|
||||
# 提取目标消息信息(如果存在)
|
||||
target_msg_info = available_actions.get("_target_message") # type: ignore
|
||||
|
||||
|
||||
if action_items:
|
||||
if len(action_items) == 1:
|
||||
# 单个动作
|
||||
action_name, action_info = list(action_items.items())[0]
|
||||
action_desc = action_info.description
|
||||
|
||||
|
||||
# 构建基础决策信息
|
||||
action_descriptions = f"## 决策信息\n\n你已经决定要执行 **{action_name}** 动作({action_desc})。\n\n"
|
||||
|
||||
|
||||
# 只有需要目标消息的动作才显示目标消息详情
|
||||
# respond 动作是统一回应所有未读消息,不应该显示特定目标消息
|
||||
if action_name not in ["respond"] and target_msg_info and isinstance(target_msg_info, dict):
|
||||
@@ -1279,7 +1279,7 @@ class DefaultReplyer:
|
||||
content = target_msg_info.get("content", "")
|
||||
msg_time = target_msg_info.get("time", 0)
|
||||
time_str = time_module.strftime("%H:%M:%S", time_module.localtime(msg_time)) if msg_time else "未知时间"
|
||||
|
||||
|
||||
action_descriptions += f"**目标消息**: {time_str} {sender} 说: {content}\n\n"
|
||||
else:
|
||||
# 多个动作
|
||||
@@ -2166,7 +2166,7 @@ class DefaultReplyer:
|
||||
except Exception as e:
|
||||
logger.error(f"存储聊天记忆失败: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
插件可以通过实现这些接口来扩展安全功能。
|
||||
"""
|
||||
|
||||
from .interfaces import SecurityCheckResult, SecurityChecker
|
||||
from .interfaces import SecurityChecker, SecurityCheckResult
|
||||
from .manager import SecurityManager, get_security_manager
|
||||
|
||||
__all__ = [
|
||||
"SecurityChecker",
|
||||
"SecurityCheckResult",
|
||||
"SecurityChecker",
|
||||
"SecurityManager",
|
||||
"get_security_manager",
|
||||
]
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel
|
||||
from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
|
||||
|
||||
logger = get_logger("security.manager")
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class StreamContext(BaseDataModel):
|
||||
break
|
||||
|
||||
def mark_message_as_read(self, message_id: str):
|
||||
"""标记消息为已读"""
|
||||
"""标记消息为已读"""
|
||||
# 先找到要标记的消息(处理 int/str 类型不匹配问题)
|
||||
message_to_mark = None
|
||||
for msg in self.unread_messages:
|
||||
@@ -106,7 +106,7 @@ class StreamContext(BaseDataModel):
|
||||
if str(msg.message_id) == str(message_id):
|
||||
message_to_mark = msg
|
||||
break
|
||||
|
||||
|
||||
# 然后移动到历史消息
|
||||
if message_to_mark:
|
||||
message_to_mark.is_read = True
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.memory_utils import estimate_size_smart
|
||||
@@ -96,7 +97,7 @@ class LRUCache(Generic[T]):
|
||||
self._lock = asyncio.Lock()
|
||||
self._stats = CacheStats()
|
||||
|
||||
async def get(self, key: str) -> Optional[T]:
|
||||
async def get(self, key: str) -> T | None:
|
||||
"""获取缓存值
|
||||
|
||||
Args:
|
||||
@@ -137,8 +138,8 @@ class LRUCache(Generic[T]):
|
||||
self,
|
||||
key: str,
|
||||
value: T,
|
||||
size: Optional[int] = None,
|
||||
ttl: Optional[float] = None,
|
||||
size: int | None = None,
|
||||
ttl: float | None = None,
|
||||
) -> None:
|
||||
"""设置缓存值
|
||||
|
||||
@@ -287,8 +288,8 @@ class MultiLevelCache:
|
||||
async def get(
|
||||
self,
|
||||
key: str,
|
||||
loader: Optional[Callable[[], Any]] = None,
|
||||
) -> Optional[Any]:
|
||||
loader: Callable[[], Any] | None = None,
|
||||
) -> Any | None:
|
||||
"""从缓存获取数据
|
||||
|
||||
查询顺序:L1 -> L2 -> loader
|
||||
@@ -329,8 +330,8 @@ class MultiLevelCache:
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
size: Optional[int] = None,
|
||||
ttl: Optional[float] = None,
|
||||
size: int | None = None,
|
||||
ttl: float | None = None,
|
||||
) -> None:
|
||||
"""设置缓存值
|
||||
|
||||
@@ -390,7 +391,7 @@ class MultiLevelCache:
|
||||
await self.l2_cache.clear()
|
||||
logger.info("所有缓存已清空")
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
|
||||
# 🔧 修复:并行获取统计信息,避免锁嵌套
|
||||
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
|
||||
@@ -492,7 +493,7 @@ class MultiLevelCache:
|
||||
logger.error(f"{cache_name}统计获取异常: {e}")
|
||||
return CacheStats()
|
||||
|
||||
async def _get_cache_keys_safe(self, cache) -> Set[str]:
|
||||
async def _get_cache_keys_safe(self, cache) -> builtins.set[str]:
|
||||
"""安全获取缓存键集合(带超时)"""
|
||||
try:
|
||||
# 快速获取键集合,使用超时避免死锁
|
||||
@@ -507,12 +508,12 @@ class MultiLevelCache:
|
||||
logger.error(f"缓存键获取异常: {e}")
|
||||
return set()
|
||||
|
||||
async def _extract_keys_with_lock(self, cache) -> Set[str]:
|
||||
async def _extract_keys_with_lock(self, cache) -> builtins.set[str]:
|
||||
"""在锁保护下提取键集合"""
|
||||
async with cache._lock:
|
||||
return set(cache._cache.keys())
|
||||
|
||||
async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int:
|
||||
async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int:
|
||||
"""安全计算内存使用(带超时)"""
|
||||
if not keys:
|
||||
return 0
|
||||
@@ -529,7 +530,7 @@ class MultiLevelCache:
|
||||
logger.error(f"内存计算异常: {e}")
|
||||
return 0
|
||||
|
||||
async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int:
|
||||
async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int:
|
||||
"""在锁保护下计算内存使用"""
|
||||
total_size = 0
|
||||
async with cache._lock:
|
||||
@@ -749,7 +750,7 @@ class MultiLevelCache:
|
||||
|
||||
|
||||
# 全局缓存实例
|
||||
_global_cache: Optional[MultiLevelCache] = None
|
||||
_global_cache: MultiLevelCache | None = None
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import socket
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from rich.traceback import install
|
||||
from uvicorn import Config
|
||||
from uvicorn import Server as UvicornServer
|
||||
|
||||
@@ -444,7 +444,7 @@ class OpenaiClient(BaseClient):
|
||||
# 🔧 优化:增加连接池限制,支持高并发embedding请求
|
||||
# 默认httpx限制为100,对于高频embedding场景不够用
|
||||
import httpx
|
||||
|
||||
|
||||
limits = httpx.Limits(
|
||||
max_keepalive_connections=200, # 保持活跃连接数(原100)
|
||||
max_connections=300, # 最大总连接数(原100)
|
||||
|
||||
@@ -128,7 +128,7 @@ class MemoryBuilder:
|
||||
# 6. 构建 Memory 对象
|
||||
# 新记忆应该有较高的初始激活度
|
||||
initial_activation = 0.75 # 新记忆初始激活度为 0.75
|
||||
|
||||
|
||||
memory = Memory(
|
||||
id=memory_id,
|
||||
subject_id=subject_node.id,
|
||||
|
||||
@@ -149,7 +149,7 @@ class MemoryManager:
|
||||
# 读取阈值过滤配置
|
||||
search_min_importance = self.config.search_min_importance
|
||||
search_similarity_threshold = self.config.search_similarity_threshold
|
||||
|
||||
|
||||
logger.info(
|
||||
f"📊 配置检查: search_max_expand_depth={expand_depth}, "
|
||||
f"search_expand_semantic_threshold={expand_semantic_threshold}, "
|
||||
@@ -417,7 +417,7 @@ class MemoryManager:
|
||||
# 使用配置的默认值
|
||||
if top_k is None:
|
||||
top_k = getattr(self.config, "search_top_k", 10)
|
||||
|
||||
|
||||
# 准备搜索参数
|
||||
params = {
|
||||
"query": query,
|
||||
@@ -951,7 +951,7 @@ class MemoryManager:
|
||||
)
|
||||
else:
|
||||
logger.debug(f"记忆已删除: {memory_id} (删除了 {deleted_vectors} 个向量)")
|
||||
|
||||
|
||||
# 4. 保存更新
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
return True
|
||||
@@ -984,7 +984,7 @@ class MemoryManager:
|
||||
try:
|
||||
forgotten_count = 0
|
||||
all_memories = self.graph_store.get_all_memories()
|
||||
|
||||
|
||||
# 获取配置参数
|
||||
min_importance = getattr(self.config, "forgetting_min_importance", 0.8)
|
||||
decay_rate = getattr(self.config, "activation_decay_rate", 0.9)
|
||||
@@ -1010,10 +1010,10 @@ class MemoryManager:
|
||||
try:
|
||||
last_access_dt = datetime.fromisoformat(last_access)
|
||||
days_passed = (datetime.now() - last_access_dt).days
|
||||
|
||||
|
||||
# 应用指数衰减:activation = base * (decay_rate ^ days)
|
||||
current_activation = base_activation * (decay_rate ** days_passed)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"记忆 {memory.id[:8]}: 基础激活度={base_activation:.3f}, "
|
||||
f"经过{days_passed}天衰减后={current_activation:.3f}"
|
||||
@@ -1035,20 +1035,20 @@ class MemoryManager:
|
||||
# 批量遗忘记忆(不立即清理孤立节点)
|
||||
if memories_to_forget:
|
||||
logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...")
|
||||
|
||||
|
||||
for memory_id, activation in memories_to_forget:
|
||||
# cleanup_orphans=False:暂不清理孤立节点
|
||||
success = await self.forget_memory(memory_id, cleanup_orphans=False)
|
||||
if success:
|
||||
forgotten_count += 1
|
||||
|
||||
|
||||
# 统一清理孤立节点和边
|
||||
logger.info("批量遗忘完成,开始统一清理孤立节点和边...")
|
||||
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
|
||||
|
||||
|
||||
# 保存最终更新
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, "
|
||||
f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边"
|
||||
@@ -1079,31 +1079,31 @@ class MemoryManager:
|
||||
# 1. 清理孤立节点
|
||||
# graph_store.node_to_memories 记录了每个节点属于哪些记忆
|
||||
nodes_to_remove = []
|
||||
|
||||
|
||||
for node_id, memory_ids in list(self.graph_store.node_to_memories.items()):
|
||||
# 如果节点不再属于任何记忆,标记为删除
|
||||
if not memory_ids:
|
||||
nodes_to_remove.append(node_id)
|
||||
|
||||
|
||||
# 从图中删除孤立节点
|
||||
for node_id in nodes_to_remove:
|
||||
if self.graph_store.graph.has_node(node_id):
|
||||
self.graph_store.graph.remove_node(node_id)
|
||||
orphan_nodes_count += 1
|
||||
|
||||
|
||||
# 从映射中删除
|
||||
if node_id in self.graph_store.node_to_memories:
|
||||
del self.graph_store.node_to_memories[node_id]
|
||||
|
||||
|
||||
# 2. 清理孤立边(指向已删除节点的边)
|
||||
edges_to_remove = []
|
||||
|
||||
for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'):
|
||||
|
||||
for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
|
||||
# 检查边的源节点和目标节点是否还存在于node_to_memories中
|
||||
if source not in self.graph_store.node_to_memories or \
|
||||
target not in self.graph_store.node_to_memories:
|
||||
edges_to_remove.append((source, target))
|
||||
|
||||
|
||||
# 删除孤立边
|
||||
for source, target in edges_to_remove:
|
||||
try:
|
||||
@@ -1111,12 +1111,12 @@ class MemoryManager:
|
||||
orphan_edges_count += 1
|
||||
except Exception as e:
|
||||
logger.debug(f"删除边失败 {source} -> {target}: {e}")
|
||||
|
||||
|
||||
if orphan_nodes_count > 0 or orphan_edges_count > 0:
|
||||
logger.info(
|
||||
f"清理完成: {orphan_nodes_count} 个孤立节点, {orphan_edges_count} 条孤立边"
|
||||
)
|
||||
|
||||
|
||||
return orphan_nodes_count, orphan_edges_count
|
||||
|
||||
except Exception as e:
|
||||
@@ -1258,7 +1258,7 @@ class MemoryManager:
|
||||
mem for mem in recent_memories
|
||||
if mem.importance >= min_importance_for_consolidation
|
||||
]
|
||||
|
||||
|
||||
result["importance_filtered"] = len(recent_memories) - len(important_memories)
|
||||
logger.info(
|
||||
f"📊 步骤2: 重要性过滤 (阈值={min_importance_for_consolidation:.2f}): "
|
||||
@@ -1382,26 +1382,26 @@ class MemoryManager:
|
||||
# ===== 步骤4: 向量检索关联记忆 + LLM分析关系 =====
|
||||
# 过滤掉已删除的记忆
|
||||
remaining_memories = [m for m in important_memories if m.id not in deleted_ids]
|
||||
|
||||
|
||||
if not remaining_memories:
|
||||
logger.info("✅ 记忆整理完成: 去重后无剩余记忆")
|
||||
return
|
||||
|
||||
logger.info(f"📍 步骤4: 开始关联分析 ({len(remaining_memories)} 条记忆)...")
|
||||
|
||||
|
||||
# 分批处理记忆关联
|
||||
llm_batch_size = getattr(self.config, "consolidation_llm_batch_size", 10)
|
||||
max_candidates_per_memory = getattr(self.config, "consolidation_max_candidates", 5)
|
||||
min_confidence = getattr(self.config, "consolidation_min_confidence", 0.6)
|
||||
|
||||
|
||||
all_new_edges = [] # 收集所有新建的边
|
||||
|
||||
|
||||
for batch_start in range(0, len(remaining_memories), llm_batch_size):
|
||||
batch_end = min(batch_start + llm_batch_size, len(remaining_memories))
|
||||
batch = remaining_memories[batch_start:batch_end]
|
||||
|
||||
|
||||
logger.debug(f"处理批次 {batch_start//llm_batch_size + 1}/{(len(remaining_memories)-1)//llm_batch_size + 1}")
|
||||
|
||||
|
||||
for memory in batch:
|
||||
# 跳过已经有很多连接的记忆
|
||||
existing_edges = len([
|
||||
@@ -1454,14 +1454,14 @@ class MemoryManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"创建关联边失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 每个批次后让出控制权
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# ===== 步骤5: 统一更新记忆数据 =====
|
||||
if all_new_edges:
|
||||
logger.info(f"📍 步骤5: 统一更新 {len(all_new_edges)} 条新关联边...")
|
||||
|
||||
|
||||
for memory, edge, relation in all_new_edges:
|
||||
try:
|
||||
# 添加到图
|
||||
@@ -2301,7 +2301,7 @@ class MemoryManager:
|
||||
# 使用 asyncio.wait_for 来支持取消
|
||||
await asyncio.wait_for(
|
||||
asyncio.sleep(initial_delay),
|
||||
timeout=float('inf') # 允许随时取消
|
||||
timeout=float("inf") # 允许随时取消
|
||||
)
|
||||
|
||||
# 检查是否仍然需要运行
|
||||
|
||||
@@ -482,7 +482,7 @@ class GraphStore:
|
||||
for node in memory.nodes:
|
||||
if node.id in self.node_to_memories:
|
||||
self.node_to_memories[node.id].discard(memory_id)
|
||||
|
||||
|
||||
# 可选:立即清理孤立节点
|
||||
if cleanup_orphans:
|
||||
# 如果该节点不再属于任何记忆,从图中移除节点
|
||||
|
||||
@@ -72,12 +72,12 @@ class MemoryTools:
|
||||
self.max_expand_depth = max_expand_depth
|
||||
self.expand_semantic_threshold = expand_semantic_threshold
|
||||
self.search_top_k = search_top_k
|
||||
|
||||
|
||||
# 保存权重配置
|
||||
self.base_vector_weight = search_vector_weight
|
||||
self.base_importance_weight = search_importance_weight
|
||||
self.base_recency_weight = search_recency_weight
|
||||
|
||||
|
||||
# 保存阈值过滤配置
|
||||
self.search_min_importance = search_min_importance
|
||||
self.search_similarity_threshold = search_similarity_threshold
|
||||
@@ -516,14 +516,14 @@ class MemoryTools:
|
||||
|
||||
# 1. 根据策略选择检索方式
|
||||
llm_prefer_types = [] # LLM识别的偏好节点类型
|
||||
|
||||
|
||||
if use_multi_query:
|
||||
# 多查询策略(返回节点列表 + 偏好类型)
|
||||
similar_nodes, llm_prefer_types = await self._multi_query_search(query, top_k, context)
|
||||
else:
|
||||
# 传统单查询策略
|
||||
similar_nodes = await self._single_query_search(query, top_k)
|
||||
|
||||
|
||||
# 合并用户指定的偏好类型和LLM识别的偏好类型
|
||||
all_prefer_types = list(set(prefer_node_types + llm_prefer_types))
|
||||
if all_prefer_types:
|
||||
@@ -551,7 +551,7 @@ class MemoryTools:
|
||||
# 记录最高分数
|
||||
if mem_id not in memory_scores or similarity > memory_scores[mem_id]:
|
||||
memory_scores[mem_id] = similarity
|
||||
|
||||
|
||||
# 🔥 详细日志:检查初始召回情况
|
||||
logger.info(
|
||||
f"初始向量搜索: 返回{len(similar_nodes)}个节点 → "
|
||||
@@ -559,8 +559,8 @@ class MemoryTools:
|
||||
)
|
||||
if len(initial_memory_ids) == 0:
|
||||
logger.warning(
|
||||
f"⚠️ 向量搜索未找到任何记忆!"
|
||||
f"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
|
||||
"⚠️ 向量搜索未找到任何记忆!"
|
||||
"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
|
||||
)
|
||||
# 输出相似节点的详细信息用于调试
|
||||
if similar_nodes:
|
||||
@@ -692,7 +692,7 @@ class MemoryTools:
|
||||
key=lambda x: final_scores[x],
|
||||
reverse=True
|
||||
) # 🔥 不再提前截断,让所有候选参与详细评分
|
||||
|
||||
|
||||
# 🔍 统计初始记忆的相似度分布(用于诊断)
|
||||
if memory_scores:
|
||||
similarities = list(memory_scores.values())
|
||||
@@ -707,7 +707,7 @@ class MemoryTools:
|
||||
# 5. 获取完整记忆并进行最终排序(优化后的动态权重系统)
|
||||
memories_with_scores = []
|
||||
filter_stats = {"importance": 0, "similarity": 0, "total_checked": 0} # 过滤统计
|
||||
|
||||
|
||||
for memory_id in sorted_memory_ids: # 遍历所有候选
|
||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||
if memory:
|
||||
@@ -715,7 +715,7 @@ class MemoryTools:
|
||||
# 基础分数
|
||||
similarity_score = final_scores[memory_id]
|
||||
importance_score = memory.importance
|
||||
|
||||
|
||||
# 🆕 区分记忆来源(用于过滤)
|
||||
is_initial_memory = memory_id in memory_scores # 是否来自初始向量搜索
|
||||
true_similarity = memory_scores.get(memory_id, 0.0) if is_initial_memory else None
|
||||
@@ -738,16 +738,16 @@ class MemoryTools:
|
||||
activation_score = memory.activation
|
||||
|
||||
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
|
||||
memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type)
|
||||
|
||||
memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type)
|
||||
|
||||
# 检测记忆的主要节点类型
|
||||
node_types_count = {}
|
||||
for node in memory.nodes:
|
||||
nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type)
|
||||
nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type)
|
||||
node_types_count[nt] = node_types_count.get(nt, 0) + 1
|
||||
|
||||
|
||||
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
|
||||
|
||||
|
||||
# 根据记忆类型和节点类型计算调整系数(在配置权重基础上微调)
|
||||
if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT":
|
||||
# 事实性记忆:提升相似度权重,降低时效性权重
|
||||
@@ -777,41 +777,41 @@ class MemoryTools:
|
||||
"importance": 1.0,
|
||||
"recency": 1.0,
|
||||
}
|
||||
|
||||
|
||||
# 应用调整后的权重(基于配置的基础权重)
|
||||
weights = {
|
||||
"similarity": self.base_vector_weight * type_adjustments["similarity"],
|
||||
"importance": self.base_importance_weight * type_adjustments["importance"],
|
||||
"recency": self.base_recency_weight * type_adjustments["recency"],
|
||||
}
|
||||
|
||||
|
||||
# 归一化权重(确保总和为1.0)
|
||||
total_weight = sum(weights.values())
|
||||
if total_weight > 0:
|
||||
weights = {k: v / total_weight for k, v in weights.items()}
|
||||
|
||||
|
||||
# 综合分数计算(🔥 移除激活度影响)
|
||||
final_score = (
|
||||
similarity_score * weights["similarity"] +
|
||||
importance_score * weights["importance"] +
|
||||
recency_score * weights["recency"]
|
||||
)
|
||||
|
||||
|
||||
# 🆕 阈值过滤策略:
|
||||
# 1. 重要性过滤:应用于所有记忆(过滤极低质量)
|
||||
if memory.importance < self.search_min_importance:
|
||||
filter_stats["importance"] += 1
|
||||
logger.debug(f"❌ 过滤 {memory.id[:8]}: 重要性 {memory.importance:.2f} < 阈值 {self.search_min_importance}")
|
||||
continue
|
||||
|
||||
|
||||
# 2. 相似度过滤:不再对初始向量搜索结果过滤(信任向量搜索的排序)
|
||||
# 理由:向量搜索已经按相似度排序,返回的都是最相关结果
|
||||
# 如果再用阈值过滤,会导致"最相关的也不够相关"的矛盾
|
||||
#
|
||||
#
|
||||
# 注意:如果未来需要对扩展记忆过滤,可以在这里添加逻辑
|
||||
# if not is_initial_memory and some_score < threshold:
|
||||
# continue
|
||||
|
||||
|
||||
# 记录通过过滤的记忆(用于调试)
|
||||
if is_initial_memory:
|
||||
logger.debug(
|
||||
@@ -823,11 +823,11 @@ class MemoryTools:
|
||||
f"✅ 保留 {memory.id[:8]} [扩展]: 重要性={memory.importance:.2f}, "
|
||||
f"综合分数={final_score:.4f}"
|
||||
)
|
||||
|
||||
|
||||
# 🆕 节点类型加权:对REFERENCE/ATTRIBUTE节点额外加分(促进事实性信息召回)
|
||||
if "REFERENCE" in node_types_count or "ATTRIBUTE" in node_types_count:
|
||||
final_score *= 1.1 # 10% 加成
|
||||
|
||||
|
||||
# 🆕 用户指定的优先节点类型额外加权
|
||||
if prefer_node_types:
|
||||
for prefer_type in prefer_node_types:
|
||||
@@ -835,7 +835,7 @@ class MemoryTools:
|
||||
final_score *= 1.15 # 15% 额外加成
|
||||
logger.debug(f"记忆 {memory.id[:8]} 包含优先节点类型 {prefer_type},加权后分数: {final_score:.4f}")
|
||||
break
|
||||
|
||||
|
||||
memories_with_scores.append((memory, final_score, dominant_node_type))
|
||||
|
||||
# 按综合分数排序
|
||||
@@ -845,7 +845,7 @@ class MemoryTools:
|
||||
# 统计过滤情况
|
||||
total_candidates = len(all_memory_ids)
|
||||
filtered_count = total_candidates - len(memories_with_scores)
|
||||
|
||||
|
||||
# 6. 格式化结果(包含调试信息)
|
||||
results = []
|
||||
for memory, score, node_type in memories_with_scores[:top_k]:
|
||||
@@ -866,7 +866,7 @@ class MemoryTools:
|
||||
f"过滤{filtered_count}个 (重要性过滤) → "
|
||||
f"最终返回{len(results)}条记忆"
|
||||
)
|
||||
|
||||
|
||||
# 如果过滤率过高,发出警告
|
||||
if total_candidates > 0:
|
||||
filter_rate = filtered_count / total_candidates
|
||||
@@ -1092,20 +1092,21 @@ class MemoryTools:
|
||||
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
|
||||
|
||||
import re
|
||||
|
||||
import orjson
|
||||
|
||||
|
||||
# 清理Markdown代码块
|
||||
response = re.sub(r"```json\s*", "", response)
|
||||
response = re.sub(r"```\s*$", "", response).strip()
|
||||
|
||||
# 解析JSON
|
||||
data = orjson.loads(response)
|
||||
|
||||
|
||||
# 提取查询列表
|
||||
queries = data.get("queries", [])
|
||||
result_queries = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
|
||||
for item in queries if item.get("text", "").strip()]
|
||||
|
||||
|
||||
# 提取偏好节点类型
|
||||
prefer_node_types = data.get("prefer_node_types", [])
|
||||
# 确保类型正确且有效
|
||||
@@ -1154,7 +1155,7 @@ class MemoryTools:
|
||||
limit=top_k * 5, # 🔥 从2倍提升到5倍,提高初始召回率
|
||||
min_similarity=0.0, # 不在这里过滤,交给后续评分
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}")
|
||||
if similar_nodes:
|
||||
logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}")
|
||||
|
||||
@@ -62,7 +62,7 @@ async def expand_memories_with_semantic_filter(
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 记录已访问的记忆,避免重复
|
||||
visited_memories = set(initial_memory_ids)
|
||||
# 记录扩展的记忆及其分数
|
||||
@@ -87,17 +87,17 @@ async def expand_memories_with_semantic_filter(
|
||||
|
||||
# 获取该记忆的邻居记忆(通过边关系)
|
||||
neighbor_memory_ids = set()
|
||||
|
||||
|
||||
# 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重)
|
||||
edge_weights = {} # 记录通过不同边类型到达的记忆的权重
|
||||
|
||||
|
||||
for edge in memory.edges:
|
||||
# 获取边的目标节点
|
||||
target_node_id = edge.target_id
|
||||
source_node_id = edge.source_id
|
||||
|
||||
|
||||
# 🆕 根据边类型设置权重(优先扩展REFERENCE、ATTRIBUTE相关的边)
|
||||
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
|
||||
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type)
|
||||
if edge_type_str == "REFERENCE":
|
||||
edge_weight = 1.3 # REFERENCE边权重最高(引用关系)
|
||||
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
|
||||
@@ -108,18 +108,18 @@ async def expand_memories_with_semantic_filter(
|
||||
edge_weight = 0.9 # 一般关系适中降权
|
||||
else:
|
||||
edge_weight = 1.0 # 默认权重
|
||||
|
||||
|
||||
# 通过节点找到其他记忆
|
||||
for node_id in [target_node_id, source_node_id]:
|
||||
if node_id in graph_store.node_to_memories:
|
||||
for neighbor_id in graph_store.node_to_memories[node_id]:
|
||||
if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight:
|
||||
edge_weights[neighbor_id] = edge_weight
|
||||
|
||||
|
||||
# 将权重高的邻居记忆加入候选
|
||||
for neighbor_id, edge_weight in edge_weights.items():
|
||||
neighbor_memory_ids.add((neighbor_id, edge_weight))
|
||||
|
||||
|
||||
# 过滤掉已访问的和自己
|
||||
filtered_neighbors = []
|
||||
for neighbor_id, edge_weight in neighbor_memory_ids:
|
||||
@@ -129,7 +129,7 @@ async def expand_memories_with_semantic_filter(
|
||||
# 批量评估邻居记忆
|
||||
for neighbor_mem_id, edge_weight in filtered_neighbors:
|
||||
candidates_checked += 1
|
||||
|
||||
|
||||
neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id)
|
||||
if not neighbor_memory:
|
||||
continue
|
||||
@@ -139,7 +139,7 @@ async def expand_memories_with_semantic_filter(
|
||||
(n for n in neighbor_memory.nodes if n.has_embedding()),
|
||||
None
|
||||
)
|
||||
|
||||
|
||||
if not topic_node or topic_node.embedding is None:
|
||||
continue
|
||||
|
||||
@@ -179,11 +179,11 @@ async def expand_memories_with_semantic_filter(
|
||||
if len(expanded_memories) >= max_expanded:
|
||||
logger.debug(f"⏹️ 提前停止:已达到最大扩展数量 {max_expanded}")
|
||||
break
|
||||
|
||||
|
||||
# 早停检查
|
||||
if len(expanded_memories) >= max_expanded:
|
||||
break
|
||||
|
||||
|
||||
# 记录本层统计
|
||||
depth_stats.append({
|
||||
"depth": depth + 1,
|
||||
@@ -199,20 +199,20 @@ async def expand_memories_with_semantic_filter(
|
||||
|
||||
# 限制下一层的记忆数量,避免爆炸性增长
|
||||
current_level_memories = next_level_memories[:max_expanded]
|
||||
|
||||
|
||||
# 每层让出控制权
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
# 排序并返回
|
||||
sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded]
|
||||
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"✅ 图扩展完成: 初始{len(initial_memory_ids)}个 → "
|
||||
f"扩展{len(sorted_results)}个新记忆 "
|
||||
f"(深度={max_depth}, 阈值={semantic_threshold:.2f}, 耗时={elapsed:.3f}s)"
|
||||
)
|
||||
|
||||
|
||||
# 输出每层统计
|
||||
for stat in depth_stats:
|
||||
logger.debug(
|
||||
|
||||
@@ -137,7 +137,7 @@ async def generate_reply(
|
||||
prompt_mode = "s4u" # 默认使用s4u模式
|
||||
if action_data and "prompt_mode" in action_data:
|
||||
prompt_mode = action_data.get("prompt_mode", "s4u")
|
||||
|
||||
|
||||
# 将prompt_mode添加到available_actions中(作为特殊键)
|
||||
# 注意:这里我们需要暂时使用类型忽略,因为available_actions的类型定义不支持非ActionInfo值
|
||||
if available_actions is None:
|
||||
|
||||
@@ -362,7 +362,7 @@ class ChatterPlanFilter:
|
||||
return "最近没有聊天内容。", "没有未读消息。", []
|
||||
|
||||
stream_context = chat_stream.context_manager
|
||||
|
||||
|
||||
# 获取真正的已读和未读消息
|
||||
read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中
|
||||
if not read_messages:
|
||||
@@ -660,30 +660,30 @@ class ChatterPlanFilter:
|
||||
if not action_info:
|
||||
logger.debug(f"动作 {action_name} 不在可用动作列表中,保留所有参数")
|
||||
return action_data
|
||||
|
||||
|
||||
# 获取该动作定义的合法参数
|
||||
defined_params = set(action_info.action_parameters.keys())
|
||||
|
||||
|
||||
# 合法参数集合
|
||||
valid_params = defined_params
|
||||
|
||||
|
||||
# 过滤参数
|
||||
filtered_data = {}
|
||||
removed_params = []
|
||||
|
||||
|
||||
for key, value in action_data.items():
|
||||
if key in valid_params:
|
||||
filtered_data[key] = value
|
||||
else:
|
||||
removed_params.append(key)
|
||||
|
||||
|
||||
# 记录被移除的参数
|
||||
if removed_params:
|
||||
logger.info(
|
||||
f"🧹 [参数过滤] 动作 '{action_name}' 移除了多余参数: {removed_params}. "
|
||||
f"合法参数: {sorted(valid_params)}"
|
||||
)
|
||||
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]:
|
||||
|
||||
@@ -615,14 +615,14 @@ async def execute_proactive_thinking(stream_id: str):
|
||||
# 获取或创建该聊天流的执行锁
|
||||
if stream_id not in _execution_locks:
|
||||
_execution_locks[stream_id] = asyncio.Lock()
|
||||
|
||||
|
||||
lock = _execution_locks[stream_id]
|
||||
|
||||
|
||||
# 尝试获取锁,如果已被占用则跳过本次执行(防止重复)
|
||||
if lock.locked():
|
||||
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 已有正在执行的主动思考任务")
|
||||
return
|
||||
|
||||
|
||||
async with lock:
|
||||
logger.debug(f"🤔 开始主动思考 {stream_id}")
|
||||
|
||||
@@ -633,13 +633,13 @@ async def execute_proactive_thinking(stream_id: str):
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
|
||||
|
||||
if chat_stream and chat_stream.context_manager.context.is_chatter_processing:
|
||||
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"检查 chatter 处理状态时出错: {e},继续执行")
|
||||
|
||||
|
||||
# 0.1 检查白名单/黑名单
|
||||
# 从 stream_id 获取 stream_config 字符串进行验证
|
||||
try:
|
||||
|
||||
@@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata(
|
||||
# 导入插件主类
|
||||
from .plugin import AntiInjectionPlugin
|
||||
|
||||
__all__ = ["__plugin_meta__", "AntiInjectionPlugin"]
|
||||
__all__ = ["AntiInjectionPlugin", "__plugin_meta__"]
|
||||
|
||||
@@ -8,8 +8,8 @@ import time
|
||||
|
||||
from src.chat.security.interfaces import (
|
||||
SecurityAction,
|
||||
SecurityCheckResult,
|
||||
SecurityChecker,
|
||||
SecurityCheckResult,
|
||||
SecurityLevel,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
|
||||
"""
|
||||
|
||||
from src.chat.security.interfaces import SecurityAction, SecurityCheckResult
|
||||
from src.chat.security.interfaces import SecurityCheckResult
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .counter_attack import CounterAttackGenerator
|
||||
|
||||
@@ -22,23 +22,23 @@ class ReplyAction(BaseAction):
|
||||
- 专注于理解和回应单条消息的具体内容
|
||||
- 适合 Focus 模式下的精准回复
|
||||
"""
|
||||
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "reply"
|
||||
action_description = "针对特定消息进行精准回复。深度理解并回应单条消息的具体内容。需要指定目标消息ID。"
|
||||
|
||||
|
||||
# 激活设置
|
||||
activation_type = ActionActivationType.ALWAYS # 回复动作总是可用
|
||||
mode_enable = ChatMode.ALL # 在所有模式下都可用
|
||||
parallel_action = False # 回复动作不能与其他动作并行
|
||||
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters: ClassVar = {
|
||||
"target_message_id": "要回复的目标消息ID(必需,来自未读消息的 <m...> 标签)",
|
||||
"content": "回复的具体内容(可选,由LLM生成)",
|
||||
"should_quote_reply": "是否引用原消息(可选,true/false,默认false。群聊中回复较早消息或需要明确指向时使用true)",
|
||||
}
|
||||
|
||||
|
||||
# 动作使用场景
|
||||
action_require: ClassVar = [
|
||||
"需要针对特定消息进行精准回复时使用",
|
||||
@@ -48,10 +48,10 @@ class ReplyAction(BaseAction):
|
||||
"群聊中需要明确回应某个特定用户或问题时使用",
|
||||
"关注单条消息的具体内容和上下文细节",
|
||||
]
|
||||
|
||||
|
||||
# 关联类型
|
||||
associated_types: ClassVar[list[str]] = ["text"]
|
||||
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行reply动作
|
||||
|
||||
@@ -70,21 +70,21 @@ class RespondAction(BaseAction):
|
||||
- 适合对于群聊消息下的宏观回应
|
||||
- 避免与单一用户深度对话而忽略其他用户的消息
|
||||
"""
|
||||
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "respond"
|
||||
action_description = "统一回应所有未读消息。理解整体对话动态和话题走向,生成连贯的回复。无需指定目标消息。"
|
||||
|
||||
|
||||
# 激活设置
|
||||
activation_type = ActionActivationType.ALWAYS # 回应动作总是可用
|
||||
mode_enable = ChatMode.ALL # 在所有模式下都可用
|
||||
parallel_action = False # 回应动作不能与其他动作并行
|
||||
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters: ClassVar = {
|
||||
"content": "回复的具体内容(可选,由LLM生成)",
|
||||
}
|
||||
|
||||
|
||||
# 动作使用场景
|
||||
action_require: ClassVar = [
|
||||
"需要统一回应多条未读消息时使用(Normal 模式专用)",
|
||||
@@ -94,10 +94,10 @@ class RespondAction(BaseAction):
|
||||
"适合群聊中的自然对话流,无需精确指向特定消息",
|
||||
"可以同时回应多个话题或参与者",
|
||||
]
|
||||
|
||||
|
||||
# 关联类型
|
||||
associated_types: ClassVar[list[str]] = ["text"]
|
||||
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行respond动作
|
||||
|
||||
|
||||
@@ -64,15 +64,15 @@ class CoreActionsPlugin(BasePlugin):
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components: ClassVar = []
|
||||
|
||||
|
||||
# 注册 reply 动作
|
||||
if self.get_config("components.enable_reply", True):
|
||||
components.append((ReplyAction.get_action_info(), ReplyAction))
|
||||
|
||||
|
||||
# 注册 respond 动作
|
||||
if self.get_config("components.enable_respond", True):
|
||||
components.append((RespondAction.get_action_info(), RespondAction))
|
||||
|
||||
|
||||
# 注册 emoji 动作
|
||||
if self.get_config("components.enable_emoji", True):
|
||||
components.append((EmojiAction.get_action_info(), EmojiAction))
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import filetype
|
||||
from collections.abc import Callable
|
||||
|
||||
import aiohttp
|
||||
import filetype
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
@@ -17,7 +17,6 @@ import uuid
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
Reference in New Issue
Block a user