chore: perform widespread code cleanup and formatting
Perform a comprehensive code cleanup across multiple modules to improve code quality, consistency, and maintainability. Key changes include: - Removing numerous unused imports. - Standardizing import order. - Eliminating trailing whitespace and inconsistent newlines. - Updating legacy type hints to modern syntax (e.g., `List` -> `list`). - Making minor improvements for code robustness and style.
This commit is contained in:
committed by
Windpicker-owo
parent
7475f87826
commit
d12e384cc2
@@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from src.chat.energy_system import energy_manager
|
from src.chat.energy_system import energy_manager
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from src.chat.chatter_manager import ChatterManager
|
from src.chat.chatter_manager import ChatterManager
|
||||||
from src.chat.energy_system import energy_manager
|
from src.chat.energy_system import energy_manager
|
||||||
@@ -115,12 +115,12 @@ class StreamLoopManager:
|
|||||||
if not context:
|
if not context:
|
||||||
logger.warning(f"无法获取流上下文: {stream_id}")
|
logger.warning(f"无法获取流上下文: {stream_id}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 快速路径:如果流已存在且不是强制启动,无需处理
|
# 快速路径:如果流已存在且不是强制启动,无需处理
|
||||||
if not force and context.stream_loop_task and not context.stream_loop_task.done():
|
if not force and context.stream_loop_task and not context.stream_loop_task.done():
|
||||||
logger.debug(f"🔄 [流循环] stream={stream_id[:8]}, 循环已在运行,跳过启动")
|
logger.debug(f"🔄 [流循环] stream={stream_id[:8]}, 循环已在运行,跳过启动")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 获取或创建该流的启动锁
|
# 获取或创建该流的启动锁
|
||||||
if stream_id not in self._stream_start_locks:
|
if stream_id not in self._stream_start_locks:
|
||||||
self._stream_start_locks[stream_id] = asyncio.Lock()
|
self._stream_start_locks[stream_id] = asyncio.Lock()
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from src.common.data_models.database_data_model import DatabaseMessages
|
|||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
from src.common.database.core.models import Images, Messages
|
from src.common.database.core.models import Images, Messages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from .message import MessageSending
|
from .message import MessageSending
|
||||||
|
|||||||
@@ -242,9 +242,9 @@ class ChatterActionManager:
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# 检查目标消息是否为表情包消息以及配置是否允许回复表情包
|
# 检查目标消息是否为表情包消息以及配置是否允许回复表情包
|
||||||
if target_message and getattr(target_message, 'is_emoji', False):
|
if target_message and getattr(target_message, "is_emoji", False):
|
||||||
# 如果是表情包消息且配置不允许回复表情包,则跳过回复
|
# 如果是表情包消息且配置不允许回复表情包,则跳过回复
|
||||||
if not getattr(global_config.chat, 'allow_reply_to_emoji', True):
|
if not getattr(global_config.chat, "allow_reply_to_emoji", True):
|
||||||
logger.info(f"{log_prefix} 目标消息为表情包且配置不允许回复表情包,跳过回复")
|
logger.info(f"{log_prefix} 目标消息为表情包且配置不允许回复表情包,跳过回复")
|
||||||
return {"action_type": action_name, "success": True, "reply_text": "", "skip_reason": "emoji_not_allowed"}
|
return {"action_type": action_name, "success": True, "reply_text": "", "skip_reason": "emoji_not_allowed"}
|
||||||
|
|
||||||
|
|||||||
@@ -376,7 +376,7 @@ class DefaultReplyer:
|
|||||||
if not prompt:
|
if not prompt:
|
||||||
logger.warning("构建prompt失败,跳过回复生成")
|
logger.warning("构建prompt失败,跳过回复生成")
|
||||||
return False, None, None
|
return False, None, None
|
||||||
|
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
# 触发 POST_LLM 事件(请求 LLM 之前)
|
# 触发 POST_LLM 事件(请求 LLM 之前)
|
||||||
if not from_plugin:
|
if not from_plugin:
|
||||||
@@ -1878,8 +1878,8 @@ class DefaultReplyer:
|
|||||||
async def build_relation_info(self, sender: str, target: str):
|
async def build_relation_info(self, sender: str, target: str):
|
||||||
# 获取用户ID
|
# 获取用户ID
|
||||||
if sender == f"{global_config.bot.nickname}(你)":
|
if sender == f"{global_config.bot.nickname}(你)":
|
||||||
return f"你将要回复的是你自己发送的消息。"
|
return "你将要回复的是你自己发送的消息。"
|
||||||
|
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
||||||
|
|
||||||
|
|||||||
@@ -47,10 +47,10 @@ class BlockShuffler:
|
|||||||
|
|
||||||
# 复制上下文以避免修改原始字典
|
# 复制上下文以避免修改原始字典
|
||||||
shuffled_context = context_data.copy()
|
shuffled_context = context_data.copy()
|
||||||
|
|
||||||
# 示例:假设模板中的占位符格式为 {block_name}
|
# 示例:假设模板中的占位符格式为 {block_name}
|
||||||
# 我们需要解析模板,找到可重排的组,并重新构建模板字符串。
|
# 我们需要解析模板,找到可重排的组,并重新构建模板字符串。
|
||||||
|
|
||||||
# 注意:这是一个复杂的逻辑,通常需要一个简单的模板引擎或正则表达式来完成。
|
# 注意:这是一个复杂的逻辑,通常需要一个简单的模板引擎或正则表达式来完成。
|
||||||
# 为保持此函数职责单一,这里仅演示核心的重排逻辑,
|
# 为保持此函数职责单一,这里仅演示核心的重排逻辑,
|
||||||
# 完整的模板重建逻辑应在调用此函数的地方处理。
|
# 完整的模板重建逻辑应在调用此函数的地方处理。
|
||||||
@@ -58,14 +58,14 @@ class BlockShuffler:
|
|||||||
for group in BlockShuffler.SWAPPABLE_BLOCK_GROUPS:
|
for group in BlockShuffler.SWAPPABLE_BLOCK_GROUPS:
|
||||||
# 过滤出在当前上下文中实际存在的、非空的block
|
# 过滤出在当前上下文中实际存在的、非空的block
|
||||||
existing_blocks = [
|
existing_blocks = [
|
||||||
block for block in group if block in context_data and context_data[block]
|
block for block in group if context_data.get(block)
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(existing_blocks) > 1:
|
if len(existing_blocks) > 1:
|
||||||
# 随机打乱顺序
|
# 随机打乱顺序
|
||||||
random.shuffle(existing_blocks)
|
random.shuffle(existing_blocks)
|
||||||
logger.debug(f"重排block组: {group} -> {existing_blocks}")
|
logger.debug(f"重排block组: {group} -> {existing_blocks}")
|
||||||
|
|
||||||
# 这里的实现需要调用者根据 `existing_blocks` 的新顺序
|
# 这里的实现需要调用者根据 `existing_blocks` 的新顺序
|
||||||
# 去动态地重新组织 `prompt_template` 字符串。
|
# 去动态地重新组织 `prompt_template` 字符串。
|
||||||
# 例如,找到模板中与 `group` 相关的占位符部分,然后按新顺序替换它们。
|
# 例如,找到模板中与 `group` 相关的占位符部分,然后按新顺序替换它们。
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import asyncio
|
|||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from src.chat.utils.prompt_params import PromptParameters
|
from src.chat.utils.prompt_params import PromptParameters
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -119,7 +118,7 @@ class PromptComponentManager:
|
|||||||
async def add_injection_rule(
|
async def add_injection_rule(
|
||||||
self,
|
self,
|
||||||
prompt_name: str,
|
prompt_name: str,
|
||||||
rules: List[InjectionRule],
|
rules: list[InjectionRule],
|
||||||
content_provider: Callable[..., Awaitable[str]],
|
content_provider: Callable[..., Awaitable[str]],
|
||||||
source: str = "runtime",
|
source: str = "runtime",
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -521,7 +520,7 @@ class PromptComponentManager:
|
|||||||
else:
|
else:
|
||||||
for name, (rule, _, _) in rules_for_target.items():
|
for name, (rule, _, _) in rules_for_target.items():
|
||||||
target_copy[name] = rule
|
target_copy[name] = rule
|
||||||
|
|
||||||
if target_copy:
|
if target_copy:
|
||||||
rules_copy[target] = target_copy
|
rules_copy[target] = target_copy
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class PromptParameters:
|
|||||||
action_descriptions: str = ""
|
action_descriptions: str = ""
|
||||||
notice_block: str = ""
|
notice_block: str = ""
|
||||||
group_chat_reminder_block: str = ""
|
group_chat_reminder_block: str = ""
|
||||||
|
|
||||||
# 可用动作信息
|
# 可用动作信息
|
||||||
available_actions: dict[str, Any] | None = None
|
available_actions: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -228,9 +228,9 @@ class HTMLReportGenerator:
|
|||||||
|
|
||||||
# 渲染模板
|
# 渲染模板
|
||||||
# 读取CSS和JS文件内容
|
# 读取CSS和JS文件内容
|
||||||
async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.css"), "r", encoding="utf-8") as f:
|
async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.css"), encoding="utf-8") as f:
|
||||||
report_css = await f.read()
|
report_css = await f.read()
|
||||||
async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.js"), "r", encoding="utf-8") as f:
|
async with aiofiles.open(os.path.join(self.jinja_env.loader.searchpath[0], "report.js"), encoding="utf-8") as f:
|
||||||
report_js = await f.read()
|
report_js = await f.read()
|
||||||
# 渲染模板
|
# 渲染模板
|
||||||
template = self.jinja_env.get_template("report.html")
|
template = self.jinja_env.get_template("report.html")
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ from collections import defaultdict
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import aiofiles
|
|
||||||
|
|
||||||
from src.common.database.compatibility import db_get, db_query
|
from src.common.database.compatibility import db_get, db_query
|
||||||
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -16,7 +14,7 @@ logger = get_logger("maibot_statistic")
|
|||||||
# 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。
|
# 彻底异步化:删除原同步包装器 _sync_db_get,所有数据库访问统一使用 await db_get。
|
||||||
|
|
||||||
|
|
||||||
from .report_generator import HTMLReportGenerator, format_online_time
|
from .report_generator import HTMLReportGenerator
|
||||||
from .statistic_keys import *
|
from .statistic_keys import *
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
"""
|
||||||
该模块用于存放统计数据相关的常量键名。
|
该模块用于存放统计数据相关的常量键名。
|
||||||
"""
|
"""
|
||||||
@@ -61,4 +60,4 @@ STD_TIME_COST_BY_PROVIDER = "std_time_costs_by_provider"
|
|||||||
PIE_CHART_COST_BY_PROVIDER = "pie_chart_cost_by_provider"
|
PIE_CHART_COST_BY_PROVIDER = "pie_chart_cost_by_provider"
|
||||||
PIE_CHART_REQ_BY_PROVIDER = "pie_chart_req_by_provider"
|
PIE_CHART_REQ_BY_PROVIDER = "pie_chart_req_by_provider"
|
||||||
BAR_CHART_COST_BY_MODEL = "bar_chart_cost_by_model"
|
BAR_CHART_COST_BY_MODEL = "bar_chart_cost_by_model"
|
||||||
BAR_CHART_REQ_BY_MODEL = "bar_chart_req_by_model"
|
BAR_CHART_REQ_BY_MODEL = "bar_chart_req_by_model"
|
||||||
|
|||||||
@@ -537,7 +537,7 @@ class _PromptProcessor:
|
|||||||
else:
|
else:
|
||||||
is_truncated = True
|
is_truncated = True
|
||||||
return content, reasoning, is_truncated
|
return content, reasoning, is_truncated
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _extract_reasoning(content: str) -> tuple[str, str]:
|
async def _extract_reasoning(content: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# 再用这个就写一行注释来混提交的我直接全部🌿飞😡
|
# 再用这个就写一行注释来混提交的我直接全部🌿飞😡
|
||||||
|
# 🌿🌿need
|
||||||
import asyncio
|
import asyncio
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -21,7 +22,6 @@ from src.common.message import get_global_api
|
|||||||
|
|
||||||
# 全局背景任务集合
|
# 全局背景任务集合
|
||||||
_background_tasks = set()
|
_background_tasks = set()
|
||||||
from src.common.remote import TelemetryHeartBeatTask
|
|
||||||
from src.common.server import Server, get_global_server
|
from src.common.server import Server, get_global_server
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.individuality.individuality import Individuality, get_individuality
|
from src.individuality.individuality import Individuality, get_individuality
|
||||||
|
|||||||
@@ -507,7 +507,7 @@ class PersistenceManager:
|
|||||||
GraphStore 对象
|
GraphStore 对象
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with aiofiles.open(input_file, "r", encoding="utf-8") as f:
|
async with aiofiles.open(input_file, encoding="utf-8") as f:
|
||||||
content = await f.read()
|
content = await f.read()
|
||||||
data = json.loads(content)
|
data = json.loads(content)
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class MemoryTools:
|
|||||||
graph_store=graph_store,
|
graph_store=graph_store,
|
||||||
embedding_generator=embedding_generator,
|
embedding_generator=embedding_generator,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化路径扩展器(延迟初始化,仅在启用时创建)
|
# 初始化路径扩展器(延迟初始化,仅在启用时创建)
|
||||||
self.path_expander: PathScoreExpansion | None = None
|
self.path_expander: PathScoreExpansion | None = None
|
||||||
|
|
||||||
@@ -573,7 +573,7 @@ class MemoryTools:
|
|||||||
# 检查是否启用路径扩展算法
|
# 检查是否启用路径扩展算法
|
||||||
use_path_expansion = getattr(global_config.memory, "enable_path_expansion", False) and expand_depth > 0
|
use_path_expansion = getattr(global_config.memory, "enable_path_expansion", False) and expand_depth > 0
|
||||||
expanded_memory_scores = {}
|
expanded_memory_scores = {}
|
||||||
|
|
||||||
if expand_depth > 0 and initial_memory_ids:
|
if expand_depth > 0 and initial_memory_ids:
|
||||||
# 获取查询的embedding
|
# 获取查询的embedding
|
||||||
query_embedding = None
|
query_embedding = None
|
||||||
@@ -582,12 +582,12 @@ class MemoryTools:
|
|||||||
query_embedding = await self.builder.embedding_generator.generate(query)
|
query_embedding = await self.builder.embedding_generator.generate(query)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"生成查询embedding失败: {e}")
|
logger.warning(f"生成查询embedding失败: {e}")
|
||||||
|
|
||||||
if query_embedding is not None:
|
if query_embedding is not None:
|
||||||
if use_path_expansion:
|
if use_path_expansion:
|
||||||
# 🆕 使用路径评分扩展算法
|
# 🆕 使用路径评分扩展算法
|
||||||
logger.info(f"🔬 使用路径评分扩展算法: 初始{len(similar_nodes)}个节点, 深度={expand_depth}")
|
logger.info(f"🔬 使用路径评分扩展算法: 初始{len(similar_nodes)}个节点, 深度={expand_depth}")
|
||||||
|
|
||||||
# 延迟初始化路径扩展器
|
# 延迟初始化路径扩展器
|
||||||
if self.path_expander is None:
|
if self.path_expander is None:
|
||||||
path_config = PathExpansionConfig(
|
path_config = PathExpansionConfig(
|
||||||
@@ -607,7 +607,7 @@ class MemoryTools:
|
|||||||
vector_store=self.vector_store,
|
vector_store=self.vector_store,
|
||||||
config=path_config
|
config=path_config
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 执行路径扩展(传递偏好类型)
|
# 执行路径扩展(传递偏好类型)
|
||||||
path_results = await self.path_expander.expand_with_path_scoring(
|
path_results = await self.path_expander.expand_with_path_scoring(
|
||||||
@@ -616,11 +616,11 @@ class MemoryTools:
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
prefer_node_types=all_prefer_types # 🆕 传递偏好类型
|
prefer_node_types=all_prefer_types # 🆕 传递偏好类型
|
||||||
)
|
)
|
||||||
|
|
||||||
# 路径扩展返回的是 [(Memory, final_score, paths), ...]
|
# 路径扩展返回的是 [(Memory, final_score, paths), ...]
|
||||||
# 我们需要直接返回这些记忆,跳过后续的传统评分
|
# 我们需要直接返回这些记忆,跳过后续的传统评分
|
||||||
logger.info(f"✅ 路径扩展返回 {len(path_results)} 条记忆")
|
logger.info(f"✅ 路径扩展返回 {len(path_results)} 条记忆")
|
||||||
|
|
||||||
# 直接构建返回结果
|
# 直接构建返回结果
|
||||||
path_memories = []
|
path_memories = []
|
||||||
for memory, score, paths in path_results:
|
for memory, score, paths in path_results:
|
||||||
@@ -635,25 +635,25 @@ class MemoryTools:
|
|||||||
"max_path_depth": max(p.depth for p in paths) if paths else 0
|
"max_path_depth": max(p.depth for p in paths) if paths else 0
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"🎯 路径扩展最终返回: {len(path_memories)} 条记忆")
|
logger.info(f"🎯 路径扩展最终返回: {len(path_memories)} 条记忆")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"results": path_memories,
|
"results": path_memories,
|
||||||
"total": len(path_memories),
|
"total": len(path_memories),
|
||||||
"expansion_method": "path_scoring"
|
"expansion_method": "path_scoring"
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"路径扩展失败: {e}", exc_info=True)
|
logger.error(f"路径扩展失败: {e}", exc_info=True)
|
||||||
logger.info("回退到传统图扩展算法")
|
logger.info("回退到传统图扩展算法")
|
||||||
# 继续执行下面的传统图扩展
|
# 继续执行下面的传统图扩展
|
||||||
|
|
||||||
# 传统图扩展(仅在未启用路径扩展或路径扩展失败时执行)
|
# 传统图扩展(仅在未启用路径扩展或路径扩展失败时执行)
|
||||||
if not use_path_expansion or expanded_memory_scores == {}:
|
if not use_path_expansion or expanded_memory_scores == {}:
|
||||||
logger.info(f"开始传统图扩展: 初始记忆{len(initial_memory_ids)}个, 深度={expand_depth}")
|
logger.info(f"开始传统图扩展: 初始记忆{len(initial_memory_ids)}个, 深度={expand_depth}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用共享的图扩展工具函数
|
# 使用共享的图扩展工具函数
|
||||||
expanded_results = await expand_memories_with_semantic_filter(
|
expanded_results = await expand_memories_with_semantic_filter(
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ from src.memory_graph.utils.time_parser import TimeParser
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EmbeddingGenerator",
|
"EmbeddingGenerator",
|
||||||
|
"Path",
|
||||||
|
"PathExpansionConfig",
|
||||||
|
"PathScoreExpansion",
|
||||||
"TimeParser",
|
"TimeParser",
|
||||||
"cosine_similarity",
|
"cosine_similarity",
|
||||||
"get_embedding_generator",
|
"get_embedding_generator",
|
||||||
"PathScoreExpansion",
|
|
||||||
"PathExpansionConfig",
|
|
||||||
"Path",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from src.common.logger import get_logger
|
|||||||
from src.memory_graph.utils.similarity import cosine_similarity
|
from src.memory_graph.utils.similarity import cosine_similarity
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.memory_graph.models import Memory
|
pass
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -41,52 +41,52 @@ async def deduplicate_memories_by_similarity(
|
|||||||
"""
|
"""
|
||||||
if len(memories) <= 1:
|
if len(memories) <= 1:
|
||||||
return memories
|
return memories
|
||||||
|
|
||||||
logger.info(f"开始记忆去重: {len(memories)} 条记忆 (阈值={similarity_threshold})")
|
logger.info(f"开始记忆去重: {len(memories)} 条记忆 (阈值={similarity_threshold})")
|
||||||
|
|
||||||
# 准备数据结构
|
# 准备数据结构
|
||||||
memory_embeddings = []
|
memory_embeddings = []
|
||||||
for memory, score, extra in memories:
|
for memory, score, extra in memories:
|
||||||
# 获取记忆的向量表示
|
# 获取记忆的向量表示
|
||||||
embedding = await _get_memory_embedding(memory)
|
embedding = await _get_memory_embedding(memory)
|
||||||
memory_embeddings.append((memory, score, extra, embedding))
|
memory_embeddings.append((memory, score, extra, embedding))
|
||||||
|
|
||||||
# 构建相似度矩阵并找出重复组
|
# 构建相似度矩阵并找出重复组
|
||||||
duplicate_groups = _find_duplicate_groups(memory_embeddings, similarity_threshold)
|
duplicate_groups = _find_duplicate_groups(memory_embeddings, similarity_threshold)
|
||||||
|
|
||||||
# 合并每个重复组
|
# 合并每个重复组
|
||||||
deduplicated = []
|
deduplicated = []
|
||||||
processed_indices = set()
|
processed_indices = set()
|
||||||
|
|
||||||
for group_indices in duplicate_groups:
|
for group_indices in duplicate_groups:
|
||||||
if any(i in processed_indices for i in group_indices):
|
if any(i in processed_indices for i in group_indices):
|
||||||
continue # 已经处理过
|
continue # 已经处理过
|
||||||
|
|
||||||
# 标记为已处理
|
# 标记为已处理
|
||||||
processed_indices.update(group_indices)
|
processed_indices.update(group_indices)
|
||||||
|
|
||||||
# 合并组内记忆
|
# 合并组内记忆
|
||||||
group_memories = [memory_embeddings[i] for i in group_indices]
|
group_memories = [memory_embeddings[i] for i in group_indices]
|
||||||
merged_memory = _merge_memory_group(group_memories)
|
merged_memory = _merge_memory_group(group_memories)
|
||||||
deduplicated.append(merged_memory)
|
deduplicated.append(merged_memory)
|
||||||
|
|
||||||
# 添加未被合并的记忆
|
# 添加未被合并的记忆
|
||||||
for i, (memory, score, extra, _) in enumerate(memory_embeddings):
|
for i, (memory, score, extra, _) in enumerate(memory_embeddings):
|
||||||
if i not in processed_indices:
|
if i not in processed_indices:
|
||||||
deduplicated.append((memory, score, extra))
|
deduplicated.append((memory, score, extra))
|
||||||
|
|
||||||
# 按分数排序
|
# 按分数排序
|
||||||
deduplicated.sort(key=lambda x: x[1], reverse=True)
|
deduplicated.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
# 限制数量
|
# 限制数量
|
||||||
if keep_top_n is not None:
|
if keep_top_n is not None:
|
||||||
deduplicated = deduplicated[:keep_top_n]
|
deduplicated = deduplicated[:keep_top_n]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"去重完成: {len(memories)} → {len(deduplicated)} 条记忆 "
|
f"去重完成: {len(memories)} → {len(deduplicated)} 条记忆 "
|
||||||
f"(合并了 {len(memories) - len(deduplicated)} 条重复)"
|
f"(合并了 {len(memories) - len(deduplicated)} 条重复)"
|
||||||
)
|
)
|
||||||
|
|
||||||
return deduplicated
|
return deduplicated
|
||||||
|
|
||||||
|
|
||||||
@@ -104,7 +104,7 @@ async def _get_memory_embedding(memory: Any) -> list[float] | None:
|
|||||||
# nodes 是 MemoryNode 对象列表
|
# nodes 是 MemoryNode 对象列表
|
||||||
first_node = memory.nodes[0]
|
first_node = memory.nodes[0]
|
||||||
node_id = getattr(first_node, "id", None)
|
node_id = getattr(first_node, "id", None)
|
||||||
|
|
||||||
if node_id:
|
if node_id:
|
||||||
# 直接从 embedding 属性获取(如果存在)
|
# 直接从 embedding 属性获取(如果存在)
|
||||||
if hasattr(first_node, "embedding") and first_node.embedding is not None:
|
if hasattr(first_node, "embedding") and first_node.embedding is not None:
|
||||||
@@ -114,7 +114,7 @@ async def _get_memory_embedding(memory: Any) -> list[float] | None:
|
|||||||
return embedding.tolist()
|
return embedding.tolist()
|
||||||
elif isinstance(embedding, list):
|
elif isinstance(embedding, list):
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
# 无法获取 embedding
|
# 无法获取 embedding
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -132,13 +132,13 @@ def _find_duplicate_groups(
|
|||||||
"""
|
"""
|
||||||
n = len(memory_embeddings)
|
n = len(memory_embeddings)
|
||||||
similarity_matrix = [[0.0] * n for _ in range(n)]
|
similarity_matrix = [[0.0] * n for _ in range(n)]
|
||||||
|
|
||||||
# 计算相似度矩阵
|
# 计算相似度矩阵
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
for j in range(i + 1, n):
|
for j in range(i + 1, n):
|
||||||
embedding_i = memory_embeddings[i][3]
|
embedding_i = memory_embeddings[i][3]
|
||||||
embedding_j = memory_embeddings[j][3]
|
embedding_j = memory_embeddings[j][3]
|
||||||
|
|
||||||
# 跳过 None 或零向量
|
# 跳过 None 或零向量
|
||||||
if (embedding_i is None or embedding_j is None or
|
if (embedding_i is None or embedding_j is None or
|
||||||
all(x == 0.0 for x in embedding_i) or all(x == 0.0 for x in embedding_j)):
|
all(x == 0.0 for x in embedding_i) or all(x == 0.0 for x in embedding_j)):
|
||||||
@@ -146,29 +146,29 @@ def _find_duplicate_groups(
|
|||||||
else:
|
else:
|
||||||
# cosine_similarity 会自动转换为 numpy 数组
|
# cosine_similarity 会自动转换为 numpy 数组
|
||||||
similarity = float(cosine_similarity(embedding_i, embedding_j)) # type: ignore
|
similarity = float(cosine_similarity(embedding_i, embedding_j)) # type: ignore
|
||||||
|
|
||||||
similarity_matrix[i][j] = similarity
|
similarity_matrix[i][j] = similarity
|
||||||
similarity_matrix[j][i] = similarity
|
similarity_matrix[j][i] = similarity
|
||||||
|
|
||||||
# 使用并查集找出连通分量
|
# 使用并查集找出连通分量
|
||||||
parent = list(range(n))
|
parent = list(range(n))
|
||||||
|
|
||||||
def find(x):
|
def find(x):
|
||||||
if parent[x] != x:
|
if parent[x] != x:
|
||||||
parent[x] = find(parent[x])
|
parent[x] = find(parent[x])
|
||||||
return parent[x]
|
return parent[x]
|
||||||
|
|
||||||
def union(x, y):
|
def union(x, y):
|
||||||
px, py = find(x), find(y)
|
px, py = find(x), find(y)
|
||||||
if px != py:
|
if px != py:
|
||||||
parent[px] = py
|
parent[px] = py
|
||||||
|
|
||||||
# 合并相似的记忆
|
# 合并相似的记忆
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
for j in range(i + 1, n):
|
for j in range(i + 1, n):
|
||||||
if similarity_matrix[i][j] >= threshold:
|
if similarity_matrix[i][j] >= threshold:
|
||||||
union(i, j)
|
union(i, j)
|
||||||
|
|
||||||
# 构建组
|
# 构建组
|
||||||
groups_dict: dict[int, list[int]] = {}
|
groups_dict: dict[int, list[int]] = {}
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
@@ -176,10 +176,10 @@ def _find_duplicate_groups(
|
|||||||
if root not in groups_dict:
|
if root not in groups_dict:
|
||||||
groups_dict[root] = []
|
groups_dict[root] = []
|
||||||
groups_dict[root].append(i)
|
groups_dict[root].append(i)
|
||||||
|
|
||||||
# 只返回大小 > 1 的组(真正的重复组)
|
# 只返回大小 > 1 的组(真正的重复组)
|
||||||
duplicate_groups = [group for group in groups_dict.values() if len(group) > 1]
|
duplicate_groups = [group for group in groups_dict.values() if len(group) > 1]
|
||||||
|
|
||||||
return duplicate_groups
|
return duplicate_groups
|
||||||
|
|
||||||
|
|
||||||
@@ -196,10 +196,10 @@ def _merge_memory_group(
|
|||||||
"""
|
"""
|
||||||
# 按分数排序
|
# 按分数排序
|
||||||
sorted_group = sorted(group, key=lambda x: x[1], reverse=True)
|
sorted_group = sorted(group, key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
# 保留分数最高的记忆
|
# 保留分数最高的记忆
|
||||||
best_memory, best_score, best_extra, _ = sorted_group[0]
|
best_memory, best_score, best_extra, _ = sorted_group[0]
|
||||||
|
|
||||||
# 计算合并后的分数(加权平均,权重递减)
|
# 计算合并后的分数(加权平均,权重递减)
|
||||||
total_weight = 0.0
|
total_weight = 0.0
|
||||||
weighted_sum = 0.0
|
weighted_sum = 0.0
|
||||||
@@ -207,17 +207,17 @@ def _merge_memory_group(
|
|||||||
weight = 1.0 / (i + 1) # 第1名权重1.0,第2名0.5,第3名0.33...
|
weight = 1.0 / (i + 1) # 第1名权重1.0,第2名0.5,第3名0.33...
|
||||||
weighted_sum += score * weight
|
weighted_sum += score * weight
|
||||||
total_weight += weight
|
total_weight += weight
|
||||||
|
|
||||||
merged_score = weighted_sum / total_weight if total_weight > 0 else best_score
|
merged_score = weighted_sum / total_weight if total_weight > 0 else best_score
|
||||||
|
|
||||||
# 增强 extra_data
|
# 增强 extra_data
|
||||||
merged_extra = best_extra if isinstance(best_extra, dict) else {}
|
merged_extra = best_extra if isinstance(best_extra, dict) else {}
|
||||||
merged_extra["merged_count"] = len(sorted_group)
|
merged_extra["merged_count"] = len(sorted_group)
|
||||||
merged_extra["original_scores"] = [score for _, score, _, _ in sorted_group]
|
merged_extra["original_scores"] = [score for _, score, _, _ in sorted_group]
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"合并 {len(sorted_group)} 条相似记忆: "
|
f"合并 {len(sorted_group)} 条相似记忆: "
|
||||||
f"分数 {best_score:.3f} → {merged_score:.3f}"
|
f"分数 {best_score:.3f} → {merged_score:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return (best_memory, merged_score, merged_extra)
|
return (best_memory, merged_score, merged_extra)
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from src.memory_graph.utils.similarity import cosine_similarity
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.memory_graph.models import Memory
|
|
||||||
from src.memory_graph.storage.graph_store import GraphStore
|
from src.memory_graph.storage.graph_store import GraphStore
|
||||||
from src.memory_graph.storage.vector_store import VectorStore
|
from src.memory_graph.storage.vector_store import VectorStore
|
||||||
|
|
||||||
@@ -71,7 +70,7 @@ class PathExpansionConfig:
|
|||||||
medium_score_threshold: float = 0.4 # 中分路径阈值
|
medium_score_threshold: float = 0.4 # 中分路径阈值
|
||||||
max_active_paths: int = 1000 # 最大活跃路径数(防止爆炸)
|
max_active_paths: int = 1000 # 最大活跃路径数(防止爆炸)
|
||||||
top_paths_retain: int = 500 # 超限时保留的top路径数
|
top_paths_retain: int = 500 # 超限时保留的top路径数
|
||||||
|
|
||||||
# 🚀 性能优化参数
|
# 🚀 性能优化参数
|
||||||
enable_early_stop: bool = True # 启用早停(如果路径增长很少则提前结束)
|
enable_early_stop: bool = True # 启用早停(如果路径增长很少则提前结束)
|
||||||
early_stop_growth_threshold: float = 0.1 # 早停阈值(路径增长率低于10%则停止)
|
early_stop_growth_threshold: float = 0.1 # 早停阈值(路径增长率低于10%则停止)
|
||||||
@@ -121,7 +120,7 @@ class PathScoreExpansion:
|
|||||||
self.vector_store = vector_store
|
self.vector_store = vector_store
|
||||||
self.config = config or PathExpansionConfig()
|
self.config = config or PathExpansionConfig()
|
||||||
self.prefer_node_types: list[str] = [] # 🆕 偏好节点类型
|
self.prefer_node_types: list[str] = [] # 🆕 偏好节点类型
|
||||||
|
|
||||||
# 🚀 性能优化:邻居边缓存
|
# 🚀 性能优化:邻居边缓存
|
||||||
self._neighbor_cache: dict[str, list[Any]] = {}
|
self._neighbor_cache: dict[str, list[Any]] = {}
|
||||||
self._node_score_cache: dict[str, float] = {}
|
self._node_score_cache: dict[str, float] = {}
|
||||||
@@ -212,11 +211,11 @@ class PathScoreExpansion:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
edge_weight = self._get_edge_weight(edge)
|
edge_weight = self._get_edge_weight(edge)
|
||||||
|
|
||||||
# 记录候选
|
# 记录候选
|
||||||
path_candidates.append((path, edge, next_node, edge_weight))
|
path_candidates.append((path, edge, next_node, edge_weight))
|
||||||
candidate_nodes_for_batch.add(next_node)
|
candidate_nodes_for_batch.add(next_node)
|
||||||
|
|
||||||
branch_count += 1
|
branch_count += 1
|
||||||
if branch_count >= max_branches:
|
if branch_count >= max_branches:
|
||||||
break
|
break
|
||||||
@@ -281,7 +280,7 @@ class PathScoreExpansion:
|
|||||||
# 🚀 早停检测:如果路径增长很少,提前终止
|
# 🚀 早停检测:如果路径增长很少,提前终止
|
||||||
prev_path_count = len(active_paths)
|
prev_path_count = len(active_paths)
|
||||||
active_paths = next_paths
|
active_paths = next_paths
|
||||||
|
|
||||||
if self.config.enable_early_stop and prev_path_count > 0:
|
if self.config.enable_early_stop and prev_path_count > 0:
|
||||||
growth_rate = (len(active_paths) - prev_path_count) / prev_path_count
|
growth_rate = (len(active_paths) - prev_path_count) / prev_path_count
|
||||||
if growth_rate < self.config.early_stop_growth_threshold:
|
if growth_rate < self.config.early_stop_growth_threshold:
|
||||||
@@ -346,18 +345,18 @@ class PathScoreExpansion:
|
|||||||
max_path_score = max(p.score for p in paths) if paths else 0
|
max_path_score = max(p.score for p in paths) if paths else 0
|
||||||
rough_score = len(paths) * max_path_score * memory.importance
|
rough_score = len(paths) * max_path_score * memory.importance
|
||||||
memory_scores_rough.append((mem_id, rough_score))
|
memory_scores_rough.append((mem_id, rough_score))
|
||||||
|
|
||||||
# 保留top候选
|
# 保留top候选
|
||||||
memory_scores_rough.sort(key=lambda x: x[1], reverse=True)
|
memory_scores_rough.sort(key=lambda x: x[1], reverse=True)
|
||||||
retained_mem_ids = set(mem_id for mem_id, _ in memory_scores_rough[:self.config.max_candidate_memories])
|
retained_mem_ids = set(mem_id for mem_id, _ in memory_scores_rough[:self.config.max_candidate_memories])
|
||||||
|
|
||||||
# 过滤
|
# 过滤
|
||||||
memory_paths = {
|
memory_paths = {
|
||||||
mem_id: (memory, paths)
|
mem_id: (memory, paths)
|
||||||
for mem_id, (memory, paths) in memory_paths.items()
|
for mem_id, (memory, paths) in memory_paths.items()
|
||||||
if mem_id in retained_mem_ids
|
if mem_id in retained_mem_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"⚡ 粗排过滤: {len(memory_scores_rough)} → {len(memory_paths)} 条候选记忆"
|
f"⚡ 粗排过滤: {len(memory_scores_rough)} → {len(memory_paths)} 条候选记忆"
|
||||||
)
|
)
|
||||||
@@ -398,7 +397,7 @@ class PathScoreExpansion:
|
|||||||
# 🚀 缓存检查
|
# 🚀 缓存检查
|
||||||
if node_id in self._neighbor_cache:
|
if node_id in self._neighbor_cache:
|
||||||
return self._neighbor_cache[node_id]
|
return self._neighbor_cache[node_id]
|
||||||
|
|
||||||
edges = []
|
edges = []
|
||||||
|
|
||||||
# 从图存储中获取与该节点相关的所有边
|
# 从图存储中获取与该节点相关的所有边
|
||||||
@@ -454,7 +453,7 @@ class PathScoreExpansion:
|
|||||||
"""
|
"""
|
||||||
# 从向量存储获取节点数据
|
# 从向量存储获取节点数据
|
||||||
node_data = await self.vector_store.get_node_by_id(node_id)
|
node_data = await self.vector_store.get_node_by_id(node_id)
|
||||||
|
|
||||||
if query_embedding is None:
|
if query_embedding is None:
|
||||||
base_score = 0.5 # 默认中等分数
|
base_score = 0.5 # 默认中等分数
|
||||||
else:
|
else:
|
||||||
@@ -493,27 +492,27 @@ class PathScoreExpansion:
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
scores = {}
|
scores = {}
|
||||||
|
|
||||||
if query_embedding is None:
|
if query_embedding is None:
|
||||||
# 无查询向量时,返回默认分数
|
# 无查询向量时,返回默认分数
|
||||||
return {nid: 0.5 for nid in node_ids}
|
return dict.fromkeys(node_ids, 0.5)
|
||||||
|
|
||||||
# 批量获取节点数据
|
# 批量获取节点数据
|
||||||
node_data_list = await asyncio.gather(
|
node_data_list = await asyncio.gather(
|
||||||
*[self.vector_store.get_node_by_id(nid) for nid in node_ids],
|
*[self.vector_store.get_node_by_id(nid) for nid in node_ids],
|
||||||
return_exceptions=True
|
return_exceptions=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 收集有效的嵌入向量
|
# 收集有效的嵌入向量
|
||||||
valid_embeddings = []
|
valid_embeddings = []
|
||||||
valid_node_ids = []
|
valid_node_ids = []
|
||||||
node_metadata_map = {}
|
node_metadata_map = {}
|
||||||
|
|
||||||
for nid, node_data in zip(node_ids, node_data_list):
|
for nid, node_data in zip(node_ids, node_data_list):
|
||||||
if isinstance(node_data, Exception):
|
if isinstance(node_data, Exception):
|
||||||
scores[nid] = 0.3
|
scores[nid] = 0.3
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 类型守卫:确保 node_data 是字典
|
# 类型守卫:确保 node_data 是字典
|
||||||
if not node_data or not isinstance(node_data, dict) or "embedding" not in node_data:
|
if not node_data or not isinstance(node_data, dict) or "embedding" not in node_data:
|
||||||
scores[nid] = 0.3
|
scores[nid] = 0.3
|
||||||
@@ -521,21 +520,21 @@ class PathScoreExpansion:
|
|||||||
valid_embeddings.append(node_data["embedding"])
|
valid_embeddings.append(node_data["embedding"])
|
||||||
valid_node_ids.append(nid)
|
valid_node_ids.append(nid)
|
||||||
node_metadata_map[nid] = node_data.get("metadata", {})
|
node_metadata_map[nid] = node_data.get("metadata", {})
|
||||||
|
|
||||||
if valid_embeddings:
|
if valid_embeddings:
|
||||||
# 批量计算相似度(使用矩阵运算)
|
# 批量计算相似度(使用矩阵运算)
|
||||||
embeddings_matrix = np.array(valid_embeddings)
|
embeddings_matrix = np.array(valid_embeddings)
|
||||||
query_norm = np.linalg.norm(query_embedding)
|
query_norm = np.linalg.norm(query_embedding)
|
||||||
embeddings_norms = np.linalg.norm(embeddings_matrix, axis=1)
|
embeddings_norms = np.linalg.norm(embeddings_matrix, axis=1)
|
||||||
|
|
||||||
# 向量化计算余弦相似度
|
# 向量化计算余弦相似度
|
||||||
similarities = np.dot(embeddings_matrix, query_embedding) / (embeddings_norms * query_norm + 1e-8)
|
similarities = np.dot(embeddings_matrix, query_embedding) / (embeddings_norms * query_norm + 1e-8)
|
||||||
similarities = np.clip(similarities, 0.0, 1.0)
|
similarities = np.clip(similarities, 0.0, 1.0)
|
||||||
|
|
||||||
# 应用偏好类型加成
|
# 应用偏好类型加成
|
||||||
for nid, sim in zip(valid_node_ids, similarities):
|
for nid, sim in zip(valid_node_ids, similarities):
|
||||||
base_score = float(sim)
|
base_score = float(sim)
|
||||||
|
|
||||||
# 偏好类型加成
|
# 偏好类型加成
|
||||||
if self.prefer_node_types and nid in node_metadata_map:
|
if self.prefer_node_types and nid in node_metadata_map:
|
||||||
node_type = node_metadata_map[nid].get("node_type")
|
node_type = node_metadata_map[nid].get("node_type")
|
||||||
@@ -546,7 +545,7 @@ class PathScoreExpansion:
|
|||||||
scores[nid] = base_score
|
scores[nid] = base_score
|
||||||
else:
|
else:
|
||||||
scores[nid] = base_score
|
scores[nid] = base_score
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def _calculate_path_score(self, old_score: float, edge_weight: float, node_score: float, depth: int) -> float:
|
def _calculate_path_score(self, old_score: float, edge_weight: float, node_score: float, depth: int) -> float:
|
||||||
@@ -689,19 +688,19 @@ class PathScoreExpansion:
|
|||||||
# 使用临时字典存储路径列表
|
# 使用临时字典存储路径列表
|
||||||
temp_paths: dict[str, list[Path]] = {}
|
temp_paths: dict[str, list[Path]] = {}
|
||||||
temp_memories: dict[str, Any] = {} # 存储 Memory 对象
|
temp_memories: dict[str, Any] = {} # 存储 Memory 对象
|
||||||
|
|
||||||
# 🚀 性能优化:收集所有需要获取的记忆ID,然后批量获取
|
# 🚀 性能优化:收集所有需要获取的记忆ID,然后批量获取
|
||||||
all_memory_ids = set()
|
all_memory_ids = set()
|
||||||
path_to_memory_ids: dict[int, set[str]] = {} # path对象id -> 记忆ID集合
|
path_to_memory_ids: dict[int, set[str]] = {} # path对象id -> 记忆ID集合
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
memory_ids_in_path = set()
|
memory_ids_in_path = set()
|
||||||
|
|
||||||
# 收集路径中所有节点涉及的记忆
|
# 收集路径中所有节点涉及的记忆
|
||||||
for node_id in path.nodes:
|
for node_id in path.nodes:
|
||||||
memory_ids = self.graph_store.node_to_memories.get(node_id, [])
|
memory_ids = self.graph_store.node_to_memories.get(node_id, [])
|
||||||
memory_ids_in_path.update(memory_ids)
|
memory_ids_in_path.update(memory_ids)
|
||||||
|
|
||||||
all_memory_ids.update(memory_ids_in_path)
|
all_memory_ids.update(memory_ids_in_path)
|
||||||
path_to_memory_ids[id(path)] = memory_ids_in_path
|
path_to_memory_ids[id(path)] = memory_ids_in_path
|
||||||
|
|
||||||
@@ -712,11 +711,11 @@ class PathScoreExpansion:
|
|||||||
memory = self.graph_store.get_memory_by_id(mem_id)
|
memory = self.graph_store.get_memory_by_id(mem_id)
|
||||||
if memory:
|
if memory:
|
||||||
memory_cache[mem_id] = memory
|
memory_cache[mem_id] = memory
|
||||||
|
|
||||||
# 构建映射关系
|
# 构建映射关系
|
||||||
for path in paths:
|
for path in paths:
|
||||||
memory_ids_in_path = path_to_memory_ids[id(path)]
|
memory_ids_in_path = path_to_memory_ids[id(path)]
|
||||||
|
|
||||||
for mem_id in memory_ids_in_path:
|
for mem_id in memory_ids_in_path:
|
||||||
if mem_id in memory_cache:
|
if mem_id in memory_cache:
|
||||||
if mem_id not in temp_paths:
|
if mem_id not in temp_paths:
|
||||||
@@ -745,10 +744,10 @@ class PathScoreExpansion:
|
|||||||
[(Memory, final_score, paths), ...]
|
[(Memory, final_score, paths), ...]
|
||||||
"""
|
"""
|
||||||
scored_memories = []
|
scored_memories = []
|
||||||
|
|
||||||
# 🚀 性能优化:如果需要偏好类型加成,批量预加载所有节点的类型信息
|
# 🚀 性能优化:如果需要偏好类型加成,批量预加载所有节点的类型信息
|
||||||
node_type_cache: dict[str, str | None] = {}
|
node_type_cache: dict[str, str | None] = {}
|
||||||
|
|
||||||
if self.prefer_node_types:
|
if self.prefer_node_types:
|
||||||
# 收集所有需要查询的节点ID
|
# 收集所有需要查询的节点ID
|
||||||
all_node_ids = set()
|
all_node_ids = set()
|
||||||
@@ -757,7 +756,7 @@ class PathScoreExpansion:
|
|||||||
for node in memory_nodes:
|
for node in memory_nodes:
|
||||||
node_id = node.id if hasattr(node, "id") else str(node)
|
node_id = node.id if hasattr(node, "id") else str(node)
|
||||||
all_node_ids.add(node_id)
|
all_node_ids.add(node_id)
|
||||||
|
|
||||||
# 批量获取节点数据
|
# 批量获取节点数据
|
||||||
if all_node_ids:
|
if all_node_ids:
|
||||||
logger.debug(f"🔍 批量预加载 {len(all_node_ids)} 个节点的类型信息")
|
logger.debug(f"🔍 批量预加载 {len(all_node_ids)} 个节点的类型信息")
|
||||||
@@ -765,7 +764,7 @@ class PathScoreExpansion:
|
|||||||
*[self.vector_store.get_node_by_id(nid) for nid in all_node_ids],
|
*[self.vector_store.get_node_by_id(nid) for nid in all_node_ids],
|
||||||
return_exceptions=True
|
return_exceptions=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建类型缓存
|
# 构建类型缓存
|
||||||
for nid, node_data in zip(all_node_ids, node_data_list):
|
for nid, node_data in zip(all_node_ids, node_data_list):
|
||||||
if isinstance(node_data, Exception) or not node_data or not isinstance(node_data, dict):
|
if isinstance(node_data, Exception) or not node_data or not isinstance(node_data, dict):
|
||||||
@@ -805,7 +804,7 @@ class PathScoreExpansion:
|
|||||||
node_type = node_type_cache.get(node_id)
|
node_type = node_type_cache.get(node_id)
|
||||||
if node_type and node_type in self.prefer_node_types:
|
if node_type and node_type in self.prefer_node_types:
|
||||||
matched_count += 1
|
matched_count += 1
|
||||||
|
|
||||||
if matched_count > 0:
|
if matched_count > 0:
|
||||||
match_ratio = matched_count / len(memory_nodes)
|
match_ratio = matched_count / len(memory_nodes)
|
||||||
# 根据匹配比例给予加成(最高10%)
|
# 根据匹配比例给予加成(最高10%)
|
||||||
@@ -870,4 +869,4 @@ class PathScoreExpansion:
|
|||||||
return recency_score
|
return recency_score
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["PathScoreExpansion", "PathExpansionConfig", "Path"]
|
__all__ = ["Path", "PathExpansionConfig", "PathScoreExpansion"]
|
||||||
|
|||||||
@@ -269,7 +269,7 @@ class RelationshipFetcher:
|
|||||||
platform = "unknown"
|
platform = "unknown"
|
||||||
if existing_stream:
|
if existing_stream:
|
||||||
# 从现有记录获取platform
|
# 从现有记录获取platform
|
||||||
platform = getattr(existing_stream, 'platform', 'unknown') or "unknown"
|
platform = getattr(existing_stream, "platform", "unknown") or "unknown"
|
||||||
logger.debug(f"从现有ChatStream获取到platform: {platform}, stream_id: {stream_id}")
|
logger.debug(f"从现有ChatStream获取到platform: {platform}, stream_id: {stream_id}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"未找到现有ChatStream记录,使用默认platform: unknown, stream_id: {stream_id}")
|
logger.debug(f"未找到现有ChatStream记录,使用默认platform: unknown, stream_id: {stream_id}")
|
||||||
|
|||||||
@@ -742,7 +742,7 @@ class BaseAction(ABC):
|
|||||||
if not case_sensitive:
|
if not case_sensitive:
|
||||||
search_text = search_text.lower()
|
search_text = search_text.lower()
|
||||||
|
|
||||||
matched_keywords: ClassVar = []
|
matched_keywords = []
|
||||||
for keyword in keywords:
|
for keyword in keywords:
|
||||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||||
if check_keyword in search_text:
|
if check_keyword in search_text:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from datetime import datetime
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
from json_repair import repair_json
|
||||||
|
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages_with_id,
|
build_readable_messages_with_id,
|
||||||
@@ -19,7 +20,6 @@ from src.common.logger import get_logger
|
|||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
from json_repair import repair_json
|
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ChatType
|
from src.plugin_system.base.component_types import ActionInfo, ChatType
|
||||||
from src.schedule.schedule_manager import schedule_manager
|
from src.schedule.schedule_manager import schedule_manager
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ class ChatterPlanFilter:
|
|||||||
plan.decided_actions = [
|
plan.decided_actions = [
|
||||||
ActionPlannerInfo(action_type="no_action", reasoning=f"筛选时出错: {e}")
|
ActionPlannerInfo(action_type="no_action", reasoning=f"筛选时出错: {e}")
|
||||||
]
|
]
|
||||||
|
|
||||||
# 在返回最终计划前,打印将要执行的动作
|
# 在返回最终计划前,打印将要执行的动作
|
||||||
if plan.decided_actions:
|
if plan.decided_actions:
|
||||||
action_types = [action.action_type for action in plan.decided_actions]
|
action_types = [action.action_type for action in plan.decided_actions]
|
||||||
@@ -631,7 +631,6 @@ class ChatterPlanFilter:
|
|||||||
candidate_ids.add(normalized_id[1:])
|
candidate_ids.add(normalized_id[1:])
|
||||||
|
|
||||||
# 处理包含在文本中的ID格式 (如 "消息m123" -> 提取 m123)
|
# 处理包含在文本中的ID格式 (如 "消息m123" -> 提取 m123)
|
||||||
import re
|
|
||||||
|
|
||||||
# 尝试提取各种格式的ID
|
# 尝试提取各种格式的ID
|
||||||
id_patterns = [
|
id_patterns = [
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from src.common.data_models.database_data_model import DatabaseMessages
|
|||||||
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
|
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
|
|
||||||
class ChatterPlanGenerator:
|
class ChatterPlanGenerator:
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ class ChatterActionPlanner:
|
|||||||
available_actions = list(initial_plan.available_actions.keys())
|
available_actions = list(initial_plan.available_actions.keys())
|
||||||
plan_filter = ChatterPlanFilter(self.chat_id, available_actions)
|
plan_filter = ChatterPlanFilter(self.chat_id, available_actions)
|
||||||
filtered_plan = await plan_filter.filter(initial_plan)
|
filtered_plan = await plan_filter.filter(initial_plan)
|
||||||
|
|
||||||
# 检查reply动作是否可用
|
# 检查reply动作是否可用
|
||||||
has_reply_action = "reply" in available_actions or "respond" in available_actions
|
has_reply_action = "reply" in available_actions or "respond" in available_actions
|
||||||
if filtered_plan.decided_actions and has_reply_action and reply_not_available:
|
if filtered_plan.decided_actions and has_reply_action and reply_not_available:
|
||||||
|
|||||||
@@ -320,7 +320,7 @@ class QZoneService:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 1. 将评论分为用户评论和自己的回复
|
# 1. 将评论分为用户评论和自己的回复
|
||||||
user_comments = [c for c in comments if str(c.get("qq_account")) != str(qq_account)]
|
user_comments = [c for c in comments if str(c.get("qq_account")) != str(qq_account)]
|
||||||
|
|
||||||
if not user_comments:
|
if not user_comments:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ class SystemCommand(PlusCommand):
|
|||||||
if injections:
|
if injections:
|
||||||
response_parts.append(f"🎯 **{target}** (注入源):")
|
response_parts.append(f"🎯 **{target}** (注入源):")
|
||||||
for inj in injections:
|
for inj in injections:
|
||||||
source_tag = f"({inj['source']})" if inj['source'] != 'static_default' else ''
|
source_tag = f"({inj['source']})" if inj["source"] != "static_default" else ""
|
||||||
response_parts.append(f" ⎿ `{inj['name']}` (优先级: {inj['priority']}) {source_tag}")
|
response_parts.append(f" ⎿ `{inj['name']}` (优先级: {inj['priority']}) {source_tag}")
|
||||||
else:
|
else:
|
||||||
response_parts.append(f"🎯 **{target}** (无注入)")
|
response_parts.append(f"🎯 **{target}** (无注入)")
|
||||||
|
|||||||
Reference in New Issue
Block a user