diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py index 1cf21d7ed..36f2dd2e9 100644 --- a/src/chat/chatter_manager.py +++ b/src/chat/chatter_manager.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any from src.chat.planner_actions.action_manager import ChatterActionManager from src.common.logger import get_logger diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index ac8d96e69..d4338eb90 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -6,7 +6,7 @@ import asyncio import time -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any from src.chat.energy_system import energy_manager from src.common.data_models.database_data_model import DatabaseMessages diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index 097410d29..b8e940748 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -5,7 +5,7 @@ import asyncio import time -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any from src.chat.chatter_manager import ChatterManager from src.chat.energy_system import energy_manager @@ -115,12 +115,12 @@ class StreamLoopManager: if not context: logger.warning(f"无法获取流上下文: {stream_id}") return False - + # 快速路径:如果流已存在且不是强制启动,无需处理 if not force and context.stream_loop_task and not context.stream_loop_task.done(): logger.debug(f"🔄 [流循环] stream={stream_id[:8]}, 循环已在运行,跳过启动") return True - + # 获取或创建该流的启动锁 if stream_id not in self._stream_start_locks: self._stream_start_locks[stream_id] = asyncio.Lock() diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 8cd4fc456..4dee0745d 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -12,7 +12,6 @@ from src.common.data_models.database_data_model import DatabaseMessages from src.common.database.core import get_db_session from src.common.database.core.models import Images, Messages from src.common.logger import get_logger -from src.config.config import global_config from .chat_stream import ChatStream from .message import MessageSending diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 0c83314c5..a0e72ed73 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -242,9 +242,9 @@ class ChatterActionManager: } else: # 检查目标消息是否为表情包消息以及配置是否允许回复表情包 - if target_message and getattr(target_message, 'is_emoji', False): + if target_message and getattr(target_message, "is_emoji", False): # 如果是表情包消息且配置不允许回复表情包,则跳过回复 - if not getattr(global_config.chat, 'allow_reply_to_emoji', True): + if not getattr(global_config.chat, "allow_reply_to_emoji", True): logger.info(f"{log_prefix} 目标消息为表情包且配置不允许回复表情包,跳过回复") return {"action_type": action_name, "success": True, "reply_text": "", "skip_reason": "emoji_not_allowed"} diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index d145c6db0..de986791a 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -376,7 +376,7 @@ class DefaultReplyer: if not prompt: logger.warning("构建prompt失败,跳过回复生成") return False, None, None - + from src.plugin_system.core.event_manager import event_manager # 触发 POST_LLM 事件(请求 LLM 之前) if not from_plugin: @@ -1878,8 +1878,8 @@ class DefaultReplyer: async def build_relation_info(self, sender: str, target: str): # 获取用户ID if sender == f"{global_config.bot.nickname}(你)": - return f"你将要回复的是你自己发送的消息。" - + return "你将要回复的是你自己发送的消息。" + person_info_manager = get_person_info_manager() person_id = await person_info_manager.get_person_id_by_person_name(sender) diff --git a/src/chat/utils/attention_optimizer.py b/src/chat/utils/attention_optimizer.py index 27365177b..8ab669228 100644 --- a/src/chat/utils/attention_optimizer.py +++ b/src/chat/utils/attention_optimizer.py @@ -47,10 +47,10 @@ class BlockShuffler: # 复制上下文以避免修改原始字典 shuffled_context = context_data.copy() - + # 示例:假设模板中的占位符格式为 {block_name} # 我们需要解析模板,找到可重排的组,并重新构建模板字符串。 - + # 注意:这是一个复杂的逻辑,通常需要一个简单的模板引擎或正则表达式来完成。 # 为保持此函数职责单一,这里仅演示核心的重排逻辑, # 完整的模板重建逻辑应在调用此函数的地方处理。 @@ -58,14 +58,14 @@ class BlockShuffler: for group in BlockShuffler.SWAPPABLE_BLOCK_GROUPS: # 过滤出在当前上下文中实际存在的、非空的block existing_blocks = [ - block for block in group if block in context_data and context_data[block] + block for block in group if context_data.get(block) ] if len(existing_blocks) > 1: # 随机打乱顺序 random.shuffle(existing_blocks) logger.debug(f"重排block组: {group} -> {existing_blocks}") - + # 这里的实现需要调用者根据 `existing_blocks` 的新顺序 # 去动态地重新组织 `prompt_template` 字符串。 # 例如,找到模板中与 `group` 相关的占位符部分,然后按新顺序替换它们。 diff --git a/src/chat/utils/prompt_component_manager.py b/src/chat/utils/prompt_component_manager.py index 135e48883..976ad488b 100644 --- a/src/chat/utils/prompt_component_manager.py +++ b/src/chat/utils/prompt_component_manager.py @@ -2,7 +2,6 @@ import asyncio import copy import re from collections.abc import Awaitable, Callable -from typing import List from src.chat.utils.prompt_params import PromptParameters from src.common.logger import get_logger @@ -119,7 +118,7 @@ class PromptComponentManager: async def add_injection_rule( self, prompt_name: str, - rules: List[InjectionRule], + rules: list[InjectionRule], content_provider: Callable[..., Awaitable[str]], source: str = "runtime", ) -> bool: @@ -521,7 +520,7 @@ class PromptComponentManager: else: for name, (rule, _, _) in rules_for_target.items(): target_copy[name] = rule - + if target_copy: rules_copy[target] = target_copy diff --git a/src/chat/utils/prompt_params.py b/src/chat/utils/prompt_params.py index ab07e1688..707b18575 100644 --- a/src/chat/utils/prompt_params.py +++ b/src/chat/utils/prompt_params.py @@ -63,7 +63,7 @@ class PromptParameters: action_descriptions: str = "" notice_block: str = "" group_chat_reminder_block: str = "" - + # 可用动作信息 available_actions: dict[str, Any] | None = None diff --git a/src/chat/utils/report_generator.py b/src/chat/utils/report_generator.py index e23a1d75e..8c8756070 100644 --- a/src/chat/utils/report_generator.py +++ b/src/chat/utils/report_generator.py @@ -228,9 +228,9 @@ class HTMLReportGenerator: # 渲染模板 # 读取CSS和JS文件内容 - async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.css"), "r", encoding="utf-8") as f: + async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.css"), encoding="utf-8") as f: report_css = await f.read() - async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.js"), "r", encoding="utf-8") as f: + async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.js"), encoding="utf-8") as f: report_js = await f.read() # 渲染模板 template = self.jinja_env.get_template("report.html") diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 21467d0f5..5b4b811c0 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -3,8 +3,6 @@ from collections import defaultdict from datetime import datetime, timedelta from typing import Any -import aiofiles - from src.common.database.compatibility import db_get, db_query from src.common.database.core.models import LLMUsage, Messages, OnlineTime from src.common.logger import get_logger @@ -16,7 +14,7 @@ logger = get_logger("maibot_statistic") # 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。 -from .report_generator import HTMLReportGenerator, format_online_time +from .report_generator import HTMLReportGenerator from .statistic_keys import * diff --git a/src/chat/utils/statistic_keys.py b/src/chat/utils/statistic_keys.py index 67b01faeb..2a552ac1a 100644 --- a/src/chat/utils/statistic_keys.py +++ b/src/chat/utils/statistic_keys.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 该模块用于存放统计数据相关的常量键名。 """ @@ -61,4 +60,4 @@ STD_TIME_COST_BY_PROVIDER = "std_time_costs_by_provider" PIE_CHART_COST_BY_PROVIDER = "pie_chart_cost_by_provider" PIE_CHART_REQ_BY_PROVIDER = "pie_chart_req_by_provider" BAR_CHART_COST_BY_MODEL = "bar_chart_cost_by_model" -BAR_CHART_REQ_BY_MODEL = "bar_chart_req_by_model" \ No newline at end of file +BAR_CHART_REQ_BY_MODEL = "bar_chart_req_by_model" diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 7e89d9c9f..c26bb752d 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -537,7 +537,7 @@ class _PromptProcessor: else: is_truncated = True return content, reasoning, is_truncated - + @staticmethod async def _extract_reasoning(content: str) -> tuple[str, str]: """ diff --git a/src/main.py b/src/main.py index f39d3f956..a5afe6ef2 100644 --- a/src/main.py +++ b/src/main.py @@ -1,4 +1,5 @@ # 再用这个就写一行注释来混提交的我直接全部🌿飞😡 +# 🌿🌿need import asyncio import signal import sys @@ -21,7 +22,6 @@ from src.common.message import get_global_api # 全局背景任务集合 _background_tasks = set() -from src.common.remote import TelemetryHeartBeatTask from src.common.server import Server, get_global_server from src.config.config import global_config from src.individuality.individuality import Individuality, get_individuality diff --git a/src/memory_graph/storage/persistence.py b/src/memory_graph/storage/persistence.py index 46ed90ba1..452604e4e 100644 --- a/src/memory_graph/storage/persistence.py +++ b/src/memory_graph/storage/persistence.py @@ -507,7 +507,7 @@ class PersistenceManager: GraphStore 对象 """ try: - async with aiofiles.open(input_file, "r", encoding="utf-8") as f: + async with aiofiles.open(input_file, encoding="utf-8") as f: content = await f.read() data = json.loads(content) diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index 88e77b34b..bb4122076 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -98,7 +98,7 @@ class MemoryTools: graph_store=graph_store, embedding_generator=embedding_generator, ) - + # 初始化路径扩展器(延迟初始化,仅在启用时创建) self.path_expander: PathScoreExpansion | None = None @@ -573,7 +573,7 @@ class MemoryTools: # 检查是否启用路径扩展算法 use_path_expansion = getattr(global_config.memory, "enable_path_expansion", False) and expand_depth > 0 expanded_memory_scores = {} - + if expand_depth > 0 and initial_memory_ids: # 获取查询的embedding query_embedding = None @@ -582,12 +582,12 @@ class MemoryTools: query_embedding = await self.builder.embedding_generator.generate(query) except Exception as e: logger.warning(f"生成查询embedding失败: {e}") - + if query_embedding is not None: if use_path_expansion: # 🆕 使用路径评分扩展算法 logger.info(f"🔬 使用路径评分扩展算法: 初始{len(similar_nodes)}个节点, 深度={expand_depth}") - + # 延迟初始化路径扩展器 if self.path_expander is None: path_config = PathExpansionConfig( @@ -607,7 +607,7 @@ class MemoryTools: vector_store=self.vector_store, config=path_config ) - + try: # 执行路径扩展(传递偏好类型) path_results = await self.path_expander.expand_with_path_scoring( @@ -616,11 +616,11 @@ class MemoryTools: top_k=top_k, prefer_node_types=all_prefer_types # 🆕 传递偏好类型 ) - + # 路径扩展返回的是 [(Memory, final_score, paths), ...] # 我们需要直接返回这些记忆,跳过后续的传统评分 logger.info(f"✅ 路径扩展返回 {len(path_results)} 条记忆") - + # 直接构建返回结果 path_memories = [] for memory, score, paths in path_results: @@ -635,25 +635,25 @@ class MemoryTools: "max_path_depth": max(p.depth for p in paths) if paths else 0 } }) - + logger.info(f"🎯 路径扩展最终返回: {len(path_memories)} 条记忆") - + return { "success": True, "results": path_memories, "total": len(path_memories), "expansion_method": "path_scoring" } - + except Exception as e: logger.error(f"路径扩展失败: {e}", exc_info=True) logger.info("回退到传统图扩展算法") # 继续执行下面的传统图扩展 - + # 传统图扩展(仅在未启用路径扩展或路径扩展失败时执行) if not use_path_expansion or expanded_memory_scores == {}: logger.info(f"开始传统图扩展: 初始记忆{len(initial_memory_ids)}个, 深度={expand_depth}") - + try: # 使用共享的图扩展工具函数 expanded_results = await expand_memories_with_semantic_filter( diff --git a/src/memory_graph/utils/__init__.py b/src/memory_graph/utils/__init__.py index fffb59ba4..72b64e611 100644 --- a/src/memory_graph/utils/__init__.py +++ b/src/memory_graph/utils/__init__.py @@ -9,10 +9,10 @@ from src.memory_graph.utils.time_parser import TimeParser __all__ = [ "EmbeddingGenerator", + "Path", + "PathExpansionConfig", + "PathScoreExpansion", "TimeParser", "cosine_similarity", "get_embedding_generator", - "PathScoreExpansion", - "PathExpansionConfig", - "Path", ] diff --git a/src/memory_graph/utils/memory_deduplication.py b/src/memory_graph/utils/memory_deduplication.py index 42079ff39..f506dfa54 100644 --- a/src/memory_graph/utils/memory_deduplication.py +++ b/src/memory_graph/utils/memory_deduplication.py @@ -12,7 +12,7 @@ from src.common.logger import get_logger from src.memory_graph.utils.similarity import cosine_similarity if TYPE_CHECKING: - from src.memory_graph.models import Memory + pass logger = get_logger(__name__) @@ -41,52 +41,52 @@ async def deduplicate_memories_by_similarity( """ if len(memories) <= 1: return memories - + logger.info(f"开始记忆去重: {len(memories)} 条记忆 (阈值={similarity_threshold})") - + # 准备数据结构 memory_embeddings = [] for memory, score, extra in memories: # 获取记忆的向量表示 embedding = await _get_memory_embedding(memory) memory_embeddings.append((memory, score, extra, embedding)) - + # 构建相似度矩阵并找出重复组 duplicate_groups = _find_duplicate_groups(memory_embeddings, similarity_threshold) - + # 合并每个重复组 deduplicated = [] processed_indices = set() - + for group_indices in duplicate_groups: if any(i in processed_indices for i in group_indices): continue # 已经处理过 - + # 标记为已处理 processed_indices.update(group_indices) - + # 合并组内记忆 group_memories = [memory_embeddings[i] for i in group_indices] merged_memory = _merge_memory_group(group_memories) deduplicated.append(merged_memory) - + # 添加未被合并的记忆 for i, (memory, score, extra, _) in enumerate(memory_embeddings): if i not in processed_indices: deduplicated.append((memory, score, extra)) - + # 按分数排序 deduplicated.sort(key=lambda x: x[1], reverse=True) - + # 限制数量 if keep_top_n is not None: deduplicated = deduplicated[:keep_top_n] - + logger.info( f"去重完成: {len(memories)} → {len(deduplicated)} 条记忆 " f"(合并了 {len(memories) - len(deduplicated)} 条重复)" ) - + return deduplicated @@ -104,7 +104,7 @@ async def _get_memory_embedding(memory: Any) -> list[float] | None: # nodes 是 MemoryNode 对象列表 first_node = memory.nodes[0] node_id = getattr(first_node, "id", None) - + if node_id: # 直接从 embedding 属性获取(如果存在) if hasattr(first_node, "embedding") and first_node.embedding is not None: @@ -114,7 +114,7 @@ async def _get_memory_embedding(memory: Any) -> list[float] | None: return embedding.tolist() elif isinstance(embedding, list): return embedding - + # 无法获取 embedding return None @@ -132,13 +132,13 @@ def _find_duplicate_groups( """ n = len(memory_embeddings) similarity_matrix = [[0.0] * n for _ in range(n)] - + # 计算相似度矩阵 for i in range(n): for j in range(i + 1, n): embedding_i = memory_embeddings[i][3] embedding_j = memory_embeddings[j][3] - + # 跳过 None 或零向量 if (embedding_i is None or embedding_j is None or all(x == 0.0 for x in embedding_i) or all(x == 0.0 for x in embedding_j)): @@ -146,29 +146,29 @@ def _find_duplicate_groups( else: # cosine_similarity 会自动转换为 numpy 数组 similarity = float(cosine_similarity(embedding_i, embedding_j)) # type: ignore - + similarity_matrix[i][j] = similarity similarity_matrix[j][i] = similarity - + # 使用并查集找出连通分量 parent = list(range(n)) - + def find(x): if parent[x] != x: parent[x] = find(parent[x]) return parent[x] - + def union(x, y): px, py = find(x), find(y) if px != py: parent[px] = py - + # 合并相似的记忆 for i in range(n): for j in range(i + 1, n): if similarity_matrix[i][j] >= threshold: union(i, j) - + # 构建组 groups_dict: dict[int, list[int]] = {} for i in range(n): @@ -176,10 +176,10 @@ def _find_duplicate_groups( if root not in groups_dict: groups_dict[root] = [] groups_dict[root].append(i) - + # 只返回大小 > 1 的组(真正的重复组) duplicate_groups = [group for group in groups_dict.values() if len(group) > 1] - + return duplicate_groups @@ -196,10 +196,10 @@ def _merge_memory_group( """ # 按分数排序 sorted_group = sorted(group, key=lambda x: x[1], reverse=True) - + # 保留分数最高的记忆 best_memory, best_score, best_extra, _ = sorted_group[0] - + # 计算合并后的分数(加权平均,权重递减) total_weight = 0.0 weighted_sum = 0.0 @@ -207,17 +207,17 @@ def _merge_memory_group( weight = 1.0 / (i + 1) # 第1名权重1.0,第2名0.5,第3名0.33... weighted_sum += score * weight total_weight += weight - + merged_score = weighted_sum / total_weight if total_weight > 0 else best_score - + # 增强 extra_data merged_extra = best_extra if isinstance(best_extra, dict) else {} merged_extra["merged_count"] = len(sorted_group) merged_extra["original_scores"] = [score for _, score, _, _ in sorted_group] - + logger.debug( f"合并 {len(sorted_group)} 条相似记忆: " f"分数 {best_score:.3f} → {merged_score:.3f}" ) - + return (best_memory, merged_score, merged_extra) diff --git a/src/memory_graph/utils/path_expansion.py b/src/memory_graph/utils/path_expansion.py index f24445495..4c80e7553 100644 --- a/src/memory_graph/utils/path_expansion.py +++ b/src/memory_graph/utils/path_expansion.py @@ -26,7 +26,6 @@ from src.memory_graph.utils.similarity import cosine_similarity if TYPE_CHECKING: import numpy as np - from src.memory_graph.models import Memory from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.vector_store import VectorStore @@ -71,7 +70,7 @@ class PathExpansionConfig: medium_score_threshold: float = 0.4 # 中分路径阈值 max_active_paths: int = 1000 # 最大活跃路径数(防止爆炸) top_paths_retain: int = 500 # 超限时保留的top路径数 - + # 🚀 性能优化参数 enable_early_stop: bool = True # 启用早停(如果路径增长很少则提前结束) early_stop_growth_threshold: float = 0.1 # 早停阈值(路径增长率低于10%则停止) @@ -121,7 +120,7 @@ class PathScoreExpansion: self.vector_store = vector_store self.config = config or PathExpansionConfig() self.prefer_node_types: list[str] = [] # 🆕 偏好节点类型 - + # 🚀 性能优化:邻居边缓存 self._neighbor_cache: dict[str, list[Any]] = {} self._node_score_cache: dict[str, float] = {} @@ -212,11 +211,11 @@ class PathScoreExpansion: continue edge_weight = self._get_edge_weight(edge) - + # 记录候选 path_candidates.append((path, edge, next_node, edge_weight)) candidate_nodes_for_batch.add(next_node) - + branch_count += 1 if branch_count >= max_branches: break @@ -281,7 +280,7 @@ class PathScoreExpansion: # 🚀 早停检测:如果路径增长很少,提前终止 prev_path_count = len(active_paths) active_paths = next_paths - + if self.config.enable_early_stop and prev_path_count > 0: growth_rate = (len(active_paths) - prev_path_count) / prev_path_count if growth_rate < self.config.early_stop_growth_threshold: @@ -346,18 +345,18 @@ class PathScoreExpansion: max_path_score = max(p.score for p in paths) if paths else 0 rough_score = len(paths) * max_path_score * memory.importance memory_scores_rough.append((mem_id, rough_score)) - + # 保留top候选 memory_scores_rough.sort(key=lambda x: x[1], reverse=True) retained_mem_ids = set(mem_id for mem_id, _ in memory_scores_rough[:self.config.max_candidate_memories]) - + # 过滤 memory_paths = { mem_id: (memory, paths) for mem_id, (memory, paths) in memory_paths.items() if mem_id in retained_mem_ids } - + logger.info( f"⚡ 粗排过滤: {len(memory_scores_rough)} → {len(memory_paths)} 条候选记忆" ) @@ -398,7 +397,7 @@ class PathScoreExpansion: # 🚀 缓存检查 if node_id in self._neighbor_cache: return self._neighbor_cache[node_id] - + edges = [] # 从图存储中获取与该节点相关的所有边 @@ -454,7 +453,7 @@ class PathScoreExpansion: """ # 从向量存储获取节点数据 node_data = await self.vector_store.get_node_by_id(node_id) - + if query_embedding is None: base_score = 0.5 # 默认中等分数 else: @@ -493,27 +492,27 @@ class PathScoreExpansion: import numpy as np scores = {} - + if query_embedding is None: # 无查询向量时,返回默认分数 - return {nid: 0.5 for nid in node_ids} - + return dict.fromkeys(node_ids, 0.5) + # 批量获取节点数据 node_data_list = await asyncio.gather( *[self.vector_store.get_node_by_id(nid) for nid in node_ids], return_exceptions=True ) - + # 收集有效的嵌入向量 valid_embeddings = [] valid_node_ids = [] node_metadata_map = {} - + for nid, node_data in zip(node_ids, node_data_list): if isinstance(node_data, Exception): scores[nid] = 0.3 continue - + # 类型守卫:确保 node_data 是字典 if not node_data or not isinstance(node_data, dict) or "embedding" not in node_data: scores[nid] = 0.3 @@ -521,21 +520,21 @@ class PathScoreExpansion: valid_embeddings.append(node_data["embedding"]) valid_node_ids.append(nid) node_metadata_map[nid] = node_data.get("metadata", {}) - + if valid_embeddings: # 批量计算相似度(使用矩阵运算) embeddings_matrix = np.array(valid_embeddings) query_norm = np.linalg.norm(query_embedding) embeddings_norms = np.linalg.norm(embeddings_matrix, axis=1) - + # 向量化计算余弦相似度 similarities = np.dot(embeddings_matrix, query_embedding) / (embeddings_norms * query_norm + 1e-8) similarities = np.clip(similarities, 0.0, 1.0) - + # 应用偏好类型加成 for nid, sim in zip(valid_node_ids, similarities): base_score = float(sim) - + # 偏好类型加成 if self.prefer_node_types and nid in node_metadata_map: node_type = node_metadata_map[nid].get("node_type") @@ -546,7 +545,7 @@ class PathScoreExpansion: scores[nid] = base_score else: scores[nid] = base_score - + return scores def _calculate_path_score(self, old_score: float, edge_weight: float, node_score: float, depth: int) -> float: @@ -689,19 +688,19 @@ class PathScoreExpansion: # 使用临时字典存储路径列表 temp_paths: dict[str, list[Path]] = {} temp_memories: dict[str, Any] = {} # 存储 Memory 对象 - + # 🚀 性能优化:收集所有需要获取的记忆ID,然后批量获取 all_memory_ids = set() path_to_memory_ids: dict[int, set[str]] = {} # path对象id -> 记忆ID集合 for path in paths: memory_ids_in_path = set() - + # 收集路径中所有节点涉及的记忆 for node_id in path.nodes: memory_ids = self.graph_store.node_to_memories.get(node_id, []) memory_ids_in_path.update(memory_ids) - + all_memory_ids.update(memory_ids_in_path) path_to_memory_ids[id(path)] = memory_ids_in_path @@ -712,11 +711,11 @@ class PathScoreExpansion: memory = self.graph_store.get_memory_by_id(mem_id) if memory: memory_cache[mem_id] = memory - + # 构建映射关系 for path in paths: memory_ids_in_path = path_to_memory_ids[id(path)] - + for mem_id in memory_ids_in_path: if mem_id in memory_cache: if mem_id not in temp_paths: @@ -745,10 +744,10 @@ class PathScoreExpansion: [(Memory, final_score, paths), ...] """ scored_memories = [] - + # 🚀 性能优化:如果需要偏好类型加成,批量预加载所有节点的类型信息 node_type_cache: dict[str, str | None] = {} - + if self.prefer_node_types: # 收集所有需要查询的节点ID all_node_ids = set() @@ -757,7 +756,7 @@ class PathScoreExpansion: for node in memory_nodes: node_id = node.id if hasattr(node, "id") else str(node) all_node_ids.add(node_id) - + # 批量获取节点数据 if all_node_ids: logger.debug(f"🔍 批量预加载 {len(all_node_ids)} 个节点的类型信息") @@ -765,7 +764,7 @@ class PathScoreExpansion: *[self.vector_store.get_node_by_id(nid) for nid in all_node_ids], return_exceptions=True ) - + # 构建类型缓存 for nid, node_data in zip(all_node_ids, node_data_list): if isinstance(node_data, Exception) or not node_data or not isinstance(node_data, dict): @@ -805,7 +804,7 @@ class PathScoreExpansion: node_type = node_type_cache.get(node_id) if node_type and node_type in self.prefer_node_types: matched_count += 1 - + if matched_count > 0: match_ratio = matched_count / len(memory_nodes) # 根据匹配比例给予加成(最高10%) @@ -870,4 +869,4 @@ class PathScoreExpansion: return recency_score -__all__ = ["PathScoreExpansion", "PathExpansionConfig", "Path"] +__all__ = ["Path", "PathExpansionConfig", "PathScoreExpansion"] diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 5ac6ba9d9..d1f3a5c21 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -269,7 +269,7 @@ class RelationshipFetcher: platform = "unknown" if existing_stream: # 从现有记录获取platform - platform = getattr(existing_stream, 'platform', 'unknown') or "unknown" + platform = getattr(existing_stream, "platform", "unknown") or "unknown" logger.debug(f"从现有ChatStream获取到platform: {platform}, stream_id: {stream_id}") else: logger.debug(f"未找到现有ChatStream记录,使用默认platform: unknown, stream_id: {stream_id}") diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 365395172..a715b98b0 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -742,7 +742,7 @@ class BaseAction(ABC): if not case_sensitive: search_text = search_text.lower() - matched_keywords: ClassVar = [] + matched_keywords = [] for keyword in keywords: check_keyword = keyword if case_sensitive else keyword.lower() if check_keyword in search_text: diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py index c9773140d..61892c1ed 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py @@ -9,6 +9,7 @@ from datetime import datetime from typing import Any import orjson +from json_repair import repair_json from src.chat.utils.chat_message_builder import ( build_readable_messages_with_id, @@ -19,7 +20,6 @@ from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.mood.mood_manager import mood_manager -from json_repair import repair_json from src.plugin_system.base.component_types import ActionInfo, ChatType from src.schedule.schedule_manager import schedule_manager @@ -144,7 +144,7 @@ class ChatterPlanFilter: plan.decided_actions = [ ActionPlannerInfo(action_type="no_action", reasoning=f"筛选时出错: {e}") ] - + # 在返回最终计划前,打印将要执行的动作 if plan.decided_actions: action_types = [action.action_type for action in plan.decided_actions] @@ -631,7 +631,6 @@ class ChatterPlanFilter: candidate_ids.add(normalized_id[1:]) # 处理包含在文本中的ID格式 (如 "消息m123" -> 提取 m123) - import re # 尝试提取各种格式的ID id_patterns = [ diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/plan_generator.py b/src/plugins/built_in/affinity_flow_chatter/planner/plan_generator.py index f8142d696..5a71fad5e 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/plan_generator.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/plan_generator.py @@ -10,7 +10,6 @@ from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import Plan, TargetPersonInfo from src.config.config import global_config from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType -from src.plugin_system.core.component_registry import component_registry class ChatterPlanGenerator: diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py index 2d42cc426..83a280fa6 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py @@ -201,7 +201,7 @@ class ChatterActionPlanner: available_actions = list(initial_plan.available_actions.keys()) plan_filter = ChatterPlanFilter(self.chat_id, available_actions) filtered_plan = await plan_filter.filter(initial_plan) - + # 检查reply动作是否可用 has_reply_action = "reply" in available_actions or "respond" in available_actions if filtered_plan.decided_actions and has_reply_action and reply_not_available: diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 8a1f45f07..5e2d8411a 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -320,7 +320,7 @@ class QZoneService: return # 1. 将评论分为用户评论和自己的回复 - user_comments = [c for c in comments if str(c.get("qq_account")) != str(qq_account)] + user_comments = [c for c in comments if str(c.get("qq_account")) != str(qq_account)] if not user_comments: return diff --git a/src/plugins/built_in/system_management/plugin.py b/src/plugins/built_in/system_management/plugin.py index 2b2df7b01..d3f9ed83e 100644 --- a/src/plugins/built_in/system_management/plugin.py +++ b/src/plugins/built_in/system_management/plugin.py @@ -295,7 +295,7 @@ class SystemCommand(PlusCommand): if injections: response_parts.append(f"🎯 **{target}** (注入源):") for inj in injections: - source_tag = f"({inj['source']})" if inj['source'] != 'static_default' else '' + source_tag = f"({inj['source']})" if inj["source"] != "static_default" else "" response_parts.append(f" ⎿ `{inj['name']}` (优先级: {inj['priority']}) {source_tag}") else: response_parts.append(f"🎯 **{target}** (无注入)")