style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
e7aaafde2f
commit
00ba07e0e1
@@ -31,9 +31,11 @@ class AntiInjectionStatistics:
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 获取最新的统计记录,如果没有则创建
|
||||
stats = (await session.execute(
|
||||
select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())
|
||||
)).scalars().first()
|
||||
stats = (
|
||||
(await session.execute(select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if not stats:
|
||||
stats = AntiInjectionStats()
|
||||
session.add(stats)
|
||||
@@ -49,9 +51,11 @@ class AntiInjectionStatistics:
|
||||
"""更新统计数据"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
stats = (await session.execute(
|
||||
select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())
|
||||
)).scalars().first()
|
||||
stats = (
|
||||
(await session.execute(select(AntiInjectionStats).order_by(AntiInjectionStats.id.desc())))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if not stats:
|
||||
stats = AntiInjectionStats()
|
||||
session.add(stats)
|
||||
|
||||
@@ -2,13 +2,13 @@ from typing import Dict, List, Optional, Any
|
||||
import time
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.plugin_system.base.component_types import ChatType, ComponentType
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chatter_manager")
|
||||
|
||||
|
||||
class ChatterManager:
|
||||
def __init__(self, action_manager: ChatterActionManager):
|
||||
self.action_manager = action_manager
|
||||
@@ -27,6 +27,7 @@ class ChatterManager:
|
||||
"""从组件注册表自动注册已注册的chatter组件"""
|
||||
try:
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
# 获取所有CHATTER类型的组件
|
||||
chatter_components = component_registry.get_enabled_chatter_registry()
|
||||
for chatter_name, chatter_class in chatter_components.items():
|
||||
@@ -70,7 +71,7 @@ class ChatterManager:
|
||||
|
||||
inactive_streams = []
|
||||
for stream_id, instance in self.instances.items():
|
||||
if hasattr(instance, 'get_activity_time'):
|
||||
if hasattr(instance, "get_activity_time"):
|
||||
activity_time = instance.get_activity_time()
|
||||
if (current_time - activity_time) > max_inactive_seconds:
|
||||
inactive_streams.append(stream_id)
|
||||
@@ -91,6 +92,7 @@ class ChatterManager:
|
||||
if not chatter_class:
|
||||
# 如果没有找到精确匹配,尝试查找支持ALL类型的chatter
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
all_chatter_class = self.get_chatter_class(ChatType.ALL)
|
||||
if all_chatter_class:
|
||||
chatter_class = all_chatter_class
|
||||
@@ -110,6 +112,7 @@ class ChatterManager:
|
||||
# 从 mood_manager 获取最新的 chat_stream 并同步回 StreamContext
|
||||
try:
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(stream_id)
|
||||
if mood and mood.chat_stream:
|
||||
context.chat_stream = mood.chat_stream
|
||||
@@ -125,6 +128,7 @@ class ChatterManager:
|
||||
# 在处理完成后,清除该流的未读消息
|
||||
try:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
await message_manager.clear_stream_unread_messages(stream_id)
|
||||
except Exception as clear_e:
|
||||
logger.error(f"清除流 {stream_id} 未读消息时发生错误: {clear_e}")
|
||||
@@ -149,4 +153,4 @@ class ChatterManager:
|
||||
"streams_processed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
表情包发送历史记录模块
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from collections import deque
|
||||
from typing import List, Dict
|
||||
|
||||
@@ -204,9 +204,7 @@ class MaiEmoji:
|
||||
# 2. 删除数据库记录
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(
|
||||
select(Emoji).where(Emoji.emoji_hash == self.hash)
|
||||
)
|
||||
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash))
|
||||
will_delete_emoji = result.scalar_one_or_none()
|
||||
if will_delete_emoji is None:
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
@@ -947,10 +945,7 @@ class EmojiManager:
|
||||
existing_description = None
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(Images).where(
|
||||
Images.emoji_hash == image_hash,
|
||||
Images.type == "emoji"
|
||||
)
|
||||
stmt = select(Images).where(Images.emoji_hash == image_hash, Images.type == "emoji")
|
||||
result = await session.execute(stmt)
|
||||
existing_image = result.scalar_one_or_none()
|
||||
if existing_image and existing_image.description:
|
||||
|
||||
@@ -12,7 +12,7 @@ from .energy_manager import (
|
||||
ActivityEnergyCalculator,
|
||||
RecencyEnergyCalculator,
|
||||
RelationshipEnergyCalculator,
|
||||
energy_manager
|
||||
energy_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -24,5 +24,5 @@ __all__ = [
|
||||
"ActivityEnergyCalculator",
|
||||
"RecencyEnergyCalculator",
|
||||
"RelationshipEnergyCalculator",
|
||||
"energy_manager"
|
||||
]
|
||||
"energy_manager",
|
||||
]
|
||||
|
||||
@@ -17,16 +17,18 @@ logger = get_logger("energy_system")
|
||||
|
||||
class EnergyLevel(Enum):
|
||||
"""能量等级"""
|
||||
|
||||
VERY_LOW = 0.1 # 非常低
|
||||
LOW = 0.3 # 低
|
||||
NORMAL = 0.5 # 正常
|
||||
HIGH = 0.7 # 高
|
||||
VERY_HIGH = 0.9 # 非常高
|
||||
LOW = 0.3 # 低
|
||||
NORMAL = 0.5 # 正常
|
||||
HIGH = 0.7 # 高
|
||||
VERY_HIGH = 0.9 # 非常高
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnergyComponent:
|
||||
"""能量组件"""
|
||||
|
||||
name: str
|
||||
value: float
|
||||
weight: float = 1.0
|
||||
@@ -47,6 +49,7 @@ class EnergyComponent:
|
||||
|
||||
class EnergyContext(TypedDict):
|
||||
"""能量计算上下文"""
|
||||
|
||||
stream_id: str
|
||||
messages: List[Any]
|
||||
user_id: Optional[str]
|
||||
@@ -54,6 +57,7 @@ class EnergyContext(TypedDict):
|
||||
|
||||
class EnergyResult(TypedDict):
|
||||
"""能量计算结果"""
|
||||
|
||||
energy: float
|
||||
level: EnergyLevel
|
||||
distribution_interval: float
|
||||
@@ -114,12 +118,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
|
||||
"""活跃度能量计算器"""
|
||||
|
||||
def __init__(self):
|
||||
self.action_weights = {
|
||||
"reply": 0.4,
|
||||
"react": 0.3,
|
||||
"mention": 0.2,
|
||||
"other": 0.1
|
||||
}
|
||||
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于活跃度计算能量"""
|
||||
@@ -188,7 +187,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
|
||||
else:
|
||||
recency_score = 0.1
|
||||
|
||||
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)")
|
||||
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age / 3600:.1f}小时)")
|
||||
return recency_score
|
||||
|
||||
def get_weight(self) -> float:
|
||||
@@ -236,11 +235,7 @@ class EnergyManager:
|
||||
self.cache_ttl: int = 60 # 1分钟缓存
|
||||
|
||||
# AFC阈值配置
|
||||
self.thresholds: Dict[str, float] = {
|
||||
"high_match": 0.8,
|
||||
"reply": 0.4,
|
||||
"non_reply": 0.2
|
||||
}
|
||||
self.thresholds: Dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, str]] = {
|
||||
@@ -260,9 +255,13 @@ class EnergyManager:
|
||||
"""从配置加载AFC阈值"""
|
||||
try:
|
||||
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
|
||||
self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
|
||||
self.thresholds["high_match"] = getattr(
|
||||
global_config.affinity_flow, "high_match_interest_threshold", 0.8
|
||||
)
|
||||
self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||||
self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
|
||||
self.thresholds["non_reply"] = getattr(
|
||||
global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2
|
||||
)
|
||||
|
||||
# 确保阈值关系合理
|
||||
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
|
||||
@@ -306,6 +305,7 @@ class EnergyManager:
|
||||
# 支持同步和异步计算器
|
||||
if callable(calculator.calculate):
|
||||
import inspect
|
||||
|
||||
if inspect.iscoroutinefunction(calculator.calculate):
|
||||
score = await calculator.calculate(context)
|
||||
else:
|
||||
@@ -347,11 +347,12 @@ class EnergyManager:
|
||||
calculation_time = time.time() - start_time
|
||||
total_calculations = self.stats["total_calculations"]
|
||||
self.stats["average_calculation_time"] = (
|
||||
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
|
||||
/ total_calculations
|
||||
)
|
||||
self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time
|
||||
) / total_calculations
|
||||
|
||||
logger.debug(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)")
|
||||
logger.debug(
|
||||
f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)"
|
||||
)
|
||||
return final_energy
|
||||
|
||||
def _apply_threshold_adjustment(self, energy: float) -> float:
|
||||
@@ -405,6 +406,7 @@ class EnergyManager:
|
||||
|
||||
# 添加随机扰动避免同步
|
||||
import random
|
||||
|
||||
jitter = random.uniform(0.8, 1.2)
|
||||
final_interval = base_interval * jitter
|
||||
|
||||
@@ -424,7 +426,8 @@ class EnergyManager:
|
||||
"""清理过期缓存"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
stream_id for stream_id, (_, timestamp) in self.energy_cache.items()
|
||||
stream_id
|
||||
for stream_id, (_, timestamp) in self.energy_cache.items()
|
||||
if current_time - timestamp > self.cache_ttl
|
||||
]
|
||||
|
||||
@@ -479,4 +482,4 @@ class EnergyManager:
|
||||
|
||||
|
||||
# 全局能量管理器实例
|
||||
energy_manager = EnergyManager()
|
||||
energy_manager = EnergyManager()
|
||||
|
||||
@@ -352,7 +352,7 @@ class ExpressionLearner:
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
)
|
||||
session.add(new_expression)
|
||||
|
||||
|
||||
# 限制最大数量
|
||||
exprs_result = await session.execute(
|
||||
select(Expression)
|
||||
@@ -456,11 +456,10 @@ class ExpressionLearnerManager:
|
||||
|
||||
self._ensure_expression_directories()
|
||||
|
||||
|
||||
async def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
await self._auto_migrate_json_to_db()
|
||||
await self._migrate_old_data_create_date()
|
||||
|
||||
|
||||
if chat_id not in self.expression_learners:
|
||||
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||
return self.expression_learners[chat_id]
|
||||
@@ -604,7 +603,9 @@ class ExpressionLearnerManager:
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 查找所有create_date为空的表达方式
|
||||
old_expressions_result = await session.execute(select(Expression).where(Expression.create_date.is_(None)))
|
||||
old_expressions_result = await session.execute(
|
||||
select(Expression).where(Expression.create_date.is_(None))
|
||||
)
|
||||
old_expressions = old_expressions_result.scalars().all()
|
||||
updated_count = 0
|
||||
|
||||
|
||||
@@ -131,7 +131,9 @@ class BotInterestManager:
|
||||
self.current_interests = generated_interests
|
||||
active_count = len(generated_interests.get_active_tags())
|
||||
logger.info(f"成功生成 {active_count} 个新兴趣标签。")
|
||||
tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags()]
|
||||
tags_info = [
|
||||
f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags()
|
||||
]
|
||||
tags_str = "\n".join(tags_info)
|
||||
logger.info(f"当前兴趣标签:\n{tags_str}")
|
||||
|
||||
@@ -639,11 +641,19 @@ class BotInterestManager:
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 查询最新的兴趣标签配置
|
||||
db_interests = (await session.execute(
|
||||
select(DBBotPersonalityInterests)
|
||||
.where(DBBotPersonalityInterests.personality_id == personality_id)
|
||||
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
|
||||
)).scalars().first()
|
||||
db_interests = (
|
||||
(
|
||||
await session.execute(
|
||||
select(DBBotPersonalityInterests)
|
||||
.where(DBBotPersonalityInterests.personality_id == personality_id)
|
||||
.order_by(
|
||||
DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc()
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_interests:
|
||||
logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}")
|
||||
@@ -728,10 +738,17 @@ class BotInterestManager:
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 检查是否已存在相同personality_id的记录
|
||||
existing_record = (await session.execute(
|
||||
select(DBBotPersonalityInterests)
|
||||
.where(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
)).scalars().first()
|
||||
existing_record = (
|
||||
(
|
||||
await session.execute(
|
||||
select(DBBotPersonalityInterests).where(
|
||||
DBBotPersonalityInterests.personality_id == interests.personality_id
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
@@ -763,10 +780,17 @@ class BotInterestManager:
|
||||
|
||||
# 验证保存是否成功
|
||||
async with get_db_session() as session:
|
||||
saved_record = (await session.execute(
|
||||
select(DBBotPersonalityInterests)
|
||||
.where(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
)).scalars().first()
|
||||
saved_record = (
|
||||
(
|
||||
await session.execute(
|
||||
select(DBBotPersonalityInterests).where(
|
||||
DBBotPersonalityInterests.personality_id == interests.personality_id
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if saved_record:
|
||||
logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录")
|
||||
logger.info(f" 版本: {saved_record.version}")
|
||||
|
||||
@@ -161,7 +161,7 @@ class EmbeddingStore:
|
||||
|
||||
@staticmethod
|
||||
def _get_embeddings_batch_threaded(
|
||||
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ class QAManager:
|
||||
async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取知识,返回结构化字典
|
||||
|
||||
|
||||
Args:
|
||||
question: 用户提出的问题
|
||||
|
||||
@@ -114,30 +114,27 @@ class QAManager:
|
||||
return None
|
||||
|
||||
query_res = processed_result[0]
|
||||
|
||||
|
||||
knowledge_items = []
|
||||
for res_hash, relevance, *_ in query_res:
|
||||
if store_item := self.embed_manager.paragraphs_embedding_store.store.get(res_hash):
|
||||
knowledge_items.append({
|
||||
"content": store_item.str,
|
||||
"source": "内部知识库",
|
||||
"relevance": f"{relevance:.4f}"
|
||||
})
|
||||
knowledge_items.append(
|
||||
{"content": store_item.str, "source": "内部知识库", "relevance": f"{relevance:.4f}"}
|
||||
)
|
||||
|
||||
if not knowledge_items:
|
||||
return None
|
||||
|
||||
|
||||
# 使用LLM生成总结
|
||||
knowledge_text_for_summary = "\n\n".join([item['content'] for item in knowledge_items[:5]]) # 最多总结前5条
|
||||
summary_prompt = f"根据以下信息,为问题 '{question}' 生成一个简洁的、不超过50字的摘要:\n\n{knowledge_text_for_summary}"
|
||||
|
||||
knowledge_text_for_summary = "\n\n".join([item["content"] for item in knowledge_items[:5]]) # 最多总结前5条
|
||||
summary_prompt = (
|
||||
f"根据以下信息,为问题 '{question}' 生成一个简洁的、不超过50字的摘要:\n\n{knowledge_text_for_summary}"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, (_, _, _) = await self.qa_model.generate_response_async(summary_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"生成知识摘要失败: {e}")
|
||||
summary = "无法生成摘要。"
|
||||
|
||||
return {
|
||||
"knowledge_items": knowledge_items,
|
||||
"summary": summary.strip() if summary else "没有可用的摘要。"
|
||||
}
|
||||
return {"knowledge_items": knowledge_items, "summary": summary.strip() if summary else "没有可用的摘要。"}
|
||||
|
||||
@@ -12,44 +12,23 @@ from .memory_chunk import (
|
||||
MemoryType,
|
||||
ImportanceLevel,
|
||||
ConfidenceLevel,
|
||||
create_memory_chunk
|
||||
create_memory_chunk,
|
||||
)
|
||||
|
||||
# 遗忘引擎
|
||||
from .memory_forgetting_engine import (
|
||||
MemoryForgettingEngine,
|
||||
ForgettingConfig,
|
||||
get_memory_forgetting_engine
|
||||
)
|
||||
from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine
|
||||
|
||||
# Vector DB存储系统
|
||||
from .vector_memory_storage_v2 import (
|
||||
VectorMemoryStorage,
|
||||
VectorStorageConfig,
|
||||
get_vector_memory_storage
|
||||
)
|
||||
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
||||
|
||||
# 记忆核心系统
|
||||
from .memory_system import (
|
||||
MemorySystem,
|
||||
MemorySystemConfig,
|
||||
get_memory_system,
|
||||
initialize_memory_system
|
||||
)
|
||||
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
|
||||
|
||||
# 记忆管理器
|
||||
from .memory_manager import (
|
||||
MemoryManager,
|
||||
MemoryResult,
|
||||
memory_manager
|
||||
)
|
||||
from .memory_manager import MemoryManager, MemoryResult, memory_manager
|
||||
|
||||
# 激活器
|
||||
from .enhanced_memory_activator import (
|
||||
MemoryActivator,
|
||||
memory_activator,
|
||||
enhanced_memory_activator
|
||||
)
|
||||
from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator
|
||||
|
||||
# 兼容性别名
|
||||
from .memory_chunk import MemoryChunk as Memory
|
||||
@@ -64,28 +43,23 @@ __all__ = [
|
||||
"ImportanceLevel",
|
||||
"ConfidenceLevel",
|
||||
"create_memory_chunk",
|
||||
|
||||
# 遗忘引擎
|
||||
"MemoryForgettingEngine",
|
||||
"ForgettingConfig",
|
||||
"get_memory_forgetting_engine",
|
||||
|
||||
# Vector DB存储
|
||||
"VectorMemoryStorage",
|
||||
"VectorStorageConfig",
|
||||
"VectorStorageConfig",
|
||||
"get_vector_memory_storage",
|
||||
|
||||
# 记忆系统
|
||||
"MemorySystem",
|
||||
"MemorySystemConfig",
|
||||
"get_memory_system",
|
||||
"initialize_memory_system",
|
||||
|
||||
# 记忆管理器
|
||||
"MemoryManager",
|
||||
"MemoryResult",
|
||||
"MemoryResult",
|
||||
"memory_manager",
|
||||
|
||||
# 激活器
|
||||
"MemoryActivator",
|
||||
"memory_activator",
|
||||
@@ -95,4 +69,4 @@ __all__ = [
|
||||
# 版本信息
|
||||
__version__ = "3.0.0"
|
||||
__author__ = "MoFox Team"
|
||||
__description__ = "简化记忆系统 - 统一记忆架构与智能遗忘机制"
|
||||
__description__ = "简化记忆系统 - 统一记忆架构与智能遗忘机制"
|
||||
|
||||
@@ -4,15 +4,14 @@
|
||||
将增强记忆系统集成到现有MoFox Bot架构中
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_formatter import MemoryFormatter, FormatterConfig, format_memories_for_llm
|
||||
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -36,6 +35,7 @@ MEMORY_TYPE_LABELS = {
|
||||
@dataclass
|
||||
class AdapterConfig:
|
||||
"""适配器配置"""
|
||||
|
||||
enable_enhanced_memory: bool = True
|
||||
integration_mode: str = "enhanced_only" # replace, enhanced_only
|
||||
auto_migration: bool = True
|
||||
@@ -61,7 +61,7 @@ class EnhancedMemoryAdapter:
|
||||
"hybrid_used": 0,
|
||||
"memories_created": 0,
|
||||
"memories_retrieved": 0,
|
||||
"average_processing_time": 0.0
|
||||
"average_processing_time": 0.0,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
@@ -79,14 +79,11 @@ class EnhancedMemoryAdapter:
|
||||
memory_value_threshold=self.config.memory_value_threshold,
|
||||
fusion_threshold=self.config.fusion_threshold,
|
||||
max_retrieval_results=self.config.max_retrieval_results,
|
||||
enable_learning=True # 启用学习功能
|
||||
enable_learning=True, # 启用学习功能
|
||||
)
|
||||
|
||||
# 创建集成层
|
||||
self.integration_layer = MemoryIntegrationLayer(
|
||||
llm_model=self.llm_model,
|
||||
config=integration_config
|
||||
)
|
||||
self.integration_layer = MemoryIntegrationLayer(llm_model=self.llm_model, config=integration_config)
|
||||
|
||||
# 初始化集成层
|
||||
await self.integration_layer.initialize()
|
||||
@@ -99,10 +96,7 @@ class EnhancedMemoryAdapter:
|
||||
# 如果初始化失败,禁用增强记忆功能
|
||||
self.config.enable_enhanced_memory = False
|
||||
|
||||
async def process_conversation_memory(
|
||||
self,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""处理对话记忆,以上下文为唯一输入"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"success": False, "error": "Enhanced memory not available"}
|
||||
@@ -152,11 +146,7 @@ class EnhancedMemoryAdapter:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None
|
||||
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
@@ -164,9 +154,7 @@ class EnhancedMemoryAdapter:
|
||||
|
||||
try:
|
||||
limit = limit or self.config.max_retrieval_results
|
||||
memories = await self.integration_layer.retrieve_relevant_memories(
|
||||
query, None, context, limit
|
||||
)
|
||||
memories = await self.integration_layer.retrieve_relevant_memories(query, None, context, limit)
|
||||
|
||||
self.adapter_stats["memories_retrieved"] += len(memories)
|
||||
logger.debug(f"检索到 {len(memories)} 条相关记忆")
|
||||
@@ -178,11 +166,7 @@ class EnhancedMemoryAdapter:
|
||||
return []
|
||||
|
||||
async def get_memory_context_for_prompt(
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
max_memories: int = 5
|
||||
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
|
||||
@@ -197,15 +181,11 @@ class EnhancedMemoryAdapter:
|
||||
include_confidence=False,
|
||||
use_emoji_icons=True,
|
||||
group_by_type=False,
|
||||
max_display_length=150
|
||||
)
|
||||
|
||||
return format_memories_for_llm(
|
||||
memories=memories,
|
||||
query_context=query,
|
||||
config=formatter_config
|
||||
max_display_length=150,
|
||||
)
|
||||
|
||||
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
|
||||
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
|
||||
"""获取增强记忆系统摘要"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
@@ -227,7 +207,7 @@ class EnhancedMemoryAdapter:
|
||||
"adapter_stats": adapter_stats,
|
||||
"integration_stats": integration_stats,
|
||||
"total_memories_created": adapter_stats["memories_created"],
|
||||
"total_memories_retrieved": adapter_stats["memories_retrieved"]
|
||||
"total_memories_retrieved": adapter_stats["memories_retrieved"],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -285,12 +265,12 @@ async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAd
|
||||
from src.config.config import global_config
|
||||
|
||||
adapter_config = AdapterConfig(
|
||||
enable_enhanced_memory=getattr(global_config.memory, 'enable_enhanced_memory', True),
|
||||
integration_mode=getattr(global_config.memory, 'enhanced_memory_mode', 'enhanced_only'),
|
||||
auto_migration=getattr(global_config.memory, 'enable_memory_migration', True),
|
||||
memory_value_threshold=getattr(global_config.memory, 'memory_value_threshold', 0.6),
|
||||
fusion_threshold=getattr(global_config.memory, 'fusion_threshold', 0.85),
|
||||
max_retrieval_results=getattr(global_config.memory, 'max_retrieval_results', 10)
|
||||
enable_enhanced_memory=getattr(global_config.memory, "enable_enhanced_memory", True),
|
||||
integration_mode=getattr(global_config.memory, "enhanced_memory_mode", "enhanced_only"),
|
||||
auto_migration=getattr(global_config.memory, "enable_memory_migration", True),
|
||||
memory_value_threshold=getattr(global_config.memory, "memory_value_threshold", 0.6),
|
||||
fusion_threshold=getattr(global_config.memory, "fusion_threshold", 0.85),
|
||||
max_retrieval_results=getattr(global_config.memory, "max_retrieval_results", 10),
|
||||
)
|
||||
|
||||
_enhanced_memory_adapter = EnhancedMemoryAdapter(llm_model, adapter_config)
|
||||
@@ -312,13 +292,13 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
|
||||
|
||||
|
||||
async def process_conversation_with_enhanced_memory(
|
||||
context: Dict[str, Any],
|
||||
llm_model: Optional[LLMRequest] = None
|
||||
context: Dict[str, Any], llm_model: Optional[LLMRequest] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
@@ -345,12 +325,13 @@ async def retrieve_memories_with_enhanced_system(
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 10,
|
||||
llm_model: Optional[LLMRequest] = None
|
||||
llm_model: Optional[LLMRequest] = None,
|
||||
) -> List[MemoryChunk]:
|
||||
"""使用增强记忆系统检索记忆"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
@@ -366,12 +347,13 @@ async def get_memory_context_for_prompt(
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
max_memories: int = 5,
|
||||
llm_model: Optional[LLMRequest] = None
|
||||
llm_model: Optional[LLMRequest] = None,
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
from src.llm_models.utils_model import get_global_llm_model
|
||||
|
||||
llm_model = get_global_llm_model()
|
||||
|
||||
try:
|
||||
@@ -379,4 +361,4 @@ async def get_memory_context_for_prompt(
|
||||
return await adapter.get_memory_context_for_prompt(query, user_id, context, max_memories)
|
||||
except Exception as e:
|
||||
logger.error(f"获取记忆上下文失败: {e}", exc_info=True)
|
||||
return ""
|
||||
return ""
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
用于在消息处理过程中自动构建和检索记忆
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
@@ -19,8 +18,7 @@ class EnhancedMemoryHooks:
|
||||
"""增强记忆系统钩子 - 自动处理消息的记忆构建和检索"""
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = (global_config.memory.enable_memory and
|
||||
global_config.memory.enable_enhanced_memory)
|
||||
self.enabled = global_config.memory.enable_memory and global_config.memory.enable_enhanced_memory
|
||||
self.processed_messages = set() # 避免重复处理
|
||||
|
||||
async def process_message_for_memory(
|
||||
@@ -29,7 +27,7 @@ class EnhancedMemoryHooks:
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理消息并构建记忆
|
||||
@@ -76,7 +74,7 @@ class EnhancedMemoryHooks:
|
||||
"timestamp": datetime.now().timestamp(),
|
||||
"message_type": "user_message",
|
||||
**bot_context,
|
||||
**(context or {})
|
||||
**(context or {}),
|
||||
}
|
||||
|
||||
# 处理对话并构建记忆
|
||||
@@ -84,7 +82,7 @@ class EnhancedMemoryHooks:
|
||||
conversation_text=message_content,
|
||||
context=memory_context,
|
||||
user_id=user_id,
|
||||
timestamp=memory_context["timestamp"]
|
||||
timestamp=memory_context["timestamp"],
|
||||
)
|
||||
|
||||
# 标记消息已处理
|
||||
@@ -108,7 +106,7 @@ class EnhancedMemoryHooks:
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5,
|
||||
extra_context: Optional[Dict[str, Any]] = None
|
||||
extra_context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
@@ -134,9 +132,7 @@ class EnhancedMemoryHooks:
|
||||
context = {
|
||||
"chat_id": chat_id,
|
||||
"query_intent": "response_generation",
|
||||
"expected_memory_types": [
|
||||
"personal_fact", "event", "preference", "opinion"
|
||||
]
|
||||
"expected_memory_types": ["personal_fact", "event", "preference", "opinion"],
|
||||
}
|
||||
|
||||
if extra_context:
|
||||
@@ -144,10 +140,7 @@ class EnhancedMemoryHooks:
|
||||
|
||||
# 获取相关记忆
|
||||
enhanced_results = await enhanced_memory_manager.get_enhanced_memory_context(
|
||||
query_text=query_text,
|
||||
user_id=user_id,
|
||||
context=context,
|
||||
limit=limit
|
||||
query_text=query_text, user_id=user_id, context=context, limit=limit
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
@@ -199,4 +192,4 @@ class EnhancedMemoryHooks:
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
enhanced_memory_hooks = EnhancedMemoryHooks()
|
||||
enhanced_memory_hooks = EnhancedMemoryHooks()
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
用于在现有系统中无缝集成增强记忆功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -14,11 +13,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def process_user_message_memory(
|
||||
message_content: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
处理用户消息并构建记忆
|
||||
@@ -35,11 +30,7 @@ async def process_user_message_memory(
|
||||
"""
|
||||
try:
|
||||
success = await enhanced_memory_hooks.process_message_for_memory(
|
||||
message_content=message_content,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
context=context
|
||||
message_content=message_content, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -53,11 +44,7 @@ async def process_user_message_memory(
|
||||
|
||||
|
||||
async def get_relevant_memories_for_response(
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5,
|
||||
extra_context: Optional[Dict[str, Any]] = None
|
||||
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
@@ -74,29 +61,17 @@ async def get_relevant_memories_for_response(
|
||||
"""
|
||||
try:
|
||||
memories = await enhanced_memory_hooks.get_memory_for_response(
|
||||
query_text=query_text,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
limit=limit,
|
||||
extra_context=extra_context
|
||||
query_text=query_text, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=extra_context
|
||||
)
|
||||
|
||||
result = {
|
||||
"has_memories": len(memories) > 0,
|
||||
"memories": memories,
|
||||
"memory_count": len(memories)
|
||||
}
|
||||
result = {"has_memories": len(memories) > 0, "memories": memories, "memory_count": len(memories)}
|
||||
|
||||
logger.debug(f"为回复获取到 {len(memories)} 条相关记忆")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取回复记忆失败: {e}")
|
||||
return {
|
||||
"has_memories": False,
|
||||
"memories": [],
|
||||
"memory_count": 0
|
||||
}
|
||||
return {"has_memories": False, "memories": [], "memory_count": 0}
|
||||
|
||||
|
||||
def format_memories_for_prompt(memories: Dict[str, Any]) -> str:
|
||||
@@ -152,16 +127,13 @@ def get_memory_system_status() -> Dict[str, Any]:
|
||||
"enabled": enhanced_memory_hooks.enabled,
|
||||
"enhanced_system_initialized": enhanced_memory_manager.is_initialized,
|
||||
"processed_messages_count": len(enhanced_memory_hooks.processed_messages),
|
||||
"system_type": "enhanced_memory_system"
|
||||
"system_type": "enhanced_memory_system",
|
||||
}
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def remember_message(
|
||||
message: str,
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat",
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
便捷的记忆构建函数
|
||||
@@ -175,13 +147,10 @@ async def remember_message(
|
||||
bool: 是否成功
|
||||
"""
|
||||
import uuid
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
return await process_user_message_memory(
|
||||
message_content=message,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
context=context
|
||||
message_content=message, user_id=user_id, chat_id=chat_id, message_id=message_id, context=context
|
||||
)
|
||||
|
||||
|
||||
@@ -190,7 +159,7 @@ async def recall_memories(
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat",
|
||||
limit: int = 5,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
便捷的记忆检索函数
|
||||
@@ -205,9 +174,5 @@ async def recall_memories(
|
||||
Dict: 记忆信息
|
||||
"""
|
||||
return await get_relevant_memories_for_response(
|
||||
query_text=query,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
limit=limit,
|
||||
extra_context=context
|
||||
)
|
||||
query_text=query, user_id=user_id, chat_id=chat_id, limit=limit, extra_context=context
|
||||
)
|
||||
|
||||
@@ -18,32 +18,34 @@ logger = get_logger(__name__)
|
||||
|
||||
class IntentType(Enum):
|
||||
"""对话意图类型"""
|
||||
FACT_QUERY = "fact_query" # 事实查询
|
||||
EVENT_RECALL = "event_recall" # 事件回忆
|
||||
|
||||
FACT_QUERY = "fact_query" # 事实查询
|
||||
EVENT_RECALL = "event_recall" # 事件回忆
|
||||
PREFERENCE_CHECK = "preference_check" # 偏好检查
|
||||
GENERAL_CHAT = "general_chat" # 一般对话
|
||||
UNKNOWN = "unknown" # 未知意图
|
||||
GENERAL_CHAT = "general_chat" # 一般对话
|
||||
UNKNOWN = "unknown" # 未知意图
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReRankingConfig:
|
||||
"""重排序配置"""
|
||||
|
||||
# 权重配置 (w1 + w2 + w3 + w4 = 1.0)
|
||||
semantic_weight: float = 0.5 # 语义相似度权重
|
||||
recency_weight: float = 0.2 # 时效性权重
|
||||
usage_freq_weight: float = 0.2 # 使用频率权重
|
||||
type_match_weight: float = 0.1 # 类型匹配权重
|
||||
|
||||
semantic_weight: float = 0.5 # 语义相似度权重
|
||||
recency_weight: float = 0.2 # 时效性权重
|
||||
usage_freq_weight: float = 0.2 # 使用频率权重
|
||||
type_match_weight: float = 0.1 # 类型匹配权重
|
||||
|
||||
# 时效性衰减参数
|
||||
recency_decay_rate: float = 0.1 # 时效性衰减率 (天)
|
||||
|
||||
recency_decay_rate: float = 0.1 # 时效性衰减率 (天)
|
||||
|
||||
# 使用频率计算参数
|
||||
freq_log_base: float = 2.0 # 对数底数
|
||||
freq_max_score: float = 5.0 # 最大频率得分
|
||||
|
||||
freq_log_base: float = 2.0 # 对数底数
|
||||
freq_max_score: float = 5.0 # 最大频率得分
|
||||
|
||||
# 类型匹配权重映射
|
||||
type_match_weights: Dict[str, Dict[str, float]] = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化类型匹配权重"""
|
||||
if self.type_match_weights is None:
|
||||
@@ -53,102 +55,150 @@ class ReRankingConfig:
|
||||
MemoryType.KNOWLEDGE.value: 0.8,
|
||||
MemoryType.PREFERENCE.value: 0.5,
|
||||
MemoryType.EVENT.value: 0.3,
|
||||
"default": 0.3
|
||||
"default": 0.3,
|
||||
},
|
||||
IntentType.EVENT_RECALL.value: {
|
||||
MemoryType.EVENT.value: 1.0,
|
||||
MemoryType.EXPERIENCE.value: 0.8,
|
||||
MemoryType.EMOTION.value: 0.6,
|
||||
MemoryType.PERSONAL_FACT.value: 0.5,
|
||||
"default": 0.5
|
||||
"default": 0.5,
|
||||
},
|
||||
IntentType.PREFERENCE_CHECK.value: {
|
||||
MemoryType.PREFERENCE.value: 1.0,
|
||||
MemoryType.OPINION.value: 0.8,
|
||||
MemoryType.GOAL.value: 0.6,
|
||||
MemoryType.PERSONAL_FACT.value: 0.4,
|
||||
"default": 0.4
|
||||
"default": 0.4,
|
||||
},
|
||||
IntentType.GENERAL_CHAT.value: {
|
||||
"default": 0.8
|
||||
},
|
||||
IntentType.UNKNOWN.value: {
|
||||
"default": 0.8
|
||||
}
|
||||
IntentType.GENERAL_CHAT.value: {"default": 0.8},
|
||||
IntentType.UNKNOWN.value: {"default": 0.8},
|
||||
}
|
||||
|
||||
|
||||
class IntentClassifier:
|
||||
"""轻量级意图识别器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
# 关键词模式匹配规则
|
||||
self.patterns = {
|
||||
IntentType.FACT_QUERY: [
|
||||
# 中文模式
|
||||
"我是", "我的", "我叫", "我在", "我住在", "我的职业", "我的工作",
|
||||
"什么时候", "在哪里", "是什么", "多少", "几岁", "年龄",
|
||||
"我是",
|
||||
"我的",
|
||||
"我叫",
|
||||
"我在",
|
||||
"我住在",
|
||||
"我的职业",
|
||||
"我的工作",
|
||||
"什么时候",
|
||||
"在哪里",
|
||||
"是什么",
|
||||
"多少",
|
||||
"几岁",
|
||||
"年龄",
|
||||
# 英文模式
|
||||
"what is", "where is", "when is", "how old", "my name", "i am", "i live"
|
||||
"what is",
|
||||
"where is",
|
||||
"when is",
|
||||
"how old",
|
||||
"my name",
|
||||
"i am",
|
||||
"i live",
|
||||
],
|
||||
IntentType.EVENT_RECALL: [
|
||||
# 中文模式
|
||||
"记得", "想起", "还记得", "那次", "上次", "之前", "以前", "曾经",
|
||||
"发生过", "经历", "做过", "去过", "见过",
|
||||
"记得",
|
||||
"想起",
|
||||
"还记得",
|
||||
"那次",
|
||||
"上次",
|
||||
"之前",
|
||||
"以前",
|
||||
"曾经",
|
||||
"发生过",
|
||||
"经历",
|
||||
"做过",
|
||||
"去过",
|
||||
"见过",
|
||||
# 英文模式
|
||||
"remember", "recall", "last time", "before", "previously", "happened", "experience"
|
||||
"remember",
|
||||
"recall",
|
||||
"last time",
|
||||
"before",
|
||||
"previously",
|
||||
"happened",
|
||||
"experience",
|
||||
],
|
||||
IntentType.PREFERENCE_CHECK: [
|
||||
# 中文模式
|
||||
"喜欢", "不喜欢", "偏好", "爱好", "兴趣", "讨厌", "最爱", "最喜欢",
|
||||
"习惯", "通常", "一般", "倾向于", "更喜欢",
|
||||
"喜欢",
|
||||
"不喜欢",
|
||||
"偏好",
|
||||
"爱好",
|
||||
"兴趣",
|
||||
"讨厌",
|
||||
"最爱",
|
||||
"最喜欢",
|
||||
"习惯",
|
||||
"通常",
|
||||
"一般",
|
||||
"倾向于",
|
||||
"更喜欢",
|
||||
# 英文模式
|
||||
"like", "love", "hate", "prefer", "favorite", "usually", "tend to", "interest"
|
||||
]
|
||||
"like",
|
||||
"love",
|
||||
"hate",
|
||||
"prefer",
|
||||
"favorite",
|
||||
"usually",
|
||||
"tend to",
|
||||
"interest",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType:
|
||||
"""识别对话意图"""
|
||||
if not query:
|
||||
return IntentType.UNKNOWN
|
||||
|
||||
|
||||
query_lower = query.lower()
|
||||
|
||||
|
||||
# 统计各意图的匹配分数
|
||||
intent_scores = {intent: 0 for intent in IntentType}
|
||||
|
||||
|
||||
for intent, patterns in self.patterns.items():
|
||||
for pattern in patterns:
|
||||
if pattern in query_lower:
|
||||
intent_scores[intent] += 1
|
||||
|
||||
|
||||
# 返回得分最高的意图
|
||||
max_score = max(intent_scores.values())
|
||||
if max_score == 0:
|
||||
return IntentType.GENERAL_CHAT
|
||||
|
||||
|
||||
for intent, score in intent_scores.items():
|
||||
if score == max_score:
|
||||
return intent
|
||||
|
||||
|
||||
return IntentType.GENERAL_CHAT
|
||||
|
||||
|
||||
class EnhancedReRanker:
|
||||
"""增强重排序器 - 实现文档设计的多维度评分模型"""
|
||||
|
||||
|
||||
def __init__(self, config: Optional[ReRankingConfig] = None):
|
||||
self.config = config or ReRankingConfig()
|
||||
self.intent_classifier = IntentClassifier()
|
||||
|
||||
|
||||
# 验证权重和为1.0
|
||||
total_weight = (
|
||||
self.config.semantic_weight +
|
||||
self.config.recency_weight +
|
||||
self.config.usage_freq_weight +
|
||||
self.config.type_match_weight
|
||||
self.config.semantic_weight
|
||||
+ self.config.recency_weight
|
||||
+ self.config.usage_freq_weight
|
||||
+ self.config.type_match_weight
|
||||
)
|
||||
|
||||
|
||||
if abs(total_weight - 1.0) > 0.01:
|
||||
logger.warning(f"重排序权重和不为1.0: {total_weight}, 将进行归一化")
|
||||
# 归一化权重
|
||||
@@ -156,94 +206,94 @@ class EnhancedReRanker:
|
||||
self.config.recency_weight /= total_weight
|
||||
self.config.usage_freq_weight /= total_weight
|
||||
self.config.type_match_weight /= total_weight
|
||||
|
||||
|
||||
def rerank_memories(
|
||||
self,
|
||||
query: str,
|
||||
candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
|
||||
context: Dict[str, Any],
|
||||
limit: int = 10
|
||||
limit: int = 10,
|
||||
) -> List[Tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
对候选记忆进行重排序
|
||||
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
candidate_memories: 候选记忆列表 [(memory_id, memory, vector_similarity)]
|
||||
context: 上下文信息
|
||||
limit: 返回数量限制
|
||||
|
||||
|
||||
Returns:
|
||||
重排序后的记忆列表 [(memory_id, memory, final_score)]
|
||||
"""
|
||||
if not candidate_memories:
|
||||
return []
|
||||
|
||||
|
||||
# 识别查询意图
|
||||
intent = self.intent_classifier.classify_intent(query, context)
|
||||
logger.debug(f"识别到查询意图: {intent.value}")
|
||||
|
||||
|
||||
# 计算每个候选记忆的最终得分
|
||||
scored_memories = []
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
for memory_id, memory, vector_sim in candidate_memories:
|
||||
try:
|
||||
# 1. 语义相似度得分 (已归一化到[0,1])
|
||||
semantic_score = self._normalize_similarity(vector_sim)
|
||||
|
||||
|
||||
# 2. 时效性得分
|
||||
recency_score = self._calculate_recency_score(memory, current_time)
|
||||
|
||||
|
||||
# 3. 使用频率得分
|
||||
usage_freq_score = self._calculate_usage_frequency_score(memory)
|
||||
|
||||
|
||||
# 4. 类型匹配得分
|
||||
type_match_score = self._calculate_type_match_score(memory, intent)
|
||||
|
||||
|
||||
# 计算最终得分
|
||||
final_score = (
|
||||
self.config.semantic_weight * semantic_score +
|
||||
self.config.recency_weight * recency_score +
|
||||
self.config.usage_freq_weight * usage_freq_score +
|
||||
self.config.type_match_weight * type_match_score
|
||||
self.config.semantic_weight * semantic_score
|
||||
+ self.config.recency_weight * recency_score
|
||||
+ self.config.usage_freq_weight * usage_freq_score
|
||||
+ self.config.type_match_weight * type_match_score
|
||||
)
|
||||
|
||||
|
||||
scored_memories.append((memory_id, memory, final_score))
|
||||
|
||||
|
||||
# 记录调试信息
|
||||
logger.debug(
|
||||
f"记忆评分 {memory_id[:8]}: semantic={semantic_score:.3f}, "
|
||||
f"recency={recency_score:.3f}, freq={usage_freq_score:.3f}, "
|
||||
f"type={type_match_score:.3f}, final={final_score:.3f}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算记忆 {memory_id} 得分时出错: {e}")
|
||||
# 使用向量相似度作为后备得分
|
||||
scored_memories.append((memory_id, memory, vector_sim))
|
||||
|
||||
|
||||
# 按最终得分降序排序
|
||||
scored_memories.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
|
||||
# 返回前N个结果
|
||||
result = scored_memories[:limit]
|
||||
|
||||
|
||||
highest_score = result[0][2] if result else 0.0
|
||||
logger.info(
|
||||
f"重排序完成: 候选={len(candidate_memories)}, 返回={len(result)}, "
|
||||
f"意图={intent.value}, 最高分={highest_score:.3f}"
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_similarity(self, raw_similarity: float) -> float:
|
||||
"""归一化相似度到[0,1]区间"""
|
||||
# 假设原始相似度已经在[-1,1]或[0,1]区间
|
||||
if raw_similarity < 0:
|
||||
return (raw_similarity + 1) / 2 # 从[-1,1]映射到[0,1]
|
||||
return min(1.0, max(0.0, raw_similarity)) # 确保在[0,1]区间
|
||||
|
||||
|
||||
def _calculate_recency_score(self, memory: MemoryChunk, current_time: float) -> float:
|
||||
"""
|
||||
计算时效性得分
|
||||
@@ -251,13 +301,13 @@ class EnhancedReRanker:
|
||||
"""
|
||||
last_accessed = memory.metadata.last_accessed or memory.metadata.created_at
|
||||
days_old = (current_time - last_accessed) / (24 * 3600) # 转换为天数
|
||||
|
||||
|
||||
if days_old < 0:
|
||||
days_old = 0 # 处理时间异常
|
||||
|
||||
|
||||
score = 1 / (1 + self.config.recency_decay_rate * days_old)
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
|
||||
def _calculate_usage_frequency_score(self, memory: MemoryChunk) -> float:
|
||||
"""
|
||||
计算使用频率得分
|
||||
@@ -266,22 +316,22 @@ class EnhancedReRanker:
|
||||
access_count = memory.metadata.access_count
|
||||
if access_count <= 0:
|
||||
return 0.0
|
||||
|
||||
|
||||
log_count = math.log2(access_count + 1)
|
||||
score = log_count / self.config.freq_max_score
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
|
||||
def _calculate_type_match_score(self, memory: MemoryChunk, intent: IntentType) -> float:
|
||||
"""计算类型匹配得分"""
|
||||
memory_type = memory.memory_type.value
|
||||
intent_value = intent.value
|
||||
|
||||
|
||||
# 获取对应意图的类型权重映射
|
||||
type_weights = self.config.type_match_weights.get(intent_value, {})
|
||||
|
||||
|
||||
# 查找具体类型的权重,如果没有则使用默认权重
|
||||
score = type_weights.get(memory_type, type_weights.get("default", 0.8))
|
||||
|
||||
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
|
||||
@@ -294,7 +344,7 @@ def rerank_candidate_memories(
|
||||
candidate_memories: List[Tuple[str, MemoryChunk, float]],
|
||||
context: Dict[str, Any],
|
||||
limit: int = 10,
|
||||
config: Optional[ReRankingConfig] = None
|
||||
config: Optional[ReRankingConfig] = None,
|
||||
) -> List[Tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
便捷函数:对候选记忆进行重排序
|
||||
@@ -303,5 +353,5 @@ def rerank_candidate_memories(
|
||||
reranker = EnhancedReRanker(config)
|
||||
else:
|
||||
reranker = default_reranker
|
||||
|
||||
return reranker.rerank_memories(query, candidate_memories, context, limit)
|
||||
|
||||
return reranker.rerank_memories(query, candidate_memories, context, limit)
|
||||
|
||||
@@ -12,7 +12,7 @@ from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -20,13 +20,15 @@ logger = get_logger(__name__)
|
||||
|
||||
class IntegrationMode(Enum):
|
||||
"""集成模式"""
|
||||
REPLACE = "replace" # 完全替换现有记忆系统
|
||||
|
||||
REPLACE = "replace" # 完全替换现有记忆系统
|
||||
ENHANCED_ONLY = "enhanced_only" # 仅使用增强记忆系统
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntegrationConfig:
|
||||
"""集成配置"""
|
||||
|
||||
mode: IntegrationMode = IntegrationMode.ENHANCED_ONLY
|
||||
enable_enhanced_memory: bool = True
|
||||
memory_value_threshold: float = 0.6
|
||||
@@ -51,7 +53,7 @@ class MemoryIntegrationLayer:
|
||||
"enhanced_queries": 0,
|
||||
"memory_creations": 0,
|
||||
"average_response_time": 0.0,
|
||||
"success_rate": 0.0
|
||||
"success_rate": 0.0,
|
||||
}
|
||||
|
||||
# 初始化锁
|
||||
@@ -88,6 +90,7 @@ class MemoryIntegrationLayer:
|
||||
|
||||
# 创建增强记忆系统配置
|
||||
from src.chat.memory_system.enhanced_memory_core import MemorySystemConfig
|
||||
|
||||
memory_config = MemorySystemConfig.from_global_config()
|
||||
|
||||
# 使用集成配置覆盖部分值
|
||||
@@ -96,9 +99,7 @@ class MemoryIntegrationLayer:
|
||||
memory_config.final_recall_limit = self.config.max_retrieval_results
|
||||
|
||||
# 创建增强记忆系统
|
||||
self.enhanced_memory = EnhancedMemorySystem(
|
||||
config=memory_config
|
||||
)
|
||||
self.enhanced_memory = EnhancedMemorySystem(config=memory_config)
|
||||
|
||||
# 如果外部提供了LLM模型,注入到系统中
|
||||
if self.llm_model is not None:
|
||||
@@ -112,10 +113,7 @@ class MemoryIntegrationLayer:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_conversation(
|
||||
self,
|
||||
context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理对话记忆,仅使用上下文信息"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return {"success": False, "error": "Memory system not available"}
|
||||
@@ -154,7 +152,7 @@ class MemoryIntegrationLayer:
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None
|
||||
limit: Optional[int] = None,
|
||||
) -> List[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
@@ -163,10 +161,7 @@ class MemoryIntegrationLayer:
|
||||
try:
|
||||
limit = limit or self.config.max_retrieval_results
|
||||
memories = await self.enhanced_memory.retrieve_relevant_memories(
|
||||
query=query,
|
||||
user_id=None,
|
||||
context=context or {},
|
||||
limit=limit
|
||||
query=query, user_id=None, context=context or {}, limit=limit
|
||||
)
|
||||
|
||||
memory_count = len(memories)
|
||||
@@ -191,7 +186,7 @@ class MemoryIntegrationLayer:
|
||||
"status": "initialized",
|
||||
"mode": self.config.mode.value,
|
||||
"enhanced_memory": enhanced_status,
|
||||
"integration_stats": self.integration_stats.copy()
|
||||
"integration_stats": self.integration_stats.copy(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -248,4 +243,4 @@ class MemoryIntegrationLayer:
|
||||
logger.info("✅ 记忆系统集成层已关闭")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)
|
||||
logger.error(f"❌ 关闭集成层失败: {e}", exc_info=True)
|
||||
|
||||
@@ -4,17 +4,15 @@
|
||||
提供与现有MoFox Bot系统的无缝集成点
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from typing import Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_adapter import (
|
||||
get_enhanced_memory_adapter,
|
||||
process_conversation_with_enhanced_memory,
|
||||
retrieve_memories_with_enhanced_system,
|
||||
get_memory_context_for_prompt
|
||||
get_memory_context_for_prompt,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -23,6 +21,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class HookResult:
|
||||
"""钩子执行结果"""
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
@@ -39,7 +38,7 @@ class MemoryIntegrationHooks:
|
||||
"memory_retrieval_hooks": 0,
|
||||
"prompt_enhancement_hooks": 0,
|
||||
"total_hook_executions": 0,
|
||||
"average_hook_time": 0.0
|
||||
"average_hook_time": 0.0,
|
||||
}
|
||||
|
||||
async def register_hooks(self):
|
||||
@@ -130,10 +129,7 @@ class MemoryIntegrationHooks:
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
# 注册消息后处理事件
|
||||
event_manager.subscribe(
|
||||
EventType.MESSAGE_PROCESSED,
|
||||
self._on_message_processed_handler
|
||||
)
|
||||
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
|
||||
logger.debug("已注册到事件系统的消息处理钩子")
|
||||
|
||||
except ImportError:
|
||||
@@ -144,10 +140,8 @@ class MemoryIntegrationHooks:
|
||||
from src.chat.message_manager import message_manager
|
||||
|
||||
# 如果消息管理器支持钩子注册
|
||||
if hasattr(message_manager, 'register_post_process_hook'):
|
||||
message_manager.register_post_process_hook(
|
||||
self._on_message_processed_hook
|
||||
)
|
||||
if hasattr(message_manager, "register_post_process_hook"):
|
||||
message_manager.register_post_process_hook(self._on_message_processed_hook)
|
||||
logger.debug("已注册到消息管理器的处理钩子")
|
||||
|
||||
except ImportError:
|
||||
@@ -164,10 +158,8 @@ class MemoryIntegrationHooks:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
if hasattr(chat_manager, 'register_save_hook'):
|
||||
chat_manager.register_save_hook(
|
||||
self._on_chat_stream_save_hook
|
||||
)
|
||||
if hasattr(chat_manager, "register_save_hook"):
|
||||
chat_manager.register_save_hook(self._on_chat_stream_save_hook)
|
||||
logger.debug("已注册到聊天流管理器的保存钩子")
|
||||
|
||||
except ImportError:
|
||||
@@ -183,10 +175,8 @@ class MemoryIntegrationHooks:
|
||||
try:
|
||||
from src.chat.replyer.default_generator import default_generator
|
||||
|
||||
if hasattr(default_generator, 'register_pre_generation_hook'):
|
||||
default_generator.register_pre_generation_hook(
|
||||
self._on_pre_response_hook
|
||||
)
|
||||
if hasattr(default_generator, "register_pre_generation_hook"):
|
||||
default_generator.register_pre_generation_hook(self._on_pre_response_hook)
|
||||
logger.debug("已注册到回复生成器的前置钩子")
|
||||
|
||||
except ImportError:
|
||||
@@ -202,10 +192,8 @@ class MemoryIntegrationHooks:
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import knowledge_manager
|
||||
|
||||
if hasattr(knowledge_manager, 'register_query_enhancer'):
|
||||
knowledge_manager.register_query_enhancer(
|
||||
self._on_knowledge_query_hook
|
||||
)
|
||||
if hasattr(knowledge_manager, "register_query_enhancer"):
|
||||
knowledge_manager.register_query_enhancer(self._on_knowledge_query_hook)
|
||||
logger.debug("已注册到知识库的查询增强钩子")
|
||||
|
||||
except ImportError:
|
||||
@@ -221,10 +209,8 @@ class MemoryIntegrationHooks:
|
||||
try:
|
||||
from src.chat.utils.prompt import prompt_manager
|
||||
|
||||
if hasattr(prompt_manager, 'register_enhancer'):
|
||||
prompt_manager.register_enhancer(
|
||||
self._on_prompt_building_hook
|
||||
)
|
||||
if hasattr(prompt_manager, "register_enhancer"):
|
||||
prompt_manager.register_enhancer(self._on_prompt_building_hook)
|
||||
logger.debug("已注册到提示词管理器的增强钩子")
|
||||
|
||||
except ImportError:
|
||||
@@ -278,7 +264,7 @@ class MemoryIntegrationHooks:
|
||||
"platform": message_info.get("platform", "unknown"),
|
||||
"interest_value": message_data.get("interest_value", 0.0),
|
||||
"keywords": message_data.get("key_words", []),
|
||||
"timestamp": message_data.get("time", time.time())
|
||||
"timestamp": message_data.get("time", time.time()),
|
||||
}
|
||||
|
||||
# 使用增强记忆系统处理对话
|
||||
@@ -296,7 +282,7 @@ class MemoryIntegrationHooks:
|
||||
return HookResult(success=True, data=result, processing_time=processing_time)
|
||||
else:
|
||||
logger.warning(f"消息处理钩子执行失败: {result.get('error')}")
|
||||
return HookResult(success=False, error=result.get('error'), processing_time=processing_time)
|
||||
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
@@ -334,7 +320,7 @@ class MemoryIntegrationHooks:
|
||||
"stream_id": chat_stream_data.get("stream_id"),
|
||||
"platform": chat_stream_data.get("platform", "unknown"),
|
||||
"message_count": len(messages),
|
||||
"timestamp": time.time()
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
# 使用增强记忆系统处理对话
|
||||
@@ -352,7 +338,7 @@ class MemoryIntegrationHooks:
|
||||
return HookResult(success=True, data=result, processing_time=processing_time)
|
||||
else:
|
||||
logger.warning(f"聊天流保存钩子执行失败: {result.get('error')}")
|
||||
return HookResult(success=False, error=result.get('error'), processing_time=processing_time)
|
||||
return HookResult(success=False, error=result.get("error"), processing_time=processing_time)
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
@@ -375,9 +361,7 @@ class MemoryIntegrationHooks:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 检索相关记忆
|
||||
memories = await retrieve_memories_with_enhanced_system(
|
||||
query, user_id, context, limit=5
|
||||
)
|
||||
memories = await retrieve_memories_with_enhanced_system(query, user_id, context, limit=5)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
@@ -411,9 +395,7 @@ class MemoryIntegrationHooks:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 获取记忆上下文并增强查询
|
||||
memory_context = await get_memory_context_for_prompt(
|
||||
query, user_id, context, max_memories=3
|
||||
)
|
||||
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=3)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
@@ -445,9 +427,7 @@ class MemoryIntegrationHooks:
|
||||
return HookResult(success=True, data="No query provided")
|
||||
|
||||
# 获取记忆上下文
|
||||
memory_context = await get_memory_context_for_prompt(
|
||||
query, user_id, context, max_memories=5
|
||||
)
|
||||
memory_context = await get_memory_context_for_prompt(query, user_id, context, max_memories=5)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
self._update_hook_stats(processing_time)
|
||||
@@ -499,6 +479,7 @@ class MemoryMaintenanceTask:
|
||||
# 获取适配器实例
|
||||
try:
|
||||
from src.chat.memory_system.enhanced_memory_adapter import _enhanced_memory_adapter
|
||||
|
||||
if _enhanced_memory_adapter:
|
||||
await _enhanced_memory_adapter.maintenance()
|
||||
logger.info("✅ 增强记忆系统维护任务完成")
|
||||
@@ -543,4 +524,4 @@ async def initialize_memory_integration_hooks():
|
||||
return hooks
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆集成钩子初始化失败: {e}", exc_info=True)
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
为记忆系统提供多维度的精准过滤和查询能力
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any, Union
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
@@ -23,23 +21,25 @@ logger = get_logger(__name__)
|
||||
|
||||
class IndexType(Enum):
|
||||
"""索引类型"""
|
||||
MEMORY_TYPE = "memory_type" # 记忆类型索引
|
||||
USER_ID = "user_id" # 用户ID索引
|
||||
SUBJECT = "subject" # 主体索引
|
||||
KEYWORD = "keyword" # 关键词索引
|
||||
TAG = "tag" # 标签索引
|
||||
CATEGORY = "category" # 分类索引
|
||||
TIMESTAMP = "timestamp" # 时间索引
|
||||
CONFIDENCE = "confidence" # 置信度索引
|
||||
IMPORTANCE = "importance" # 重要性索引
|
||||
|
||||
MEMORY_TYPE = "memory_type" # 记忆类型索引
|
||||
USER_ID = "user_id" # 用户ID索引
|
||||
SUBJECT = "subject" # 主体索引
|
||||
KEYWORD = "keyword" # 关键词索引
|
||||
TAG = "tag" # 标签索引
|
||||
CATEGORY = "category" # 分类索引
|
||||
TIMESTAMP = "timestamp" # 时间索引
|
||||
CONFIDENCE = "confidence" # 置信度索引
|
||||
IMPORTANCE = "importance" # 重要性索引
|
||||
RELATIONSHIP_SCORE = "relationship_score" # 关系分索引
|
||||
ACCESS_FREQUENCY = "access_frequency" # 访问频率索引
|
||||
SEMANTIC_HASH = "semantic_hash" # 语义哈希索引
|
||||
SEMANTIC_HASH = "semantic_hash" # 语义哈希索引
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexQuery:
|
||||
"""索引查询条件"""
|
||||
|
||||
user_ids: Optional[List[str]] = None
|
||||
memory_types: Optional[List[MemoryType]] = None
|
||||
subjects: Optional[List[str]] = None
|
||||
@@ -61,6 +61,7 @@ class IndexQuery:
|
||||
@dataclass
|
||||
class IndexResult:
|
||||
"""索引结果"""
|
||||
|
||||
memory_ids: List[str]
|
||||
total_count: int
|
||||
query_time: float
|
||||
@@ -102,7 +103,7 @@ class MetadataIndexManager:
|
||||
"average_query_time": 0.0,
|
||||
"total_queries": 0,
|
||||
"cache_hit_rate": 0.0,
|
||||
"cache_hits": 0
|
||||
"cache_hits": 0,
|
||||
}
|
||||
|
||||
# 线程锁
|
||||
@@ -171,9 +172,8 @@ class MetadataIndexManager:
|
||||
|
||||
index_time = time.time() - start_time
|
||||
self.index_stats["index_build_time"] = (
|
||||
(self.index_stats["index_build_time"] * (len(memories) - 1) + index_time) /
|
||||
len(memories)
|
||||
)
|
||||
self.index_stats["index_build_time"] * (len(memories) - 1) + index_time
|
||||
) / len(memories)
|
||||
|
||||
logger.debug(f"元数据索引完成,{len(memories)} 条记忆,耗时 {index_time:.3f}秒")
|
||||
|
||||
@@ -258,7 +258,7 @@ class MetadataIndexManager:
|
||||
"relationship_score": memory.metadata.relationship_score,
|
||||
"relevance_score": memory.metadata.relevance_score,
|
||||
"semantic_hash": memory.semantic_hash,
|
||||
"subjects": memory.subjects
|
||||
"subjects": memory.subjects,
|
||||
}
|
||||
|
||||
# 记忆类型索引
|
||||
@@ -355,21 +355,20 @@ class MetadataIndexManager:
|
||||
|
||||
# 限制数量
|
||||
if query.limit and len(filtered_ids) > query.limit:
|
||||
filtered_ids = filtered_ids[:query.limit]
|
||||
filtered_ids = filtered_ids[: query.limit]
|
||||
|
||||
# 记录查询统计
|
||||
query_time = time.time() - start_time
|
||||
self.index_stats["total_queries"] += 1
|
||||
self.index_stats["average_query_time"] = (
|
||||
(self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time) /
|
||||
self.index_stats["total_queries"]
|
||||
)
|
||||
self.index_stats["average_query_time"] * (self.index_stats["total_queries"] - 1) + query_time
|
||||
) / self.index_stats["total_queries"]
|
||||
|
||||
return IndexResult(
|
||||
memory_ids=filtered_ids,
|
||||
total_count=len(filtered_ids),
|
||||
query_time=query_time,
|
||||
filtered_by=self._get_applied_filters(query)
|
||||
filtered_by=self._get_applied_filters(query),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -486,15 +485,15 @@ class MetadataIndexManager:
|
||||
if query.time_range:
|
||||
start_time, end_time = query.time_range
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
if self._is_in_time_range(memory_id, start_time, end_time)
|
||||
memory_id for memory_id in filtered_ids if self._is_in_time_range(memory_id, start_time, end_time)
|
||||
]
|
||||
|
||||
# 置信度过滤
|
||||
if query.confidence_levels:
|
||||
confidence_set = set(query.confidence_levels)
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
memory_id
|
||||
for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["confidence"] in confidence_set
|
||||
]
|
||||
|
||||
@@ -502,27 +501,31 @@ class MetadataIndexManager:
|
||||
if query.importance_levels:
|
||||
importance_set = set(query.importance_levels)
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
memory_id
|
||||
for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["importance"] in importance_set
|
||||
]
|
||||
|
||||
# 关系分范围过滤
|
||||
if query.min_relationship_score is not None:
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
memory_id
|
||||
for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["relationship_score"] >= query.min_relationship_score
|
||||
]
|
||||
|
||||
if query.max_relationship_score is not None:
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
memory_id
|
||||
for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["relationship_score"] <= query.max_relationship_score
|
||||
]
|
||||
|
||||
# 最小访问次数过滤
|
||||
if query.min_access_count is not None:
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
memory_id
|
||||
for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["access_count"] >= query.min_access_count
|
||||
]
|
||||
|
||||
@@ -530,7 +533,8 @@ class MetadataIndexManager:
|
||||
if query.semantic_hashes:
|
||||
hash_set = set(query.semantic_hashes)
|
||||
filtered_ids = [
|
||||
memory_id for memory_id in filtered_ids
|
||||
memory_id
|
||||
for memory_id in filtered_ids
|
||||
if self.memory_metadata_cache[memory_id]["semantic_hash"] in hash_set
|
||||
]
|
||||
|
||||
@@ -560,8 +564,7 @@ class MetadataIndexManager:
|
||||
elif sort_by == "relevance_score":
|
||||
# 按相关度排序
|
||||
memory_ids.sort(
|
||||
key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"],
|
||||
reverse=(sort_order == "desc")
|
||||
key=lambda mid: self.memory_metadata_cache[mid]["relevance_score"], reverse=(sort_order == "desc")
|
||||
)
|
||||
|
||||
elif sort_by == "relationship_score":
|
||||
@@ -574,8 +577,7 @@ class MetadataIndexManager:
|
||||
elif sort_by == "last_accessed":
|
||||
# 按最后访问时间排序
|
||||
memory_ids.sort(
|
||||
key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"],
|
||||
reverse=(sort_order == "desc")
|
||||
key=lambda mid: self.memory_metadata_cache[mid]["last_accessed"], reverse=(sort_order == "desc")
|
||||
)
|
||||
|
||||
return memory_ids
|
||||
@@ -665,7 +667,9 @@ class MetadataIndexManager:
|
||||
self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid != memory_id]
|
||||
|
||||
# 从访问频率索引中移除
|
||||
self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid != memory_id]
|
||||
self.access_frequency_index = [
|
||||
(count, mid) for count, mid in self.access_frequency_index if mid != memory_id
|
||||
]
|
||||
|
||||
# 注意:关键词、标签、分类索引需要从原始记忆中获取,这里简化处理
|
||||
# 实际实现中可能需要重新加载记忆或维护反向索引
|
||||
@@ -704,7 +708,7 @@ class MetadataIndexManager:
|
||||
"average_importance": 0.0,
|
||||
"average_relationship_score": 0.0,
|
||||
"top_keywords": [],
|
||||
"top_tags": []
|
||||
"top_tags": [],
|
||||
}
|
||||
|
||||
if user_id:
|
||||
@@ -789,23 +793,23 @@ class MetadataIndexManager:
|
||||
indices_data[index_type.value] = serialized_index
|
||||
|
||||
indices_file = self.index_path / "indices.json"
|
||||
with open(indices_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
with open(indices_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(indices_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存时间索引
|
||||
time_index_file = self.index_path / "time_index.json"
|
||||
with open(time_index_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
with open(time_index_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(self.time_index, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存关系分索引
|
||||
relationship_index_file = self.index_path / "relationship_index.json"
|
||||
with open(relationship_index_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
with open(relationship_index_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(self.relationship_index, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存访问频率索引
|
||||
access_frequency_index_file = self.index_path / "access_frequency_index.json"
|
||||
with open(access_frequency_index_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
with open(access_frequency_index_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(self.access_frequency_index, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存元数据缓存
|
||||
metadata_cache_file = self.index_path / "metadata_cache.json"
|
||||
@@ -813,13 +817,13 @@ class MetadataIndexManager:
|
||||
memory_id: self._serialize_metadata_entry(metadata)
|
||||
for memory_id, metadata in self.memory_metadata_cache.items()
|
||||
}
|
||||
with open(metadata_cache_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
with open(metadata_cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(metadata_serialized, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存统计信息
|
||||
stats_file = self.index_path / "index_stats.json"
|
||||
with open(stats_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
with open(stats_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(self.index_stats, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
self._dirty = False
|
||||
logger.info("✅ 元数据索引保存完成")
|
||||
@@ -835,7 +839,7 @@ class MetadataIndexManager:
|
||||
# 加载各类索引
|
||||
indices_file = self.index_path / "indices.json"
|
||||
if indices_file.exists():
|
||||
with open(indices_file, 'r', encoding='utf-8') as f:
|
||||
with open(indices_file, "r", encoding="utf-8") as f:
|
||||
indices_data = orjson.loads(f.read())
|
||||
|
||||
for index_type_value, index_data in indices_data.items():
|
||||
@@ -849,25 +853,25 @@ class MetadataIndexManager:
|
||||
# 加载时间索引
|
||||
time_index_file = self.index_path / "time_index.json"
|
||||
if time_index_file.exists():
|
||||
with open(time_index_file, 'r', encoding='utf-8') as f:
|
||||
with open(time_index_file, "r", encoding="utf-8") as f:
|
||||
self.time_index = orjson.loads(f.read())
|
||||
|
||||
# 加载关系分索引
|
||||
relationship_index_file = self.index_path / "relationship_index.json"
|
||||
if relationship_index_file.exists():
|
||||
with open(relationship_index_file, 'r', encoding='utf-8') as f:
|
||||
with open(relationship_index_file, "r", encoding="utf-8") as f:
|
||||
self.relationship_index = orjson.loads(f.read())
|
||||
|
||||
# 加载访问频率索引
|
||||
access_frequency_index_file = self.index_path / "access_frequency_index.json"
|
||||
if access_frequency_index_file.exists():
|
||||
with open(access_frequency_index_file, 'r', encoding='utf-8') as f:
|
||||
with open(access_frequency_index_file, "r", encoding="utf-8") as f:
|
||||
self.access_frequency_index = orjson.loads(f.read())
|
||||
|
||||
# 加载元数据缓存
|
||||
metadata_cache_file = self.index_path / "metadata_cache.json"
|
||||
if metadata_cache_file.exists():
|
||||
with open(metadata_cache_file, 'r', encoding='utf-8') as f:
|
||||
with open(metadata_cache_file, "r", encoding="utf-8") as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
# 转换置信度和重要性为枚举类型
|
||||
@@ -910,7 +914,7 @@ class MetadataIndexManager:
|
||||
# 加载统计信息
|
||||
stats_file = self.index_path / "index_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, 'r', encoding='utf-8') as f:
|
||||
with open(stats_file, "r", encoding="utf-8") as f:
|
||||
self.index_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新记忆计数
|
||||
@@ -937,9 +941,7 @@ class MetadataIndexManager:
|
||||
|
||||
# 更新统计信息
|
||||
if self.index_stats["total_queries"] > 0:
|
||||
self.index_stats["cache_hit_rate"] = (
|
||||
self.index_stats["cache_hits"] / self.index_stats["total_queries"]
|
||||
)
|
||||
self.index_stats["cache_hit_rate"] = self.index_stats["cache_hits"] / self.index_stats["total_queries"]
|
||||
|
||||
logger.info("✅ 元数据索引优化完成")
|
||||
|
||||
@@ -967,7 +969,9 @@ class MetadataIndexManager:
|
||||
self.relationship_index = [(score, mid) for score, mid in self.relationship_index if mid in valid_memory_ids]
|
||||
|
||||
# 清理访问频率索引中的无效引用
|
||||
self.access_frequency_index = [(count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids]
|
||||
self.access_frequency_index = [
|
||||
(count, mid) for count, mid in self.access_frequency_index if mid in valid_memory_ids
|
||||
]
|
||||
|
||||
# 更新总记忆数
|
||||
self.index_stats["total_memories"] = len(valid_memory_ids)
|
||||
@@ -1017,7 +1021,7 @@ class MetadataIndexManager:
|
||||
"categories": len(self.indices[IndexType.CATEGORY]),
|
||||
"confidence_levels": len(self.indices[IndexType.CONFIDENCE]),
|
||||
"importance_levels": len(self.indices[IndexType.IMPORTANCE]),
|
||||
"semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH])
|
||||
"semantic_hashes": len(self.indices[IndexType.SEMANTIC_HASH]),
|
||||
}
|
||||
|
||||
return stats
|
||||
return stats
|
||||
|
||||
@@ -5,15 +5,13 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -21,30 +19,32 @@ logger = get_logger(__name__)
|
||||
|
||||
class RetrievalStage(Enum):
|
||||
"""检索阶段"""
|
||||
METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段
|
||||
VECTOR_SEARCH = "vector_search" # 向量搜索阶段
|
||||
SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段
|
||||
CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段
|
||||
|
||||
METADATA_FILTERING = "metadata_filtering" # 元数据过滤阶段
|
||||
VECTOR_SEARCH = "vector_search" # 向量搜索阶段
|
||||
SEMANTIC_RERANKING = "semantic_reranking" # 语义重排序阶段
|
||||
CONTEXTUAL_FILTERING = "contextual_filtering" # 上下文过滤阶段
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalConfig:
|
||||
"""检索配置"""
|
||||
|
||||
# 各阶段配置 - 优化召回率
|
||||
metadata_filter_limit: int = 150 # 元数据过滤阶段返回数量(增加)
|
||||
vector_search_limit: int = 80 # 向量搜索阶段返回数量(增加)
|
||||
semantic_rerank_limit: int = 30 # 语义重排序阶段返回数量(增加)
|
||||
final_result_limit: int = 10 # 最终结果数量
|
||||
metadata_filter_limit: int = 150 # 元数据过滤阶段返回数量(增加)
|
||||
vector_search_limit: int = 80 # 向量搜索阶段返回数量(增加)
|
||||
semantic_rerank_limit: int = 30 # 语义重排序阶段返回数量(增加)
|
||||
final_result_limit: int = 10 # 最终结果数量
|
||||
|
||||
# 相似度阈值 - 优化召回率
|
||||
vector_similarity_threshold: float = 0.5 # 向量相似度阈值(降低以提升召回率)
|
||||
vector_similarity_threshold: float = 0.5 # 向量相似度阈值(降低以提升召回率)
|
||||
semantic_similarity_threshold: float = 0.05 # 语义相似度阈值(保持较低以获得更多相关记忆)
|
||||
|
||||
# 权重配置
|
||||
vector_weight: float = 0.4 # 向量相似度权重
|
||||
semantic_weight: float = 0.3 # 语义相似度权重
|
||||
context_weight: float = 0.2 # 上下文权重
|
||||
recency_weight: float = 0.1 # 时效性权重
|
||||
vector_weight: float = 0.4 # 向量相似度权重
|
||||
semantic_weight: float = 0.3 # 语义相似度权重
|
||||
context_weight: float = 0.2 # 上下文权重
|
||||
recency_weight: float = 0.1 # 时效性权重
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls):
|
||||
@@ -53,26 +53,25 @@ class RetrievalConfig:
|
||||
|
||||
return cls(
|
||||
# 各阶段配置 - 优化召回率
|
||||
metadata_filter_limit=max(150, global_config.memory.metadata_filter_limit), # 增加候选池
|
||||
vector_search_limit=max(80, global_config.memory.vector_search_limit), # 增加向量搜索结果
|
||||
semantic_rerank_limit=max(30, global_config.memory.semantic_rerank_limit), # 增加重排序候选
|
||||
metadata_filter_limit=max(150, global_config.memory.metadata_filter_limit), # 增加候选池
|
||||
vector_search_limit=max(80, global_config.memory.vector_search_limit), # 增加向量搜索结果
|
||||
semantic_rerank_limit=max(30, global_config.memory.semantic_rerank_limit), # 增加重排序候选
|
||||
final_result_limit=global_config.memory.final_result_limit,
|
||||
|
||||
# 相似度阈值 - 优化召回率
|
||||
vector_similarity_threshold=max(0.5, global_config.memory.vector_similarity_threshold), # 确保不低于0.5
|
||||
semantic_similarity_threshold=0.05, # 进一步降低以提升召回率
|
||||
|
||||
# 权重配置
|
||||
vector_weight=global_config.memory.vector_weight,
|
||||
semantic_weight=global_config.memory.semantic_weight,
|
||||
context_weight=global_config.memory.context_weight,
|
||||
recency_weight=global_config.memory.recency_weight
|
||||
recency_weight=global_config.memory.recency_weight,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageResult:
|
||||
"""阶段结果"""
|
||||
|
||||
stage: RetrievalStage
|
||||
memory_ids: List[str]
|
||||
processing_time: float
|
||||
@@ -84,6 +83,7 @@ class StageResult:
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""检索结果"""
|
||||
|
||||
query: str
|
||||
user_id: str
|
||||
final_memories: List[MemoryChunk]
|
||||
@@ -98,16 +98,16 @@ class MultiStageRetrieval:
|
||||
|
||||
def __init__(self, config: Optional[RetrievalConfig] = None):
|
||||
self.config = config or RetrievalConfig.from_global_config()
|
||||
|
||||
|
||||
# 初始化增强重排序器
|
||||
reranker_config = ReRankingConfig(
|
||||
semantic_weight=self.config.vector_weight,
|
||||
recency_weight=self.config.recency_weight,
|
||||
usage_freq_weight=0.2, # 新增的使用频率权重
|
||||
type_match_weight=0.1 # 新增的类型匹配权重
|
||||
type_match_weight=0.1, # 新增的类型匹配权重
|
||||
)
|
||||
self.reranker = EnhancedReRanker(reranker_config)
|
||||
|
||||
|
||||
self.retrieval_stats = {
|
||||
"total_queries": 0,
|
||||
"average_retrieval_time": 0.0,
|
||||
@@ -116,8 +116,8 @@ class MultiStageRetrieval:
|
||||
"vector_search": {"calls": 0, "avg_time": 0.0},
|
||||
"semantic_reranking": {"calls": 0, "avg_time": 0.0},
|
||||
"contextual_filtering": {"calls": 0, "avg_time": 0.0},
|
||||
"enhanced_reranking": {"calls": 0, "avg_time": 0.0} # 新增统计
|
||||
}
|
||||
"enhanced_reranking": {"calls": 0, "avg_time": 0.0}, # 新增统计
|
||||
},
|
||||
}
|
||||
|
||||
async def retrieve_memories(
|
||||
@@ -128,7 +128,7 @@ class MultiStageRetrieval:
|
||||
metadata_index,
|
||||
vector_storage,
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
limit: Optional[int] = None
|
||||
limit: Optional[int] = None,
|
||||
) -> RetrievalResult:
|
||||
"""多阶段记忆检索"""
|
||||
start_time = time.time()
|
||||
@@ -143,31 +143,39 @@ class MultiStageRetrieval:
|
||||
|
||||
# 阶段1:元数据过滤
|
||||
stage1_result = await self._metadata_filtering_stage(
|
||||
query, user_id, context, metadata_index, all_memories_cache,
|
||||
debug_log=memory_debug_info
|
||||
query, user_id, context, metadata_index, all_memories_cache, debug_log=memory_debug_info
|
||||
)
|
||||
stage_results.append(stage1_result)
|
||||
current_memory_ids.update(stage1_result.memory_ids)
|
||||
|
||||
# 阶段2:向量搜索
|
||||
stage2_result = await self._vector_search_stage(
|
||||
query, user_id, context, vector_storage, current_memory_ids, all_memories_cache,
|
||||
debug_log=memory_debug_info
|
||||
query,
|
||||
user_id,
|
||||
context,
|
||||
vector_storage,
|
||||
current_memory_ids,
|
||||
all_memories_cache,
|
||||
debug_log=memory_debug_info,
|
||||
)
|
||||
stage_results.append(stage2_result)
|
||||
current_memory_ids.update(stage2_result.memory_ids)
|
||||
|
||||
# 阶段3:语义重排序
|
||||
stage3_result = await self._semantic_reranking_stage(
|
||||
query, user_id, context, current_memory_ids, all_memories_cache,
|
||||
debug_log=memory_debug_info
|
||||
query, user_id, context, current_memory_ids, all_memories_cache, debug_log=memory_debug_info
|
||||
)
|
||||
stage_results.append(stage3_result)
|
||||
|
||||
# 阶段4:上下文过滤
|
||||
stage4_result = await self._contextual_filtering_stage(
|
||||
query, user_id, context, stage3_result.memory_ids, all_memories_cache, limit,
|
||||
debug_log=memory_debug_info
|
||||
query,
|
||||
user_id,
|
||||
context,
|
||||
stage3_result.memory_ids,
|
||||
all_memories_cache,
|
||||
limit,
|
||||
debug_log=memory_debug_info,
|
||||
)
|
||||
stage_results.append(stage4_result)
|
||||
|
||||
@@ -176,18 +184,27 @@ class MultiStageRetrieval:
|
||||
logger.debug(f"上下文过滤结果过少({len(stage4_result.memory_ids)}),启用回退机制")
|
||||
# 回退到更宽松的检索策略
|
||||
fallback_result = await self._fallback_retrieval_stage(
|
||||
query, user_id, context, all_memories_cache, limit,
|
||||
query,
|
||||
user_id,
|
||||
context,
|
||||
all_memories_cache,
|
||||
limit,
|
||||
excluded_ids=set(stage4_result.memory_ids),
|
||||
debug_log=memory_debug_info
|
||||
debug_log=memory_debug_info,
|
||||
)
|
||||
if fallback_result.memory_ids:
|
||||
stage4_result.memory_ids.extend(fallback_result.memory_ids[:limit - len(stage4_result.memory_ids)])
|
||||
stage4_result.memory_ids.extend(fallback_result.memory_ids[: limit - len(stage4_result.memory_ids)])
|
||||
logger.debug(f"回退机制补充了 {len(fallback_result.memory_ids)} 条记忆")
|
||||
|
||||
# 阶段5:增强重排序 (新增)
|
||||
stage5_result = await self._enhanced_reranking_stage(
|
||||
query, user_id, context, stage4_result.memory_ids, all_memories_cache, limit,
|
||||
debug_log=memory_debug_info
|
||||
query,
|
||||
user_id,
|
||||
context,
|
||||
stage4_result.memory_ids,
|
||||
all_memories_cache,
|
||||
limit,
|
||||
debug_log=memory_debug_info,
|
||||
)
|
||||
stage_results.append(stage5_result)
|
||||
|
||||
@@ -226,13 +243,21 @@ class MultiStageRetrieval:
|
||||
"semantic_score": trace.get("semantic_stage", {}).get("score"),
|
||||
"context_score": trace.get("context_stage", {}).get("context_score"),
|
||||
"final_score": trace.get("context_stage", {}).get("final_score"),
|
||||
"status": trace.get("context_stage", {}).get("status") or trace.get("vector_stage", {}).get("status") or trace.get("semantic_stage", {}).get("status"),
|
||||
"status": trace.get("context_stage", {}).get("status")
|
||||
or trace.get("vector_stage", {}).get("status")
|
||||
or trace.get("semantic_stage", {}).get("status"),
|
||||
"is_final": memory_id in final_ids_set,
|
||||
}
|
||||
debug_entries.append(entry)
|
||||
|
||||
# 限制日志输出数量
|
||||
debug_entries.sort(key=lambda item: (item.get("is_final", False), item.get("final_score") or item.get("vector_similarity") or 0.0), reverse=True)
|
||||
debug_entries.sort(
|
||||
key=lambda item: (
|
||||
item.get("is_final", False),
|
||||
item.get("final_score") or item.get("vector_similarity") or 0.0,
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
debug_payload = {
|
||||
"query": query,
|
||||
"semantic_query": context.get("resolved_query_text", query),
|
||||
@@ -266,7 +291,7 @@ class MultiStageRetrieval:
|
||||
stage_results=stage_results,
|
||||
total_processing_time=total_time,
|
||||
total_filtered=total_filtered,
|
||||
retrieval_stats=self.retrieval_stats.copy()
|
||||
retrieval_stats=self.retrieval_stats.copy(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -279,7 +304,7 @@ class MultiStageRetrieval:
|
||||
stage_results=stage_results,
|
||||
total_processing_time=time.time() - start_time,
|
||||
total_filtered=0,
|
||||
retrieval_stats=self.retrieval_stats.copy()
|
||||
retrieval_stats=self.retrieval_stats.copy(),
|
||||
)
|
||||
|
||||
async def _metadata_filtering_stage(
|
||||
@@ -290,7 +315,7 @@ class MultiStageRetrieval:
|
||||
metadata_index,
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> StageResult:
|
||||
"""阶段1:元数据过滤"""
|
||||
start_time = time.time()
|
||||
@@ -302,7 +327,9 @@ class MultiStageRetrieval:
|
||||
|
||||
memory_types = self._extract_memory_types_from_context(context)
|
||||
keywords = self._extract_keywords_from_query(query, query_plan)
|
||||
subjects = query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None
|
||||
subjects = (
|
||||
query_plan.subject_includes if query_plan and getattr(query_plan, "subject_includes", None) else None
|
||||
)
|
||||
|
||||
index_query = IndexQuery(
|
||||
user_ids=None,
|
||||
@@ -311,7 +338,7 @@ class MultiStageRetrieval:
|
||||
keywords=keywords,
|
||||
limit=self.config.metadata_filter_limit,
|
||||
sort_by="last_accessed",
|
||||
sort_order="desc"
|
||||
sort_order="desc",
|
||||
)
|
||||
|
||||
# 执行查询
|
||||
@@ -328,19 +355,16 @@ class MultiStageRetrieval:
|
||||
reverse=True,
|
||||
)
|
||||
if memory_types:
|
||||
type_filtered = [
|
||||
mid for mid in sorted_ids
|
||||
if all_memories_cache[mid].memory_type in memory_types
|
||||
]
|
||||
type_filtered = [mid for mid in sorted_ids if all_memories_cache[mid].memory_type in memory_types]
|
||||
sorted_ids = type_filtered or sorted_ids
|
||||
if subjects:
|
||||
subject_candidates = [s.lower() for s in subjects if isinstance(s, str) and s.strip()]
|
||||
if subject_candidates:
|
||||
subject_filtered = [
|
||||
mid for mid in sorted_ids
|
||||
mid
|
||||
for mid in sorted_ids
|
||||
if any(
|
||||
subj.strip().lower() in subject_candidates
|
||||
for subj in all_memories_cache[mid].subjects
|
||||
subj.strip().lower() in subject_candidates for subj in all_memories_cache[mid].subjects
|
||||
)
|
||||
]
|
||||
sorted_ids = subject_filtered or sorted_ids
|
||||
@@ -367,12 +391,14 @@ class MultiStageRetrieval:
|
||||
bool(subjects),
|
||||
bool(keywords),
|
||||
)
|
||||
details.append({
|
||||
"note": "fallback_recent",
|
||||
"requested_types": [mt.value for mt in memory_types] if memory_types else [],
|
||||
"subjects": subjects or [],
|
||||
"keywords": keywords or [],
|
||||
})
|
||||
details.append(
|
||||
{
|
||||
"note": "fallback_recent",
|
||||
"requested_types": [mt.value for mt in memory_types] if memory_types else [],
|
||||
"subjects": subjects or [],
|
||||
"keywords": keywords or [],
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"元数据过滤:候选=%d, 返回=%d",
|
||||
@@ -419,7 +445,7 @@ class MultiStageRetrieval:
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> StageResult:
|
||||
"""阶段2:向量搜索"""
|
||||
start_time = time.time()
|
||||
@@ -441,8 +467,7 @@ class MultiStageRetrieval:
|
||||
|
||||
# 执行向量搜索
|
||||
search_result = await vector_storage.search_similar_memories(
|
||||
query_vector=query_embedding,
|
||||
limit=self.config.vector_search_limit
|
||||
query_vector=query_embedding, limit=self.config.vector_search_limit
|
||||
)
|
||||
|
||||
if not search_result:
|
||||
@@ -464,16 +489,18 @@ class MultiStageRetrieval:
|
||||
if in_metadata_candidates and above_threshold:
|
||||
filtered_memories.append((memory_id, similarity))
|
||||
|
||||
raw_details.append({
|
||||
"memory_id": memory_id,
|
||||
"similarity": similarity,
|
||||
"in_metadata": in_metadata_candidates,
|
||||
"above_threshold": above_threshold,
|
||||
})
|
||||
raw_details.append(
|
||||
{
|
||||
"memory_id": memory_id,
|
||||
"similarity": similarity,
|
||||
"in_metadata": in_metadata_candidates,
|
||||
"above_threshold": above_threshold,
|
||||
}
|
||||
)
|
||||
|
||||
# 按相似度排序
|
||||
filtered_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
result_ids = [memory_id for memory_id, _ in filtered_memories[:self.config.vector_search_limit]]
|
||||
result_ids = [memory_id for memory_id, _ in filtered_memories[: self.config.vector_search_limit]]
|
||||
kept_ids = set(result_ids)
|
||||
|
||||
for entry in raw_details:
|
||||
@@ -534,11 +561,7 @@ class MultiStageRetrieval:
|
||||
)
|
||||
|
||||
def _create_text_search_fallback(
|
||||
self,
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
query_text: str,
|
||||
start_time: float
|
||||
self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float
|
||||
) -> StageResult:
|
||||
"""当向量搜索失败时,使用文本搜索作为回退策略"""
|
||||
try:
|
||||
@@ -561,15 +584,13 @@ class MultiStageRetrieval:
|
||||
|
||||
# 按匹配度排序
|
||||
text_matches.sort(key=lambda x: x[1], reverse=True)
|
||||
result_ids = [memory_id for memory_id, _ in text_matches[:self.config.vector_search_limit]]
|
||||
result_ids = [memory_id for memory_id, _ in text_matches[: self.config.vector_search_limit]]
|
||||
|
||||
details = []
|
||||
for memory_id, score in text_matches[:self.config.vector_search_limit]:
|
||||
details.append({
|
||||
"memory_id": memory_id,
|
||||
"text_match_score": round(score, 4),
|
||||
"status": "text_match_fallback"
|
||||
})
|
||||
for memory_id, score in text_matches[: self.config.vector_search_limit]:
|
||||
details.append(
|
||||
{"memory_id": memory_id, "text_match_score": round(score, 4), "status": "text_match_fallback"}
|
||||
)
|
||||
|
||||
logger.debug(f"向量搜索回退到文本匹配:找到 {len(result_ids)} 条匹配记忆")
|
||||
|
||||
@@ -579,18 +600,18 @@ class MultiStageRetrieval:
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=len(candidate_ids) - len(result_ids),
|
||||
score_threshold=0.0, # 文本匹配无严格阈值
|
||||
details=details
|
||||
details=details,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文本搜索回退失败: {e}")
|
||||
return StageResult(
|
||||
stage=RetrievalStage.VECTOR_SEARCH,
|
||||
memory_ids=list(candidate_ids)[:self.config.vector_search_limit],
|
||||
memory_ids=list(candidate_ids)[: self.config.vector_search_limit],
|
||||
processing_time=time.time() - start_time,
|
||||
filtered_count=0,
|
||||
score_threshold=0.0,
|
||||
details=[{"error": str(e), "note": "text_fallback_failed"}]
|
||||
details=[{"error": str(e), "note": "text_fallback_failed"}],
|
||||
)
|
||||
|
||||
async def _semantic_reranking_stage(
|
||||
@@ -601,7 +622,7 @@ class MultiStageRetrieval:
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> StageResult:
|
||||
"""阶段3:语义重排序"""
|
||||
start_time = time.time()
|
||||
@@ -643,7 +664,7 @@ class MultiStageRetrieval:
|
||||
|
||||
# 按语义相似度排序
|
||||
reranked_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
result_ids = [memory_id for memory_id, _ in reranked_memories[:self.config.semantic_rerank_limit]]
|
||||
result_ids = [memory_id for memory_id, _ in reranked_memories[: self.config.semantic_rerank_limit]]
|
||||
kept_ids = set(result_ids)
|
||||
|
||||
filtered_count = len(candidate_ids) - len(result_ids)
|
||||
@@ -688,7 +709,7 @@ class MultiStageRetrieval:
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> StageResult:
|
||||
"""阶段4:上下文过滤"""
|
||||
start_time = time.time()
|
||||
@@ -777,7 +798,7 @@ class MultiStageRetrieval:
|
||||
limit: int,
|
||||
*,
|
||||
excluded_ids: Optional[Set[str]] = None,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> StageResult:
|
||||
"""回退检索阶段 - 当主检索失败时使用更宽松的策略"""
|
||||
start_time = time.time()
|
||||
@@ -806,13 +827,15 @@ class MultiStageRetrieval:
|
||||
if not fallback_candidates:
|
||||
logger.debug("关键词匹配无结果,使用时序最近策略")
|
||||
recent_memories = sorted(
|
||||
[(mid, mem.metadata.last_accessed or mem.metadata.created_at)
|
||||
for mid, mem in all_memories_cache.items()
|
||||
if mid not in excluded_ids],
|
||||
[
|
||||
(mid, mem.metadata.last_accessed or mem.metadata.created_at)
|
||||
for mid, mem in all_memories_cache.items()
|
||||
if mid not in excluded_ids
|
||||
],
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
reverse=True,
|
||||
)
|
||||
fallback_candidates = [(mid, 0.5) for mid, _ in recent_memories[:limit*2]]
|
||||
fallback_candidates = [(mid, 0.5) for mid, _ in recent_memories[: limit * 2]]
|
||||
|
||||
# 按分数排序
|
||||
fallback_candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
@@ -857,7 +880,9 @@ class MultiStageRetrieval:
|
||||
details=[{"error": str(e)}],
|
||||
)
|
||||
|
||||
async def _generate_query_embedding(self, query: str, context: Dict[str, Any], vector_storage) -> Optional[List[float]]:
|
||||
async def _generate_query_embedding(
|
||||
self, query: str, context: Dict[str, Any], vector_storage
|
||||
) -> Optional[List[float]]:
|
||||
"""生成查询向量"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -875,15 +900,15 @@ class MultiStageRetrieval:
|
||||
|
||||
logger.debug(f"正在生成查询向量,文本: '{query_text[:100]}'")
|
||||
embedding = await vector_storage.generate_query_embedding(query_text)
|
||||
|
||||
|
||||
if embedding is None:
|
||||
logger.warning("向量存储返回空的查询向量")
|
||||
return None
|
||||
|
||||
|
||||
if len(embedding) == 0:
|
||||
logger.warning("向量存储返回空列表作为查询向量")
|
||||
return None
|
||||
|
||||
|
||||
logger.debug(f"查询向量生成成功,维度: {len(embedding)}")
|
||||
return embedding
|
||||
|
||||
@@ -926,8 +951,8 @@ class MultiStageRetrieval:
|
||||
import re
|
||||
|
||||
# 分词处理
|
||||
query_words = list(jieba.cut(query_text)) + re.findall(r'[a-zA-Z]+', query_text)
|
||||
memory_words = list(jieba.cut(memory_text)) + re.findall(r'[a-zA-Z]+', memory_text)
|
||||
query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text)
|
||||
memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text)
|
||||
|
||||
# 清理和标准化
|
||||
query_words = [w.strip().lower() for w in query_words if w.strip() and len(w.strip()) > 1]
|
||||
@@ -953,8 +978,9 @@ class MultiStageRetrieval:
|
||||
except ImportError:
|
||||
# 如果jieba不可用,使用简单分词
|
||||
import re
|
||||
query_words = re.findall(r'[\w\u4e00-\u9fa5]+', query_lower)
|
||||
memory_words = re.findall(r'[\w\u4e00-\u9fa5]+', memory_lower)
|
||||
|
||||
query_words = re.findall(r"[\w\u4e00-\u9fa5]+", query_lower)
|
||||
memory_words = re.findall(r"[\w\u4e00-\u9fa5]+", memory_lower)
|
||||
|
||||
if query_words and memory_words:
|
||||
query_set = set(w for w in query_words if len(w) > 1)
|
||||
@@ -971,13 +997,19 @@ class MultiStageRetrieval:
|
||||
"天气": ["天气", "阳光", "雨", "晴", "阴", "温度", "weather", "sunny", "rain"],
|
||||
"编程": ["编程", "代码", "程序", "开发", "语言", "programming", "code", "develop", "python"],
|
||||
"时间": ["今天", "昨天", "明天", "现在", "时间", "today", "yesterday", "tomorrow", "time"],
|
||||
"情感": ["好", "坏", "开心", "难过", "有趣", "good", "bad", "happy", "sad", "fun"]
|
||||
"情感": ["好", "坏", "开心", "难过", "有趣", "good", "bad", "happy", "sad", "fun"],
|
||||
}
|
||||
|
||||
query_concepts = {concept for concept, keywords in concept_groups.items()
|
||||
if any(keyword in query_lower for keyword in keywords)}
|
||||
memory_concepts = {concept for concept, keywords in concept_groups.items()
|
||||
if any(keyword in memory_lower for keyword in keywords)}
|
||||
query_concepts = {
|
||||
concept
|
||||
for concept, keywords in concept_groups.items()
|
||||
if any(keyword in query_lower for keyword in keywords)
|
||||
}
|
||||
memory_concepts = {
|
||||
concept
|
||||
for concept, keywords in concept_groups.items()
|
||||
if any(keyword in memory_lower for keyword in keywords)
|
||||
}
|
||||
|
||||
if query_concepts and memory_concepts:
|
||||
concept_overlap = query_concepts & memory_concepts
|
||||
@@ -987,19 +1019,19 @@ class MultiStageRetrieval:
|
||||
plan_bonus = 0.0
|
||||
if query_plan:
|
||||
# 主体匹配
|
||||
if hasattr(query_plan, 'subjects') and query_plan.subjects:
|
||||
if hasattr(query_plan, "subjects") and query_plan.subjects:
|
||||
for subject in query_plan.subjects:
|
||||
if subject.lower() in memory_lower:
|
||||
plan_bonus += 0.15
|
||||
|
||||
# 对象匹配
|
||||
if hasattr(query_plan, 'objects') and query_plan.objects:
|
||||
if hasattr(query_plan, "objects") and query_plan.objects:
|
||||
for obj in query_plan.objects:
|
||||
if obj.lower() in memory_lower:
|
||||
plan_bonus += 0.1
|
||||
|
||||
# 记忆类型匹配
|
||||
if hasattr(query_plan, 'memory_types') and query_plan.memory_types:
|
||||
if hasattr(query_plan, "memory_types") and query_plan.memory_types:
|
||||
if memory.memory_type in query_plan.memory_types:
|
||||
plan_bonus += 0.1
|
||||
|
||||
@@ -1059,14 +1091,22 @@ class MultiStageRetrieval:
|
||||
object_keywords = getattr(query_plan, "object_includes", []) or []
|
||||
if object_keywords:
|
||||
display_text = (memory.display or memory.text_content or "").lower()
|
||||
hits = sum(1 for kw in object_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
|
||||
hits = sum(
|
||||
1
|
||||
for kw in object_keywords
|
||||
if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text
|
||||
)
|
||||
if hits:
|
||||
score += min(0.3, hits * 0.1)
|
||||
|
||||
optional_keywords = getattr(query_plan, "optional_keywords", []) or []
|
||||
if optional_keywords:
|
||||
display_text = (memory.display or memory.text_content or "").lower()
|
||||
hits = sum(1 for kw in optional_keywords if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text)
|
||||
hits = sum(
|
||||
1
|
||||
for kw in optional_keywords
|
||||
if isinstance(kw, str) and kw.strip() and kw.strip().lower() in display_text
|
||||
)
|
||||
if hits:
|
||||
score += min(0.2, hits * 0.05)
|
||||
|
||||
@@ -1091,7 +1131,9 @@ class MultiStageRetrieval:
|
||||
logger.warning(f"计算上下文相关度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _calculate_final_score(self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float) -> float:
|
||||
async def _calculate_final_score(
|
||||
self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float
|
||||
) -> float:
|
||||
"""计算最终评分"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -1126,10 +1168,10 @@ class MultiStageRetrieval:
|
||||
context_weight += 0.05
|
||||
|
||||
final_score = (
|
||||
semantic_score * semantic_weight +
|
||||
vector_score * vector_weight +
|
||||
context_score * context_weight +
|
||||
recency_score * recency_weight
|
||||
semantic_score * semantic_weight
|
||||
+ vector_score * vector_weight
|
||||
+ context_score * context_weight
|
||||
+ recency_score * recency_weight
|
||||
)
|
||||
|
||||
# 加入记忆重要性权重
|
||||
@@ -1259,7 +1301,9 @@ class MultiStageRetrieval:
|
||||
stage_stat["calls"] += 1
|
||||
|
||||
current_stage_avg = stage_stat["avg_time"]
|
||||
new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat["calls"]
|
||||
new_stage_avg = (current_stage_avg * (stage_stat["calls"] - 1) + result.processing_time) / stage_stat[
|
||||
"calls"
|
||||
]
|
||||
stage_stat["avg_time"] = new_stage_avg
|
||||
|
||||
def get_retrieval_stats(self) -> Dict[str, Any]:
|
||||
@@ -1276,8 +1320,8 @@ class MultiStageRetrieval:
|
||||
"vector_search": {"calls": 0, "avg_time": 0.0},
|
||||
"semantic_reranking": {"calls": 0, "avg_time": 0.0},
|
||||
"contextual_filtering": {"calls": 0, "avg_time": 0.0},
|
||||
"enhanced_reranking": {"calls": 0, "avg_time": 0.0}
|
||||
}
|
||||
"enhanced_reranking": {"calls": 0, "avg_time": 0.0},
|
||||
},
|
||||
}
|
||||
|
||||
async def _enhanced_reranking_stage(
|
||||
@@ -1289,7 +1333,7 @@ class MultiStageRetrieval:
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> StageResult:
|
||||
"""阶段5:增强重排序 - 使用多维度评分模型"""
|
||||
start_time = time.time()
|
||||
@@ -1326,15 +1370,12 @@ class MultiStageRetrieval:
|
||||
|
||||
# 使用增强重排序器
|
||||
reranked_memories = self.reranker.rerank_memories(
|
||||
query=query,
|
||||
candidate_memories=candidate_memories,
|
||||
context=context,
|
||||
limit=limit
|
||||
query=query, candidate_memories=candidate_memories, context=context, limit=limit
|
||||
)
|
||||
|
||||
# 提取重排序后的记忆ID
|
||||
result_ids = [memory_id for memory_id, _, _ in reranked_memories]
|
||||
|
||||
|
||||
# 生成调试详情
|
||||
details = []
|
||||
for memory_id, memory, final_score in reranked_memories:
|
||||
@@ -1346,7 +1387,7 @@ class MultiStageRetrieval:
|
||||
"access_count": memory.metadata.access_count,
|
||||
}
|
||||
details.append(detail_entry)
|
||||
|
||||
|
||||
if debug_log is not None:
|
||||
stage_entry = debug_log.setdefault(memory_id, {}).setdefault("enhanced_rerank_stage", {})
|
||||
stage_entry["final_score"] = round(final_score, 4)
|
||||
@@ -1357,13 +1398,9 @@ class MultiStageRetrieval:
|
||||
kept_ids = set(result_ids)
|
||||
for memory_id in candidate_ids:
|
||||
if memory_id not in kept_ids:
|
||||
detail_entry = {
|
||||
"memory_id": memory_id,
|
||||
"status": "filtered_out",
|
||||
"reason": "ranked_below_limit"
|
||||
}
|
||||
detail_entry = {"memory_id": memory_id, "status": "filtered_out", "reason": "ranked_below_limit"}
|
||||
details.append(detail_entry)
|
||||
|
||||
|
||||
if debug_log is not None:
|
||||
stage_entry = debug_log.setdefault(memory_id, {}).setdefault("enhanced_rerank_stage", {})
|
||||
stage_entry["status"] = "filtered_out"
|
||||
@@ -1371,10 +1408,7 @@ class MultiStageRetrieval:
|
||||
|
||||
filtered_count = len(candidate_ids) - len(result_ids)
|
||||
|
||||
logger.debug(
|
||||
f"增强重排序完成:候选={len(candidate_ids)}, 返回={len(result_ids)}, "
|
||||
f"过滤={filtered_count}"
|
||||
)
|
||||
logger.debug(f"增强重排序完成:候选={len(candidate_ids)}, 返回={len(result_ids)}, 过滤={filtered_count}")
|
||||
|
||||
return StageResult(
|
||||
stage=RetrievalStage.CONTEXTUAL_FILTERING, # 保持与原有枚举兼容
|
||||
@@ -1394,4 +1428,4 @@ class MultiStageRetrieval:
|
||||
filtered_count=0,
|
||||
score_threshold=0.0,
|
||||
details=[{"error": str(e)}],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,22 +4,19 @@
|
||||
为记忆系统提供高效的向量存储和语义搜索能力
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.config.config import model_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
|
||||
@@ -28,6 +25,7 @@ logger = get_logger(__name__)
|
||||
# 尝试导入FAISS,如果不可用则使用简单替代
|
||||
try:
|
||||
import faiss
|
||||
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
@@ -37,6 +35,7 @@ except ImportError:
|
||||
@dataclass
|
||||
class VectorStorageConfig:
|
||||
"""向量存储配置"""
|
||||
|
||||
dimension: int = 1024
|
||||
similarity_threshold: float = 0.8
|
||||
index_type: str = "flat" # flat, ivf, hnsw
|
||||
@@ -79,7 +78,7 @@ class VectorStorageManager:
|
||||
"average_search_time": 0.0,
|
||||
"cache_hit_rate": 0.0,
|
||||
"total_searches": 0,
|
||||
"cache_hits": 0
|
||||
"cache_hits": 0,
|
||||
}
|
||||
|
||||
# 线程锁
|
||||
@@ -122,8 +121,7 @@ class VectorStorageManager:
|
||||
"""初始化嵌入模型"""
|
||||
if self.embedding_model is None:
|
||||
self.embedding_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.embedding,
|
||||
request_type="memory.embedding"
|
||||
model_set=model_config.model_task_config.embedding, request_type="memory.embedding"
|
||||
)
|
||||
logger.info("✅ 嵌入模型初始化完成")
|
||||
|
||||
@@ -137,20 +135,16 @@ class VectorStorageManager:
|
||||
await self.initialize_embedding_model()
|
||||
|
||||
logger.debug(f"开始生成查询向量,文本: '{query_text[:50]}{'...' if len(query_text) > 50 else ''}'")
|
||||
|
||||
|
||||
embedding, _ = await self.embedding_model.get_embedding(query_text)
|
||||
if not embedding:
|
||||
logger.warning("嵌入模型返回空向量")
|
||||
return None
|
||||
|
||||
logger.debug(f"生成的向量维度: {len(embedding)}, 期望维度: {self.config.dimension}")
|
||||
|
||||
|
||||
if len(embedding) != self.config.dimension:
|
||||
logger.error(
|
||||
"查询向量维度不匹配: 期望 %d, 实际 %d",
|
||||
self.config.dimension,
|
||||
len(embedding)
|
||||
)
|
||||
logger.error("查询向量维度不匹配: 期望 %d, 实际 %d", self.config.dimension, len(embedding))
|
||||
return None
|
||||
|
||||
normalized_vector = self._normalize_vector(embedding)
|
||||
@@ -287,7 +281,7 @@ class VectorStorageManager:
|
||||
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
|
||||
results[memory_id] = []
|
||||
|
||||
tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts)]
|
||||
tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
@@ -313,12 +307,12 @@ class VectorStorageManager:
|
||||
memory.set_embedding(embedding)
|
||||
|
||||
# 添加到向量索引
|
||||
if hasattr(self.vector_index, 'add'):
|
||||
if hasattr(self.vector_index, "add"):
|
||||
# FAISS索引
|
||||
if isinstance(embedding, np.ndarray):
|
||||
vector_array = embedding.reshape(1, -1).astype('float32')
|
||||
vector_array = embedding.reshape(1, -1).astype("float32")
|
||||
else:
|
||||
vector_array = np.array([embedding], dtype='float32')
|
||||
vector_array = np.array([embedding], dtype="float32")
|
||||
|
||||
# 特殊处理IVF索引
|
||||
if self.config.index_type == "ivf" and self.vector_index.ntotal == 0:
|
||||
@@ -367,14 +361,14 @@ class VectorStorageManager:
|
||||
*,
|
||||
query_text: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
scope_id: Optional[str] = None
|
||||
scope_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""搜索相似记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.debug(f"开始向量搜索: query_text='{query_text[:30] if query_text else 'None'}', limit={limit}")
|
||||
|
||||
|
||||
if query_vector is None:
|
||||
if not query_text:
|
||||
logger.warning("查询向量和查询文本都为空")
|
||||
@@ -395,34 +389,34 @@ class VectorStorageManager:
|
||||
|
||||
# 规范化查询向量
|
||||
query_vector = self._normalize_vector(query_vector)
|
||||
|
||||
|
||||
logger.debug(f"查询向量维度: {len(query_vector)}, 存储总向量数: {self.storage_stats['total_vectors']}")
|
||||
|
||||
# 检查向量索引状态
|
||||
if not self.vector_index:
|
||||
logger.error("向量索引未初始化")
|
||||
return []
|
||||
|
||||
|
||||
total_vectors = 0
|
||||
if hasattr(self.vector_index, 'ntotal'):
|
||||
if hasattr(self.vector_index, "ntotal"):
|
||||
total_vectors = self.vector_index.ntotal
|
||||
elif hasattr(self.vector_index, 'vectors'):
|
||||
elif hasattr(self.vector_index, "vectors"):
|
||||
total_vectors = len(self.vector_index.vectors)
|
||||
|
||||
|
||||
logger.debug(f"向量索引中实际向量数: {total_vectors}")
|
||||
|
||||
|
||||
if total_vectors == 0:
|
||||
logger.warning("向量索引为空,无法执行搜索")
|
||||
return []
|
||||
|
||||
# 执行向量搜索
|
||||
with self._lock:
|
||||
if hasattr(self.vector_index, 'search'):
|
||||
if hasattr(self.vector_index, "search"):
|
||||
# FAISS索引
|
||||
if isinstance(query_vector, np.ndarray):
|
||||
query_array = query_vector.reshape(1, -1).astype('float32')
|
||||
query_array = query_vector.reshape(1, -1).astype("float32")
|
||||
else:
|
||||
query_array = np.array([query_vector], dtype='float32')
|
||||
query_array = np.array([query_vector], dtype="float32")
|
||||
|
||||
if self.config.index_type == "ivf" and self.vector_index.ntotal > 0:
|
||||
# 设置IVF搜索参数
|
||||
@@ -432,11 +426,11 @@ class VectorStorageManager:
|
||||
|
||||
search_limit = min(limit, total_vectors)
|
||||
logger.debug(f"执行FAISS搜索,搜索限制: {search_limit}")
|
||||
|
||||
|
||||
distances, indices = self.vector_index.search(query_array, search_limit)
|
||||
distances = distances.flatten().tolist()
|
||||
indices = indices.flatten().tolist()
|
||||
|
||||
|
||||
logger.debug(f"FAISS搜索结果: {len(distances)} 个距离值, {len(indices)} 个索引")
|
||||
else:
|
||||
# 简单索引
|
||||
@@ -451,8 +445,8 @@ class VectorStorageManager:
|
||||
valid_results = 0
|
||||
invalid_indices = 0
|
||||
filtered_by_scope = 0
|
||||
|
||||
for distance, index in zip(distances, indices):
|
||||
|
||||
for distance, index in zip(distances, indices, strict=False):
|
||||
if index == -1: # FAISS的无效索引标记
|
||||
invalid_indices += 1
|
||||
continue
|
||||
@@ -462,7 +456,7 @@ class VectorStorageManager:
|
||||
logger.debug(f"索引 {index} 没有对应的记忆ID")
|
||||
invalid_indices += 1
|
||||
continue
|
||||
|
||||
|
||||
if scope_filter:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if memory and str(memory.user_id) != scope_filter:
|
||||
@@ -482,16 +476,15 @@ class VectorStorageManager:
|
||||
search_time = time.time() - start_time
|
||||
self.storage_stats["total_searches"] += 1
|
||||
self.storage_stats["average_search_time"] = (
|
||||
(self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time) /
|
||||
self.storage_stats["total_searches"]
|
||||
)
|
||||
self.storage_stats["average_search_time"] * (self.storage_stats["total_searches"] - 1) + search_time
|
||||
) / self.storage_stats["total_searches"]
|
||||
|
||||
final_results = results[:limit]
|
||||
logger.info(
|
||||
f"向量搜索完成: 查询='{query_text[:20] if query_text else 'vector'}' "
|
||||
f"耗时={search_time:.3f}s, 返回={len(final_results)}个结果"
|
||||
)
|
||||
|
||||
|
||||
return final_results
|
||||
|
||||
except Exception as e:
|
||||
@@ -520,7 +513,7 @@ class VectorStorageManager:
|
||||
old_index = self.memory_id_to_index[memory_id]
|
||||
|
||||
# 删除旧向量(如果支持)
|
||||
if hasattr(self.vector_index, 'remove_ids'):
|
||||
if hasattr(self.vector_index, "remove_ids"):
|
||||
try:
|
||||
self.vector_index.remove_ids(np.array([old_index]))
|
||||
except:
|
||||
@@ -530,11 +523,11 @@ class VectorStorageManager:
|
||||
new_embedding = self._normalize_vector(new_embedding)
|
||||
|
||||
# 添加新向量
|
||||
if hasattr(self.vector_index, 'add'):
|
||||
if hasattr(self.vector_index, "add"):
|
||||
if isinstance(new_embedding, np.ndarray):
|
||||
vector_array = new_embedding.reshape(1, -1).astype('float32')
|
||||
vector_array = new_embedding.reshape(1, -1).astype("float32")
|
||||
else:
|
||||
vector_array = np.array([new_embedding], dtype='float32')
|
||||
vector_array = np.array([new_embedding], dtype="float32")
|
||||
|
||||
self.vector_index.add(vector_array)
|
||||
new_index = self.vector_index.ntotal - 1
|
||||
@@ -569,7 +562,7 @@ class VectorStorageManager:
|
||||
index = self.memory_id_to_index[memory_id]
|
||||
|
||||
# 从向量索引中删除(如果支持)
|
||||
if hasattr(self.vector_index, 'remove_ids'):
|
||||
if hasattr(self.vector_index, "remove_ids"):
|
||||
try:
|
||||
self.vector_index.remove_ids(np.array([index]))
|
||||
except:
|
||||
@@ -598,44 +591,37 @@ class VectorStorageManager:
|
||||
logger.info("正在保存向量存储...")
|
||||
|
||||
# 保存记忆缓存
|
||||
cache_data = {
|
||||
memory_id: memory.to_dict()
|
||||
for memory_id, memory in self.memory_cache.items()
|
||||
}
|
||||
cache_data = {memory_id: memory.to_dict() for memory_id, memory in self.memory_cache.items()}
|
||||
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
with open(vector_cache_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
with open(vector_cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(self.vector_cache, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
mapping_data = {
|
||||
"memory_id_to_index": {
|
||||
str(memory_id): int(index)
|
||||
for memory_id, index in self.memory_id_to_index.items()
|
||||
str(memory_id): int(index) for memory_id, index in self.memory_id_to_index.items()
|
||||
},
|
||||
"index_to_memory_id": {
|
||||
str(index): memory_id
|
||||
for index, memory_id in self.index_to_memory_id.items()
|
||||
}
|
||||
"index_to_memory_id": {str(index): memory_id for index, memory_id in self.index_to_memory_id.items()},
|
||||
}
|
||||
with open(mapping_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
with open(mapping_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(mapping_data, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
# 保存FAISS索引(如果可用)
|
||||
if FAISS_AVAILABLE and hasattr(self.vector_index, 'save'):
|
||||
if FAISS_AVAILABLE and hasattr(self.vector_index, "save"):
|
||||
index_file = self.storage_path / "vector_index.faiss"
|
||||
faiss.write_index(self.vector_index, str(index_file))
|
||||
|
||||
# 保存统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
with open(stats_file, 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode('utf-8'))
|
||||
with open(stats_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(self.storage_stats, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
logger.info("✅ 向量存储保存完成")
|
||||
|
||||
@@ -650,36 +636,31 @@ class VectorStorageManager:
|
||||
# 加载记忆缓存
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
if cache_file.exists():
|
||||
with open(cache_file, 'r', encoding='utf-8') as f:
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
self.memory_cache = {
|
||||
memory_id: MemoryChunk.from_dict(memory_data)
|
||||
for memory_id, memory_data in cache_data.items()
|
||||
memory_id: MemoryChunk.from_dict(memory_data) for memory_id, memory_data in cache_data.items()
|
||||
}
|
||||
|
||||
# 加载向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
if vector_cache_file.exists():
|
||||
with open(vector_cache_file, 'r', encoding='utf-8') as f:
|
||||
with open(vector_cache_file, "r", encoding="utf-8") as f:
|
||||
self.vector_cache = orjson.loads(f.read())
|
||||
|
||||
# 加载映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
if mapping_file.exists():
|
||||
with open(mapping_file, 'r', encoding='utf-8') as f:
|
||||
with open(mapping_file, "r", encoding="utf-8") as f:
|
||||
mapping_data = orjson.loads(f.read())
|
||||
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
|
||||
self.memory_id_to_index = {
|
||||
str(memory_id): int(index)
|
||||
for memory_id, index in raw_memory_to_index.items()
|
||||
str(memory_id): int(index) for memory_id, index in raw_memory_to_index.items()
|
||||
}
|
||||
|
||||
raw_index_to_memory = mapping_data.get("index_to_memory_id", {})
|
||||
self.index_to_memory_id = {
|
||||
int(index): memory_id
|
||||
for index, memory_id in raw_index_to_memory.items()
|
||||
}
|
||||
self.index_to_memory_id = {int(index): memory_id for index, memory_id in raw_index_to_memory.items()}
|
||||
|
||||
# 加载FAISS索引(如果可用)
|
||||
index_loaded = False
|
||||
@@ -699,7 +680,7 @@ class VectorStorageManager:
|
||||
logger.warning(f"加载FAISS索引失败: {e},重新构建")
|
||||
else:
|
||||
logger.info("FAISS索引文件不存在,将重新构建")
|
||||
|
||||
|
||||
# 如果索引没有成功加载且有向量数据,则重建索引
|
||||
if not index_loaded and self.vector_cache:
|
||||
logger.info(f"检测到 {len(self.vector_cache)} 个向量缓存,重建索引")
|
||||
@@ -708,7 +689,7 @@ class VectorStorageManager:
|
||||
# 加载统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, 'r', encoding='utf-8') as f:
|
||||
with open(stats_file, "r", encoding="utf-8") as f:
|
||||
self.storage_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新向量计数
|
||||
@@ -738,7 +719,7 @@ class VectorStorageManager:
|
||||
# 准备向量数据
|
||||
memory_ids = []
|
||||
vectors = []
|
||||
|
||||
|
||||
for memory_id, embedding in self.vector_cache.items():
|
||||
if embedding and len(embedding) == self.config.dimension:
|
||||
memory_ids.append(memory_id)
|
||||
@@ -753,18 +734,18 @@ class VectorStorageManager:
|
||||
logger.info(f"准备重建 {len(vectors)} 个向量到索引")
|
||||
|
||||
# 批量添加向量到FAISS索引
|
||||
if hasattr(self.vector_index, 'add'):
|
||||
if hasattr(self.vector_index, "add"):
|
||||
# FAISS索引
|
||||
vector_array = np.array(vectors, dtype='float32')
|
||||
|
||||
vector_array = np.array(vectors, dtype="float32")
|
||||
|
||||
# 特殊处理IVF索引
|
||||
if self.config.index_type == "ivf" and hasattr(self.vector_index, 'train'):
|
||||
if self.config.index_type == "ivf" and hasattr(self.vector_index, "train"):
|
||||
logger.info("训练IVF索引...")
|
||||
self.vector_index.train(vector_array)
|
||||
|
||||
# 添加向量
|
||||
self.vector_index.add(vector_array)
|
||||
|
||||
|
||||
# 重建映射关系
|
||||
for i, memory_id in enumerate(memory_ids):
|
||||
self.memory_id_to_index[memory_id] = i
|
||||
@@ -772,15 +753,15 @@ class VectorStorageManager:
|
||||
|
||||
else:
|
||||
# 简单索引
|
||||
for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors)):
|
||||
for i, (memory_id, vector) in enumerate(zip(memory_ids, vectors, strict=False)):
|
||||
index_id = self.vector_index.add_vector(vector)
|
||||
self.memory_id_to_index[memory_id] = index_id
|
||||
self.index_to_memory_id[index_id] = memory_id
|
||||
|
||||
# 更新统计
|
||||
self.storage_stats["total_vectors"] = len(self.memory_id_to_index)
|
||||
|
||||
final_count = getattr(self.vector_index, 'ntotal', len(self.memory_id_to_index))
|
||||
|
||||
final_count = getattr(self.vector_index, "ntotal", len(self.memory_id_to_index))
|
||||
logger.info(f"✅ 向量索引重建完成,索引中向量数: {final_count}")
|
||||
|
||||
except Exception as e:
|
||||
@@ -875,7 +856,7 @@ class SimpleVectorIndex:
|
||||
def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
dot_product = sum(x * y for x, y in zip(v1, v2))
|
||||
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))
|
||||
norm1 = sum(x * x for x in v1) ** 0.5
|
||||
norm2 = sum(x * x for x in v2) ** 0.5
|
||||
|
||||
@@ -890,4 +871,4 @@ class SimpleVectorIndex:
|
||||
@property
|
||||
def ntotal(self) -> int:
|
||||
"""向量总数"""
|
||||
return len(self.vectors)
|
||||
return len(self.vectors)
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
import difflib
|
||||
import orjson
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
@@ -15,7 +14,7 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.memory_system.memory_manager import memory_manager, MemoryResult
|
||||
from src.chat.memory_system.memory_manager import MemoryResult
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
@@ -127,8 +126,8 @@ class MemoryActivator:
|
||||
for result in memory_results:
|
||||
# 检查是否已存在相似内容的记忆
|
||||
exists = any(
|
||||
m["content"] == result.content or
|
||||
difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
|
||||
m["content"] == result.content
|
||||
or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
|
||||
for m in self.running_memory
|
||||
)
|
||||
if not exists:
|
||||
@@ -140,7 +139,7 @@ class MemoryActivator:
|
||||
"confidence": result.confidence,
|
||||
"importance": result.importance,
|
||||
"source": result.source,
|
||||
"relevance_score": result.relevance_score # 添加相关度评分
|
||||
"relevance_score": result.relevance_score, # 添加相关度评分
|
||||
}
|
||||
self.running_memory.append(memory_entry)
|
||||
logger.debug(f"添加新记忆: {result.memory_type} - {result.content}")
|
||||
@@ -168,17 +167,14 @@ class MemoryActivator:
|
||||
return []
|
||||
|
||||
# 构建查询上下文
|
||||
context = {
|
||||
"keywords": keywords,
|
||||
"query_intent": "conversation_response"
|
||||
}
|
||||
context = {"keywords": keywords, "query_intent": "conversation_response"}
|
||||
|
||||
# 查询记忆
|
||||
memories = await memory_system.retrieve_relevant_memories(
|
||||
query_text=query_text,
|
||||
user_id="global", # 使用全局作用域
|
||||
context=context,
|
||||
limit=5
|
||||
limit=5,
|
||||
)
|
||||
|
||||
# 转换为 MemoryResult 格式
|
||||
@@ -191,7 +187,7 @@ class MemoryActivator:
|
||||
importance=memory.metadata.importance.value,
|
||||
timestamp=memory.metadata.created_at,
|
||||
source="unified_memory",
|
||||
relevance_score=memory.metadata.relevance_score
|
||||
relevance_score=memory.metadata.relevance_score,
|
||||
)
|
||||
memory_results.append(result)
|
||||
|
||||
@@ -214,16 +210,10 @@ class MemoryActivator:
|
||||
if not memory_system or memory_system.status.value != "ready":
|
||||
return None
|
||||
|
||||
context = {
|
||||
"query_intent": "instant_response",
|
||||
"chat_id": chat_id
|
||||
}
|
||||
context = {"query_intent": "instant_response", "chat_id": chat_id}
|
||||
|
||||
memories = await memory_system.retrieve_relevant_memories(
|
||||
query_text=target_message,
|
||||
user_id="global",
|
||||
context=context,
|
||||
limit=1
|
||||
query_text=target_message, user_id="global", context=context, limit=1
|
||||
)
|
||||
|
||||
if memories:
|
||||
@@ -248,4 +238,4 @@ memory_activator = MemoryActivator()
|
||||
# 兼容性别名
|
||||
enhanced_memory_activator = memory_activator
|
||||
|
||||
init_prompt()
|
||||
init_prompt()
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
import difflib
|
||||
import orjson
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
@@ -15,7 +14,7 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.memory_system.memory_manager import memory_manager, MemoryResult
|
||||
from src.chat.memory_system.memory_manager import MemoryResult
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
@@ -127,8 +126,8 @@ class MemoryActivator:
|
||||
for result in memory_results:
|
||||
# 检查是否已存在相似内容的记忆
|
||||
exists = any(
|
||||
m["content"] == result.content or
|
||||
difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
|
||||
m["content"] == result.content
|
||||
or difflib.SequenceMatcher(None, m["content"], result.content).ratio() >= 0.7
|
||||
for m in self.running_memory
|
||||
)
|
||||
if not exists:
|
||||
@@ -140,7 +139,7 @@ class MemoryActivator:
|
||||
"confidence": result.confidence,
|
||||
"importance": result.importance,
|
||||
"source": result.source,
|
||||
"relevance_score": result.relevance_score # 添加相关度评分
|
||||
"relevance_score": result.relevance_score, # 添加相关度评分
|
||||
}
|
||||
self.running_memory.append(memory_entry)
|
||||
logger.debug(f"添加新记忆: {result.memory_type} - {result.content}")
|
||||
@@ -168,17 +167,14 @@ class MemoryActivator:
|
||||
return []
|
||||
|
||||
# 构建查询上下文
|
||||
context = {
|
||||
"keywords": keywords,
|
||||
"query_intent": "conversation_response"
|
||||
}
|
||||
context = {"keywords": keywords, "query_intent": "conversation_response"}
|
||||
|
||||
# 查询记忆
|
||||
memories = await memory_system.retrieve_relevant_memories(
|
||||
query_text=query_text,
|
||||
user_id="global", # 使用全局作用域
|
||||
context=context,
|
||||
limit=5
|
||||
limit=5,
|
||||
)
|
||||
|
||||
# 转换为 MemoryResult 格式
|
||||
@@ -191,7 +187,7 @@ class MemoryActivator:
|
||||
importance=memory.metadata.importance.value,
|
||||
timestamp=memory.metadata.created_at,
|
||||
source="unified_memory",
|
||||
relevance_score=memory.metadata.relevance_score
|
||||
relevance_score=memory.metadata.relevance_score,
|
||||
)
|
||||
memory_results.append(result)
|
||||
|
||||
@@ -214,16 +210,10 @@ class MemoryActivator:
|
||||
if not memory_system or memory_system.status.value != "ready":
|
||||
return None
|
||||
|
||||
context = {
|
||||
"query_intent": "instant_response",
|
||||
"chat_id": chat_id
|
||||
}
|
||||
context = {"query_intent": "instant_response", "chat_id": chat_id}
|
||||
|
||||
memories = await memory_system.retrieve_relevant_memories(
|
||||
query_text=target_message,
|
||||
user_id="global",
|
||||
context=context,
|
||||
limit=1
|
||||
query_text=target_message, user_id="global", context=context, limit=1
|
||||
)
|
||||
|
||||
if memories:
|
||||
@@ -246,4 +236,4 @@ class MemoryActivator:
|
||||
memory_activator = MemoryActivator()
|
||||
|
||||
|
||||
init_prompt()
|
||||
init_prompt()
|
||||
|
||||
@@ -33,7 +33,7 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union, Type
|
||||
from typing import Any, Dict, List, Optional, Union, Type
|
||||
|
||||
import orjson
|
||||
|
||||
@@ -53,14 +53,15 @@ logger = get_logger(__name__)
|
||||
class ExtractionStrategy(Enum):
|
||||
"""提取策略"""
|
||||
|
||||
LLM_BASED = "llm_based" # 基于LLM的智能提取
|
||||
RULE_BASED = "rule_based" # 基于规则的提取
|
||||
HYBRID = "hybrid" # 混合策略
|
||||
LLM_BASED = "llm_based" # 基于LLM的智能提取
|
||||
RULE_BASED = "rule_based" # 基于规则的提取
|
||||
HYBRID = "hybrid" # 混合策略
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""提取结果"""
|
||||
|
||||
memories: List[MemoryChunk]
|
||||
confidence_scores: List[float]
|
||||
extraction_time: float
|
||||
@@ -80,15 +81,11 @@ class MemoryBuilder:
|
||||
"total_extractions": 0,
|
||||
"successful_extractions": 0,
|
||||
"failed_extractions": 0,
|
||||
"average_confidence": 0.0
|
||||
"average_confidence": 0.0,
|
||||
}
|
||||
|
||||
async def build_memories(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: float
|
||||
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
"""从对话中构建记忆"""
|
||||
start_time = time.time()
|
||||
@@ -119,19 +116,13 @@ class MemoryBuilder:
|
||||
raise
|
||||
|
||||
async def _extract_with_llm(
|
||||
self,
|
||||
text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: float
|
||||
self, text: str, context: Dict[str, Any], user_id: str, timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
"""使用LLM提取记忆"""
|
||||
try:
|
||||
prompt = self._build_llm_extraction_prompt(text, context)
|
||||
|
||||
response, _ = await self.llm_model.generate_response_async(
|
||||
prompt, temperature=0.3
|
||||
)
|
||||
response, _ = await self.llm_model.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
# 解析LLM响应
|
||||
memories = self._parse_llm_response(response, user_id, timestamp, context)
|
||||
@@ -342,16 +333,12 @@ class MemoryBuilder:
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start:end + 1].strip()
|
||||
return stripped[start : end + 1].strip()
|
||||
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _parse_llm_response(
|
||||
self,
|
||||
response: str,
|
||||
user_id: str,
|
||||
timestamp: float,
|
||||
context: Dict[str, Any]
|
||||
self, response: str, user_id: str, timestamp: float, context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
"""解析LLM响应"""
|
||||
if not response:
|
||||
@@ -366,9 +353,7 @@ class MemoryBuilder:
|
||||
data = orjson.loads(json_payload)
|
||||
except Exception as e:
|
||||
preview = json_payload[:200]
|
||||
raise MemoryExtractionError(
|
||||
f"LLM响应JSON解析失败: {e}, 片段: {preview}"
|
||||
) from e
|
||||
raise MemoryExtractionError(f"LLM响应JSON解析失败: {e}, 片段: {preview}") from e
|
||||
|
||||
memory_list = data.get("memories", [])
|
||||
|
||||
@@ -406,17 +391,15 @@ class MemoryBuilder:
|
||||
try:
|
||||
# 检查是否包含模糊代称
|
||||
display_text = mem_data.get("display", "")
|
||||
if any(ambiguous_term in display_text for ambiguous_term in ["用户", "user", "the user", "对方", "对手"]):
|
||||
if any(
|
||||
ambiguous_term in display_text for ambiguous_term in ["用户", "user", "the user", "对方", "对手"]
|
||||
):
|
||||
logger.debug(f"拒绝构建包含模糊代称的记忆,display字段: {display_text}")
|
||||
continue
|
||||
|
||||
subject_value = mem_data.get("subject")
|
||||
normalized_subject = self._normalize_subjects(
|
||||
subject_value,
|
||||
bot_identifiers,
|
||||
system_identifiers,
|
||||
default_subjects,
|
||||
bot_display
|
||||
subject_value, bot_identifiers, system_identifiers, default_subjects, bot_display
|
||||
)
|
||||
|
||||
if not normalized_subject:
|
||||
@@ -425,17 +408,11 @@ class MemoryBuilder:
|
||||
|
||||
# 创建记忆块
|
||||
importance_level = self._parse_enum_value(
|
||||
ImportanceLevel,
|
||||
mem_data.get("importance"),
|
||||
ImportanceLevel.NORMAL,
|
||||
"importance"
|
||||
ImportanceLevel, mem_data.get("importance"), ImportanceLevel.NORMAL, "importance"
|
||||
)
|
||||
|
||||
confidence_level = self._parse_enum_value(
|
||||
ConfidenceLevel,
|
||||
mem_data.get("confidence"),
|
||||
ConfidenceLevel.MEDIUM,
|
||||
"confidence"
|
||||
ConfidenceLevel, mem_data.get("confidence"), ConfidenceLevel.MEDIUM, "confidence"
|
||||
)
|
||||
|
||||
predicate_value = mem_data.get("predicate", "")
|
||||
@@ -457,7 +434,7 @@ class MemoryBuilder:
|
||||
source_context=mem_data.get("reasoning", ""),
|
||||
importance=importance_level,
|
||||
confidence=confidence_level,
|
||||
display=display_text
|
||||
display=display_text,
|
||||
)
|
||||
|
||||
if used_fallback_display:
|
||||
@@ -483,13 +460,7 @@ class MemoryBuilder:
|
||||
|
||||
return memories
|
||||
|
||||
def _parse_enum_value(
|
||||
self,
|
||||
enum_cls: Type[Enum],
|
||||
raw_value: Any,
|
||||
default: Enum,
|
||||
field_name: str
|
||||
) -> Enum:
|
||||
def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
|
||||
"""解析枚举值,兼容数字/字符串表示"""
|
||||
if isinstance(raw_value, enum_cls):
|
||||
return raw_value
|
||||
@@ -533,12 +504,14 @@ class MemoryBuilder:
|
||||
try:
|
||||
return enum_cls(raw_value)
|
||||
except Exception:
|
||||
logger.debug("%s=%s 类型 %s 无法解析为 %s,使用默认值 %s",
|
||||
field_name,
|
||||
raw_value,
|
||||
type(raw_value).__name__,
|
||||
enum_cls.__name__,
|
||||
default.name)
|
||||
logger.debug(
|
||||
"%s=%s 类型 %s 无法解析为 %s,使用默认值 %s",
|
||||
field_name,
|
||||
raw_value,
|
||||
type(raw_value).__name__,
|
||||
enum_cls.__name__,
|
||||
default.name,
|
||||
)
|
||||
return default
|
||||
|
||||
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||
@@ -606,7 +579,7 @@ class MemoryBuilder:
|
||||
"members",
|
||||
"member_names",
|
||||
"mention_users",
|
||||
"audiences"
|
||||
"audiences",
|
||||
]
|
||||
|
||||
for key in candidate_keys:
|
||||
@@ -727,7 +700,7 @@ class MemoryBuilder:
|
||||
bot_identifiers: set[str],
|
||||
system_identifiers: set[str],
|
||||
default_subjects: List[str],
|
||||
bot_display: Optional[str] = None
|
||||
bot_display: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
defaults = default_subjects or ["对话参与者"]
|
||||
|
||||
@@ -800,7 +773,9 @@ class MemoryBuilder:
|
||||
return obj.strip() or None
|
||||
return None
|
||||
|
||||
def _compose_display_text(self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]) -> str:
|
||||
def _compose_display_text(
|
||||
self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]
|
||||
) -> str:
|
||||
subject_phrase = "、".join(subjects) if subjects else "对话参与者"
|
||||
predicate = (predicate or "").strip()
|
||||
|
||||
@@ -866,11 +841,7 @@ class MemoryBuilder:
|
||||
return f"{subject_phrase}{predicate}".strip()
|
||||
return subject_phrase
|
||||
|
||||
def _validate_and_enhance_memories(
|
||||
self,
|
||||
memories: List[MemoryChunk],
|
||||
context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]:
|
||||
"""验证和增强记忆"""
|
||||
validated_memories = []
|
||||
|
||||
@@ -905,11 +876,7 @@ class MemoryBuilder:
|
||||
|
||||
return True
|
||||
|
||||
def _enhance_memory(
|
||||
self,
|
||||
memory: MemoryChunk,
|
||||
context: Dict[str, Any]
|
||||
) -> MemoryChunk:
|
||||
def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk:
|
||||
"""增强记忆块"""
|
||||
# 时间规范化处理
|
||||
self._normalize_time_in_memory(memory)
|
||||
@@ -919,7 +886,7 @@ class MemoryBuilder:
|
||||
memory.temporal_context = {
|
||||
"timestamp": memory.metadata.created_at,
|
||||
"timezone": context.get("timezone", "UTC"),
|
||||
"day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A")
|
||||
"day_of_week": datetime.fromtimestamp(memory.metadata.created_at).strftime("%A"),
|
||||
}
|
||||
|
||||
# 添加情感上下文(如果有)
|
||||
@@ -941,22 +908,22 @@ class MemoryBuilder:
|
||||
|
||||
# 定义相对时间映射
|
||||
relative_time_patterns = {
|
||||
r'今天|今日': current_time.strftime('%Y-%m-%d'),
|
||||
r'昨天|昨日': (current_time - timedelta(days=1)).strftime('%Y-%m-%d'),
|
||||
r'明天|明日': (current_time + timedelta(days=1)).strftime('%Y-%m-%d'),
|
||||
r'后天': (current_time + timedelta(days=2)).strftime('%Y-%m-%d'),
|
||||
r'大后天': (current_time + timedelta(days=3)).strftime('%Y-%m-%d'),
|
||||
r'前天': (current_time - timedelta(days=2)).strftime('%Y-%m-%d'),
|
||||
r'大前天': (current_time - timedelta(days=3)).strftime('%Y-%m-%d'),
|
||||
r'本周|这周|这星期': current_time.strftime('%Y-%m-%d'),
|
||||
r'上周|上星期': (current_time - timedelta(weeks=1)).strftime('%Y-%m-%d'),
|
||||
r'下周|下星期': (current_time + timedelta(weeks=1)).strftime('%Y-%m-%d'),
|
||||
r'本月|这个月': current_time.strftime('%Y-%m-01'),
|
||||
r'上月|上个月': (current_time.replace(day=1) - timedelta(days=1)).strftime('%Y-%m-01'),
|
||||
r'下月|下个月': (current_time.replace(day=1) + timedelta(days=32)).replace(day=1).strftime('%Y-%m-01'),
|
||||
r'今年|今年': current_time.strftime('%Y'),
|
||||
r'去年|上一年': str(current_time.year - 1),
|
||||
r'明年|下一年': str(current_time.year + 1),
|
||||
r"今天|今日": current_time.strftime("%Y-%m-%d"),
|
||||
r"昨天|昨日": (current_time - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
r"明天|明日": (current_time + timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
r"后天": (current_time + timedelta(days=2)).strftime("%Y-%m-%d"),
|
||||
r"大后天": (current_time + timedelta(days=3)).strftime("%Y-%m-%d"),
|
||||
r"前天": (current_time - timedelta(days=2)).strftime("%Y-%m-%d"),
|
||||
r"大前天": (current_time - timedelta(days=3)).strftime("%Y-%m-%d"),
|
||||
r"本周|这周|这星期": current_time.strftime("%Y-%m-%d"),
|
||||
r"上周|上星期": (current_time - timedelta(weeks=1)).strftime("%Y-%m-%d"),
|
||||
r"下周|下星期": (current_time + timedelta(weeks=1)).strftime("%Y-%m-%d"),
|
||||
r"本月|这个月": current_time.strftime("%Y-%m-01"),
|
||||
r"上月|上个月": (current_time.replace(day=1) - timedelta(days=1)).strftime("%Y-%m-01"),
|
||||
r"下月|下个月": (current_time.replace(day=1) + timedelta(days=32)).replace(day=1).strftime("%Y-%m-01"),
|
||||
r"今年|今年": current_time.strftime("%Y"),
|
||||
r"去年|上一年": str(current_time.year - 1),
|
||||
r"明年|下一年": str(current_time.year + 1),
|
||||
}
|
||||
|
||||
def _normalize_value(value):
|
||||
@@ -1009,10 +976,14 @@ class MemoryBuilder:
|
||||
|
||||
# 更新平均置信度
|
||||
if self.extraction_stats["successful_extractions"] > 0:
|
||||
total_confidence = self.extraction_stats["average_confidence"] * (self.extraction_stats["successful_extractions"] - success_count)
|
||||
total_confidence = self.extraction_stats["average_confidence"] * (
|
||||
self.extraction_stats["successful_extractions"] - success_count
|
||||
)
|
||||
# 假设新记忆的平均置信度为0.8
|
||||
total_confidence += 0.8 * success_count
|
||||
self.extraction_stats["average_confidence"] = total_confidence / self.extraction_stats["successful_extractions"]
|
||||
self.extraction_stats["average_confidence"] = (
|
||||
total_confidence / self.extraction_stats["successful_extractions"]
|
||||
)
|
||||
|
||||
def get_extraction_stats(self) -> Dict[str, Any]:
|
||||
"""获取提取统计信息"""
|
||||
@@ -1024,5 +995,5 @@ class MemoryBuilder:
|
||||
"total_extractions": 0,
|
||||
"successful_extractions": 0,
|
||||
"failed_extractions": 0,
|
||||
"average_confidence": 0.0
|
||||
}
|
||||
"average_confidence": 0.0,
|
||||
}
|
||||
|
||||
@@ -8,8 +8,7 @@ import time
|
||||
import uuid
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Any, Union, Iterable
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import hashlib
|
||||
|
||||
@@ -21,33 +20,36 @@ logger = get_logger(__name__)
|
||||
|
||||
class MemoryType(Enum):
|
||||
"""记忆类型分类"""
|
||||
PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等)
|
||||
EVENT = "event" # 事件(重要经历、约会等)
|
||||
PREFERENCE = "preference" # 偏好(喜好、习惯等)
|
||||
OPINION = "opinion" # 观点(对事物的看法)
|
||||
RELATIONSHIP = "relationship" # 关系(与他人的关系)
|
||||
EMOTION = "emotion" # 情感状态
|
||||
KNOWLEDGE = "knowledge" # 知识信息
|
||||
SKILL = "skill" # 技能能力
|
||||
GOAL = "goal" # 目标计划
|
||||
EXPERIENCE = "experience" # 经验教训
|
||||
CONTEXTUAL = "contextual" # 上下文信息
|
||||
|
||||
PERSONAL_FACT = "personal_fact" # 个人事实(姓名、职业、住址等)
|
||||
EVENT = "event" # 事件(重要经历、约会等)
|
||||
PREFERENCE = "preference" # 偏好(喜好、习惯等)
|
||||
OPINION = "opinion" # 观点(对事物的看法)
|
||||
RELATIONSHIP = "relationship" # 关系(与他人的关系)
|
||||
EMOTION = "emotion" # 情感状态
|
||||
KNOWLEDGE = "knowledge" # 知识信息
|
||||
SKILL = "skill" # 技能能力
|
||||
GOAL = "goal" # 目标计划
|
||||
EXPERIENCE = "experience" # 经验教训
|
||||
CONTEXTUAL = "contextual" # 上下文信息
|
||||
|
||||
|
||||
class ConfidenceLevel(Enum):
|
||||
"""置信度等级"""
|
||||
LOW = 1 # 低置信度,可能不准确
|
||||
MEDIUM = 2 # 中等置信度,有一定依据
|
||||
HIGH = 3 # 高置信度,有明确来源
|
||||
VERIFIED = 4 # 已验证,非常可靠
|
||||
|
||||
LOW = 1 # 低置信度,可能不准确
|
||||
MEDIUM = 2 # 中等置信度,有一定依据
|
||||
HIGH = 3 # 高置信度,有明确来源
|
||||
VERIFIED = 4 # 已验证,非常可靠
|
||||
|
||||
|
||||
class ImportanceLevel(Enum):
|
||||
"""重要性等级"""
|
||||
LOW = 1 # 低重要性,普通信息
|
||||
NORMAL = 2 # 一般重要性,日常信息
|
||||
HIGH = 3 # 高重要性,重要信息
|
||||
CRITICAL = 4 # 关键重要性,核心信息
|
||||
|
||||
LOW = 1 # 低重要性,普通信息
|
||||
NORMAL = 2 # 一般重要性,日常信息
|
||||
HIGH = 3 # 高重要性,重要信息
|
||||
CRITICAL = 4 # 关键重要性,核心信息
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -61,12 +63,7 @@ class ContentStructure:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"subject": self.subject,
|
||||
"predicate": self.predicate,
|
||||
"object": self.object,
|
||||
"display": self.display
|
||||
}
|
||||
return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure":
|
||||
@@ -75,7 +72,7 @@ class ContentStructure:
|
||||
subject=data.get("subject", ""),
|
||||
predicate=data.get("predicate", ""),
|
||||
object=data.get("object", ""),
|
||||
display=data.get("display", "")
|
||||
display=data.get("display", ""),
|
||||
)
|
||||
|
||||
def to_subject_list(self) -> List[str]:
|
||||
@@ -98,24 +95,25 @@ class ContentStructure:
|
||||
@dataclass
|
||||
class MemoryMetadata:
|
||||
"""记忆元数据 - 简化版本"""
|
||||
|
||||
# 基础信息
|
||||
memory_id: str # 唯一标识符
|
||||
user_id: str # 用户ID
|
||||
chat_id: Optional[str] = None # 聊天ID(群聊或私聊)
|
||||
memory_id: str # 唯一标识符
|
||||
user_id: str # 用户ID
|
||||
chat_id: Optional[str] = None # 聊天ID(群聊或私聊)
|
||||
|
||||
# 时间信息
|
||||
created_at: float = 0.0 # 创建时间戳
|
||||
last_accessed: float = 0.0 # 最后访问时间
|
||||
last_modified: float = 0.0 # 最后修改时间
|
||||
created_at: float = 0.0 # 创建时间戳
|
||||
last_accessed: float = 0.0 # 最后访问时间
|
||||
last_modified: float = 0.0 # 最后修改时间
|
||||
|
||||
# 激活频率管理
|
||||
last_activation_time: float = 0.0 # 最后激活时间
|
||||
activation_frequency: int = 0 # 激活频率(单位时间内的激活次数)
|
||||
total_activations: int = 0 # 总激活次数
|
||||
last_activation_time: float = 0.0 # 最后激活时间
|
||||
activation_frequency: int = 0 # 激活频率(单位时间内的激活次数)
|
||||
total_activations: int = 0 # 总激活次数
|
||||
|
||||
# 统计信息
|
||||
access_count: int = 0 # 访问次数
|
||||
relevance_score: float = 0.0 # 相关度评分
|
||||
access_count: int = 0 # 访问次数
|
||||
relevance_score: float = 0.0 # 相关度评分
|
||||
|
||||
# 信心和重要性(核心字段)
|
||||
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM
|
||||
@@ -123,10 +121,10 @@ class MemoryMetadata:
|
||||
|
||||
# 遗忘机制相关
|
||||
forgetting_threshold: float = 0.0 # 遗忘阈值(动态计算)
|
||||
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
|
||||
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
|
||||
|
||||
# 来源信息
|
||||
source_context: Optional[str] = None # 来源上下文片段
|
||||
source_context: Optional[str] = None # 来源上下文片段
|
||||
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
|
||||
source: Optional[str] = None
|
||||
|
||||
@@ -153,13 +151,13 @@ class MemoryMetadata:
|
||||
self.last_forgetting_check = current_time
|
||||
|
||||
# 兼容性:如果旧字段 source 被使用,保证 source 与 source_context 同步
|
||||
if not getattr(self, 'source', None) and getattr(self, 'source_context', None):
|
||||
if not getattr(self, "source", None) and getattr(self, "source_context", None):
|
||||
try:
|
||||
self.source = str(self.source_context)
|
||||
except Exception:
|
||||
self.source = None
|
||||
# 如果有 source 字段但 source_context 为空,也同步回去
|
||||
if not getattr(self, 'source_context', None) and getattr(self, 'source', None):
|
||||
if not getattr(self, "source_context", None) and getattr(self, "source", None):
|
||||
try:
|
||||
self.source_context = str(self.source)
|
||||
except Exception:
|
||||
@@ -177,7 +175,6 @@ class MemoryMetadata:
|
||||
|
||||
def _update_activation_frequency(self, current_time: float):
|
||||
"""更新激活频率(24小时内的激活次数)"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# 如果超过24小时,重置激活频率
|
||||
if current_time - self.last_activation_time > 86400: # 24小时 = 86400秒
|
||||
@@ -251,7 +248,7 @@ class MemoryMetadata:
|
||||
"importance": self.importance.value,
|
||||
"forgetting_threshold": self.forgetting_threshold,
|
||||
"last_forgetting_check": self.last_forgetting_check,
|
||||
"source_context": self.source_context
|
||||
"source_context": self.source_context,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -273,7 +270,7 @@ class MemoryMetadata:
|
||||
importance=ImportanceLevel(data.get("importance", ImportanceLevel.NORMAL.value)),
|
||||
forgetting_threshold=data.get("forgetting_threshold", 0.0),
|
||||
last_forgetting_check=data.get("last_forgetting_check", 0),
|
||||
source_context=data.get("source_context")
|
||||
source_context=data.get("source_context"),
|
||||
)
|
||||
|
||||
|
||||
@@ -285,21 +282,21 @@ class MemoryChunk:
|
||||
metadata: MemoryMetadata
|
||||
|
||||
# 内容结构
|
||||
content: ContentStructure # 主谓宾结构
|
||||
memory_type: MemoryType # 记忆类型
|
||||
content: ContentStructure # 主谓宾结构
|
||||
memory_type: MemoryType # 记忆类型
|
||||
|
||||
# 扩展信息
|
||||
keywords: List[str] = field(default_factory=list) # 关键词列表
|
||||
tags: List[str] = field(default_factory=list) # 标签列表
|
||||
categories: List[str] = field(default_factory=list) # 分类列表
|
||||
keywords: List[str] = field(default_factory=list) # 关键词列表
|
||||
tags: List[str] = field(default_factory=list) # 标签列表
|
||||
categories: List[str] = field(default_factory=list) # 分类列表
|
||||
|
||||
# 语义信息
|
||||
embedding: Optional[List[float]] = None # 语义向量
|
||||
semantic_hash: Optional[str] = None # 语义哈希值
|
||||
embedding: Optional[List[float]] = None # 语义向量
|
||||
semantic_hash: Optional[str] = None # 语义哈希值
|
||||
|
||||
# 关联信息
|
||||
related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表
|
||||
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
|
||||
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
@@ -317,7 +314,7 @@ class MemoryChunk:
|
||||
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
|
||||
|
||||
hash_input = f"{content_str}|{embedding_str}"
|
||||
hash_object = hashlib.sha256(hash_input.encode('utf-8'))
|
||||
hash_object = hashlib.sha256(hash_input.encode("utf-8"))
|
||||
self.semantic_hash = hash_object.hexdigest()[:16]
|
||||
|
||||
except Exception as e:
|
||||
@@ -430,7 +427,7 @@ class MemoryChunk:
|
||||
"embedding": self.embedding,
|
||||
"semantic_hash": self.semantic_hash,
|
||||
"related_memories": self.related_memories,
|
||||
"temporal_context": self.temporal_context
|
||||
"temporal_context": self.temporal_context,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -449,14 +446,14 @@ class MemoryChunk:
|
||||
embedding=data.get("embedding"),
|
||||
semantic_hash=data.get("semantic_hash"),
|
||||
related_memories=data.get("related_memories", []),
|
||||
temporal_context=data.get("temporal_context")
|
||||
temporal_context=data.get("temporal_context"),
|
||||
)
|
||||
|
||||
return chunk
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""转换为JSON字符串"""
|
||||
return orjson.dumps(self.to_dict(), ensure_ascii=False).decode('utf-8')
|
||||
return orjson.dumps(self.to_dict(), ensure_ascii=False).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "MemoryChunk":
|
||||
@@ -530,7 +527,7 @@ class MemoryChunk:
|
||||
MemoryType.SKILL: "🛠️",
|
||||
MemoryType.GOAL: "🎯",
|
||||
MemoryType.EXPERIENCE: "💡",
|
||||
MemoryType.CONTEXTUAL: "📝"
|
||||
MemoryType.CONTEXTUAL: "📝",
|
||||
}
|
||||
|
||||
emoji = type_emoji.get(self.memory_type, "📝")
|
||||
@@ -581,7 +578,7 @@ def create_memory_chunk(
|
||||
importance: ImportanceLevel = ImportanceLevel.NORMAL,
|
||||
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
|
||||
display: Optional[str] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> MemoryChunk:
|
||||
"""便捷的内存块创建函数"""
|
||||
metadata = MemoryMetadata(
|
||||
@@ -593,7 +590,7 @@ def create_memory_chunk(
|
||||
last_modified=0,
|
||||
confidence=confidence,
|
||||
importance=importance,
|
||||
source_context=source_context
|
||||
source_context=source_context,
|
||||
)
|
||||
|
||||
subjects: List[str]
|
||||
@@ -607,18 +604,8 @@ def create_memory_chunk(
|
||||
|
||||
display_text = display or _build_display_text(subjects, predicate, obj)
|
||||
|
||||
content = ContentStructure(
|
||||
subject=subject_payload,
|
||||
predicate=predicate,
|
||||
object=obj,
|
||||
display=display_text
|
||||
)
|
||||
content = ContentStructure(subject=subject_payload, predicate=predicate, object=obj, display=display_text)
|
||||
|
||||
chunk = MemoryChunk(
|
||||
metadata=metadata,
|
||||
content=content,
|
||||
memory_type=memory_type,
|
||||
**kwargs
|
||||
)
|
||||
chunk = MemoryChunk(metadata=metadata, content=content, memory_type=memory_type, **kwargs)
|
||||
|
||||
return chunk
|
||||
return chunk
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Set, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -19,6 +19,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class ForgettingStats:
|
||||
"""遗忘统计信息"""
|
||||
|
||||
total_checked: int = 0
|
||||
marked_for_forgetting: int = 0
|
||||
actually_forgotten: int = 0
|
||||
@@ -30,34 +31,35 @@ class ForgettingStats:
|
||||
@dataclass
|
||||
class ForgettingConfig:
|
||||
"""遗忘引擎配置"""
|
||||
|
||||
# 检查频率配置
|
||||
check_interval_hours: int = 24 # 定期检查间隔(小时)
|
||||
batch_size: int = 100 # 批处理大小
|
||||
check_interval_hours: int = 24 # 定期检查间隔(小时)
|
||||
batch_size: int = 100 # 批处理大小
|
||||
|
||||
# 遗忘阈值配置
|
||||
base_forgetting_days: float = 30.0 # 基础遗忘天数
|
||||
min_forgetting_days: float = 7.0 # 最小遗忘天数
|
||||
max_forgetting_days: float = 365.0 # 最大遗忘天数
|
||||
base_forgetting_days: float = 30.0 # 基础遗忘天数
|
||||
min_forgetting_days: float = 7.0 # 最小遗忘天数
|
||||
max_forgetting_days: float = 365.0 # 最大遗忘天数
|
||||
|
||||
# 重要程度权重
|
||||
critical_importance_bonus: float = 45.0 # 关键重要性额外天数
|
||||
high_importance_bonus: float = 30.0 # 高重要性额外天数
|
||||
normal_importance_bonus: float = 15.0 # 一般重要性额外天数
|
||||
low_importance_bonus: float = 0.0 # 低重要性额外天数
|
||||
high_importance_bonus: float = 30.0 # 高重要性额外天数
|
||||
normal_importance_bonus: float = 15.0 # 一般重要性额外天数
|
||||
low_importance_bonus: float = 0.0 # 低重要性额外天数
|
||||
|
||||
# 置信度权重
|
||||
verified_confidence_bonus: float = 30.0 # 已验证置信度额外天数
|
||||
high_confidence_bonus: float = 20.0 # 高置信度额外天数
|
||||
medium_confidence_bonus: float = 10.0 # 中等置信度额外天数
|
||||
low_confidence_bonus: float = 0.0 # 低置信度额外天数
|
||||
high_confidence_bonus: float = 20.0 # 高置信度额外天数
|
||||
medium_confidence_bonus: float = 10.0 # 中等置信度额外天数
|
||||
low_confidence_bonus: float = 0.0 # 低置信度额外天数
|
||||
|
||||
# 激活频率权重
|
||||
activation_frequency_weight: float = 0.5 # 每次激活增加的天数权重
|
||||
max_frequency_bonus: float = 10.0 # 最大激活频率奖励天数
|
||||
max_frequency_bonus: float = 10.0 # 最大激活频率奖励天数
|
||||
|
||||
# 休眠配置
|
||||
dormant_threshold_days: int = 90 # 休眠状态判定天数
|
||||
force_forget_dormant_days: int = 180 # 强制遗忘休眠记忆的天数
|
||||
dormant_threshold_days: int = 90 # 休眠状态判定天数
|
||||
force_forget_dormant_days: int = 180 # 强制遗忘休眠记忆的天数
|
||||
|
||||
|
||||
class MemoryForgettingEngine:
|
||||
@@ -107,13 +109,12 @@ class MemoryForgettingEngine:
|
||||
# 激活频率权重
|
||||
frequency_bonus = min(
|
||||
memory.metadata.activation_frequency * self.config.activation_frequency_weight,
|
||||
self.config.max_frequency_bonus
|
||||
self.config.max_frequency_bonus,
|
||||
)
|
||||
threshold += frequency_bonus
|
||||
|
||||
# 确保在合理范围内
|
||||
return max(self.config.min_forgetting_days,
|
||||
min(threshold, self.config.max_forgetting_days))
|
||||
return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days))
|
||||
|
||||
def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
|
||||
"""
|
||||
@@ -265,8 +266,8 @@ class MemoryForgettingEngine:
|
||||
"actually_forgotten": self.stats.actually_forgotten,
|
||||
"dormant_memories": self.stats.dormant_memories,
|
||||
"check_duration": self.stats.check_duration,
|
||||
"last_check_time": self.stats.last_check_time
|
||||
}
|
||||
"last_check_time": self.stats.last_check_time,
|
||||
},
|
||||
}
|
||||
|
||||
def is_forgetting_check_needed(self) -> bool:
|
||||
@@ -302,7 +303,9 @@ class MemoryForgettingEngine:
|
||||
|
||||
# 如果启用自动清理,执行实际的遗忘操作
|
||||
if enable_auto_cleanup and (result["normal_forgetting"] or result["force_forgetting"]):
|
||||
logger.info(f"检测到 {len(result['normal_forgetting'])} 条普通遗忘和 {len(result['force_forgetting'])} 条强制遗忘记忆")
|
||||
logger.info(
|
||||
f"检测到 {len(result['normal_forgetting'])} 条普通遗忘和 {len(result['force_forgetting'])} 条强制遗忘记忆"
|
||||
)
|
||||
# 这里可以调用实际的删除逻辑
|
||||
# await self.cleanup_forgotten_memories(result["normal_forgetting"] + result["force_forgetting"])
|
||||
|
||||
@@ -318,14 +321,16 @@ class MemoryForgettingEngine:
|
||||
"marked_for_forgetting": self.stats.marked_for_forgetting,
|
||||
"actually_forgotten": self.stats.actually_forgotten,
|
||||
"dormant_memories": self.stats.dormant_memories,
|
||||
"last_check_time": datetime.fromtimestamp(self.stats.last_check_time).isoformat() if self.stats.last_check_time else None,
|
||||
"last_check_time": datetime.fromtimestamp(self.stats.last_check_time).isoformat()
|
||||
if self.stats.last_check_time
|
||||
else None,
|
||||
"last_check_duration": self.stats.check_duration,
|
||||
"config": {
|
||||
"check_interval_hours": self.config.check_interval_hours,
|
||||
"base_forgetting_days": self.config.base_forgetting_days,
|
||||
"min_forgetting_days": self.config.min_forgetting_days,
|
||||
"max_forgetting_days": self.config.max_forgetting_days
|
||||
}
|
||||
"max_forgetting_days": self.config.max_forgetting_days,
|
||||
},
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
@@ -349,4 +354,4 @@ memory_forgetting_engine = MemoryForgettingEngine()
|
||||
|
||||
def get_memory_forgetting_engine() -> MemoryForgettingEngine:
|
||||
"""获取全局遗忘引擎实例"""
|
||||
return memory_forgetting_engine
|
||||
return memory_forgetting_engine
|
||||
|
||||
@@ -5,14 +5,12 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
|
||||
)
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -20,6 +18,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class FusionResult:
|
||||
"""融合结果"""
|
||||
|
||||
original_count: int
|
||||
fused_count: int
|
||||
removed_duplicates: int
|
||||
@@ -31,6 +30,7 @@ class FusionResult:
|
||||
@dataclass
|
||||
class DuplicateGroup:
|
||||
"""重复记忆组"""
|
||||
|
||||
group_id: str
|
||||
memories: List[MemoryChunk]
|
||||
similarity_matrix: List[List[float]]
|
||||
@@ -46,22 +46,20 @@ class MemoryFusionEngine:
|
||||
"total_fusions": 0,
|
||||
"memories_fused": 0,
|
||||
"duplicates_removed": 0,
|
||||
"average_similarity": 0.0
|
||||
"average_similarity": 0.0,
|
||||
}
|
||||
|
||||
# 融合策略配置
|
||||
self.fusion_strategies = {
|
||||
"semantic_similarity": True, # 语义相似性融合
|
||||
"temporal_proximity": True, # 时间接近性融合
|
||||
"logical_consistency": True, # 逻辑一致性融合
|
||||
"confidence_boosting": True, # 置信度提升
|
||||
"importance_preservation": True # 重要性保持
|
||||
"semantic_similarity": True, # 语义相似性融合
|
||||
"temporal_proximity": True, # 时间接近性融合
|
||||
"logical_consistency": True, # 逻辑一致性融合
|
||||
"confidence_boosting": True, # 置信度提升
|
||||
"importance_preservation": True, # 重要性保持
|
||||
}
|
||||
|
||||
async def fuse_memories(
|
||||
self,
|
||||
new_memories: List[MemoryChunk],
|
||||
existing_memories: Optional[List[MemoryChunk]] = None
|
||||
self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""融合记忆列表"""
|
||||
start_time = time.time()
|
||||
@@ -73,9 +71,7 @@ class MemoryFusionEngine:
|
||||
logger.info(f"开始记忆融合,新记忆: {len(new_memories)},现有记忆: {len(existing_memories or [])}")
|
||||
|
||||
# 1. 检测重复记忆组
|
||||
duplicate_groups = await self._detect_duplicate_groups(
|
||||
new_memories, existing_memories or []
|
||||
)
|
||||
duplicate_groups = await self._detect_duplicate_groups(new_memories, existing_memories or [])
|
||||
|
||||
if not duplicate_groups:
|
||||
fusion_time = time.time() - start_time
|
||||
@@ -110,9 +106,7 @@ class MemoryFusionEngine:
|
||||
return new_memories # 失败时返回原始记忆
|
||||
|
||||
async def _detect_duplicate_groups(
|
||||
self,
|
||||
new_memories: List[MemoryChunk],
|
||||
existing_memories: List[MemoryChunk]
|
||||
self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk]
|
||||
) -> List[DuplicateGroup]:
|
||||
"""检测重复记忆组"""
|
||||
all_memories = new_memories + existing_memories
|
||||
@@ -125,16 +119,12 @@ class MemoryFusionEngine:
|
||||
continue
|
||||
|
||||
# 创建新的重复组
|
||||
group = DuplicateGroup(
|
||||
group_id=f"group_{len(groups)}",
|
||||
memories=[memory1],
|
||||
similarity_matrix=[[1.0]]
|
||||
)
|
||||
group = DuplicateGroup(group_id=f"group_{len(groups)}", memories=[memory1], similarity_matrix=[[1.0]])
|
||||
|
||||
processed_ids.add(memory1.memory_id)
|
||||
|
||||
# 寻找相似记忆
|
||||
for j, memory2 in enumerate(all_memories[i+1:], i+1):
|
||||
for j, memory2 in enumerate(all_memories[i + 1 :], i + 1):
|
||||
if memory2.memory_id in processed_ids:
|
||||
continue
|
||||
|
||||
@@ -182,9 +172,7 @@ class MemoryFusionEngine:
|
||||
|
||||
# 5. 时间接近性
|
||||
if self.fusion_strategies["temporal_proximity"]:
|
||||
temporal_sim = self._calculate_temporal_similarity(
|
||||
mem1.metadata.created_at, mem2.metadata.created_at
|
||||
)
|
||||
temporal_sim = self._calculate_temporal_similarity(mem1.metadata.created_at, mem2.metadata.created_at)
|
||||
similarity_scores.append(("temporal", temporal_sim))
|
||||
|
||||
# 6. 逻辑一致性
|
||||
@@ -193,14 +181,7 @@ class MemoryFusionEngine:
|
||||
similarity_scores.append(("logical", logical_sim))
|
||||
|
||||
# 计算加权平均相似度
|
||||
weights = {
|
||||
"semantic": 0.35,
|
||||
"text": 0.25,
|
||||
"keyword": 0.15,
|
||||
"type": 0.10,
|
||||
"temporal": 0.10,
|
||||
"logical": 0.05
|
||||
}
|
||||
weights = {"semantic": 0.35, "text": 0.25, "keyword": 0.15, "type": 0.10, "temporal": 0.10, "logical": 0.05}
|
||||
|
||||
weighted_sum = 0.0
|
||||
total_weight = 0.0
|
||||
@@ -276,9 +257,7 @@ class MemoryFusionEngine:
|
||||
|
||||
# 宾语相似性
|
||||
if isinstance(mem1.content.object, str) and isinstance(mem2.content.object, str):
|
||||
object_sim = self._calculate_text_similarity(
|
||||
str(mem1.content.object), str(mem2.content.object)
|
||||
)
|
||||
object_sim = self._calculate_text_similarity(str(mem1.content.object), str(mem2.content.object))
|
||||
consistency_score += object_sim * 0.3
|
||||
|
||||
return consistency_score
|
||||
@@ -349,11 +328,7 @@ class MemoryFusionEngine:
|
||||
# 返回置信度最高的记忆
|
||||
return max(group.memories, key=lambda m: m.metadata.confidence.value)
|
||||
|
||||
async def _merge_memory_attributes(
|
||||
self,
|
||||
base_memory: MemoryChunk,
|
||||
memories: List[MemoryChunk]
|
||||
) -> MemoryChunk:
|
||||
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk:
|
||||
"""合并记忆属性"""
|
||||
# 创建基础记忆的深拷贝
|
||||
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
|
||||
@@ -436,7 +411,7 @@ class MemoryFusionEngine:
|
||||
"earliest_timestamp": earliest_time,
|
||||
"latest_timestamp": latest_time,
|
||||
"time_span_hours": (latest_time - earliest_time) / 3600,
|
||||
"source_memories": len(memories)
|
||||
"source_memories": len(memories),
|
||||
}
|
||||
|
||||
# 合并其他上下文信息
|
||||
@@ -451,9 +426,7 @@ class MemoryFusionEngine:
|
||||
return merged_context
|
||||
|
||||
async def incremental_fusion(
|
||||
self,
|
||||
new_memory: MemoryChunk,
|
||||
existing_memories: List[MemoryChunk]
|
||||
self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk]
|
||||
) -> Tuple[MemoryChunk, List[MemoryChunk]]:
|
||||
"""增量融合(单个新记忆与现有记忆融合)"""
|
||||
# 寻找相似记忆
|
||||
@@ -478,7 +451,7 @@ class MemoryFusionEngine:
|
||||
group = DuplicateGroup(
|
||||
group_id=f"incremental_{int(time.time())}",
|
||||
memories=[new_memory, best_match],
|
||||
similarity_matrix=[[1.0, similarity], [similarity, 1.0]]
|
||||
similarity_matrix=[[1.0, similarity], [similarity, 1.0]],
|
||||
)
|
||||
|
||||
# 执行融合
|
||||
@@ -530,5 +503,5 @@ class MemoryFusionEngine:
|
||||
"total_fusions": 0,
|
||||
"memories_fused": 0,
|
||||
"duplicates_removed": 0,
|
||||
"average_similarity": 0.0
|
||||
}
|
||||
"average_similarity": 0.0,
|
||||
}
|
||||
|
||||
@@ -9,12 +9,9 @@ from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_system import (
|
||||
initialize_memory_system
|
||||
)
|
||||
from src.chat.memory_system.memory_system import initialize_memory_system
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -22,6 +19,7 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class MemoryResult:
|
||||
"""记忆查询结果"""
|
||||
|
||||
content: str
|
||||
memory_type: str
|
||||
confidence: float
|
||||
@@ -67,6 +65,7 @@ class MemoryManager:
|
||||
# 获取LLM模型
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
|
||||
|
||||
# 初始化记忆系统
|
||||
@@ -121,7 +120,7 @@ class MemoryManager:
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
time_weight: float = 1.0,
|
||||
keyword_weight: float = 1.0
|
||||
keyword_weight: float = 1.0,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""从文本获取相关记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
@@ -131,14 +130,11 @@ class MemoryManager:
|
||||
# 使用增强记忆系统检索
|
||||
context = {
|
||||
"chat_id": chat_id,
|
||||
"expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE]
|
||||
"expected_memory_types": [MemoryType.PERSONAL_FACT, MemoryType.EVENT, MemoryType.PREFERENCE],
|
||||
}
|
||||
|
||||
relevant_memories = await self.memory_system.retrieve_relevant_memories(
|
||||
query=text,
|
||||
user_id=user_id,
|
||||
context=context,
|
||||
limit=max_memory_num
|
||||
query=text, user_id=user_id, context=context, limit=max_memory_num
|
||||
)
|
||||
|
||||
# 转换为原有格式 (topic, content)
|
||||
@@ -156,11 +152,7 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self,
|
||||
valid_keywords: List[str],
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
max_depth: int = 3
|
||||
self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""从关键词获取记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
@@ -177,15 +169,15 @@ class MemoryManager:
|
||||
MemoryType.PERSONAL_FACT,
|
||||
MemoryType.EVENT,
|
||||
MemoryType.PREFERENCE,
|
||||
MemoryType.OPINION
|
||||
]
|
||||
MemoryType.OPINION,
|
||||
],
|
||||
}
|
||||
|
||||
relevant_memories = await self.memory_system.retrieve_relevant_memories(
|
||||
query_text=query_text,
|
||||
user_id="default_user", # 可以根据实际需要传递
|
||||
context=context,
|
||||
limit=max_memory_num
|
||||
limit=max_memory_num,
|
||||
)
|
||||
|
||||
# 转换为原有格式 (topic, content)
|
||||
@@ -216,11 +208,7 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def process_conversation(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
user_id: str,
|
||||
timestamp: Optional[float] = None
|
||||
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""处理对话并构建记忆 - 新增功能"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
@@ -247,11 +235,7 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def get_enhanced_memory_context(
|
||||
self,
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 5
|
||||
self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5
|
||||
) -> List[MemoryResult]:
|
||||
"""获取增强记忆上下文 - 新增功能"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
@@ -259,10 +243,7 @@ class MemoryManager:
|
||||
|
||||
try:
|
||||
relevant_memories = await self.memory_system.retrieve_relevant_memories(
|
||||
query=query_text,
|
||||
user_id=None,
|
||||
context=context or {},
|
||||
limit=limit
|
||||
query=query_text, user_id=None, context=context or {}, limit=limit
|
||||
)
|
||||
|
||||
results = []
|
||||
@@ -276,7 +257,7 @@ class MemoryManager:
|
||||
timestamp=memory.metadata.created_at,
|
||||
source="enhanced_memory",
|
||||
relevance_score=memory.metadata.relevance_score,
|
||||
structure=structure
|
||||
structure=structure,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
@@ -342,7 +323,9 @@ class MemoryManager:
|
||||
return None
|
||||
return f"{subject}的职业是{profession}"
|
||||
if predicate == "lives_in":
|
||||
location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object(obj_value)
|
||||
location = self._extract_from_object(obj_value, ["location", "city", "place"]) or self._format_object(
|
||||
obj_value
|
||||
)
|
||||
location = self._clean_text(location)
|
||||
if not location:
|
||||
return None
|
||||
@@ -385,7 +368,9 @@ class MemoryManager:
|
||||
return None
|
||||
return f"{subject}最喜欢{favorite}"
|
||||
if predicate == "mentioned_event":
|
||||
event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object(obj_value)
|
||||
event_text = self._extract_from_object(obj_value, ["event_text", "description"]) or self._format_object(
|
||||
obj_value
|
||||
)
|
||||
event_text = self._clean_text(self._truncate(event_text))
|
||||
if not event_text:
|
||||
return None
|
||||
@@ -494,4 +479,4 @@ class MemoryManager:
|
||||
|
||||
|
||||
# 全局记忆管理器实例
|
||||
memory_manager = MemoryManager()
|
||||
memory_manager = MemoryManager()
|
||||
|
||||
@@ -12,7 +12,6 @@ from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryType, ImportanceLevel, ConfidenceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -20,22 +19,23 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class MemoryMetadataIndexEntry:
|
||||
"""元数据索引条目(轻量级,只用于快速过滤)"""
|
||||
|
||||
memory_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
# 分类信息
|
||||
memory_type: str # MemoryType.value
|
||||
subjects: List[str] # 主语列表
|
||||
objects: List[str] # 宾语列表
|
||||
keywords: List[str] # 关键词列表
|
||||
tags: List[str] # 标签列表
|
||||
|
||||
|
||||
# 数值字段(用于范围过滤)
|
||||
importance: int # ImportanceLevel.value (1-4)
|
||||
confidence: int # ConfidenceLevel.value (1-4)
|
||||
created_at: float # 创建时间戳
|
||||
access_count: int # 访问次数
|
||||
|
||||
|
||||
# 可选字段
|
||||
chat_id: Optional[str] = None
|
||||
content_preview: Optional[str] = None # 内容预览(前100字符)
|
||||
@@ -43,152 +43,152 @@ class MemoryMetadataIndexEntry:
|
||||
|
||||
class MemoryMetadataIndex:
|
||||
"""记忆元数据索引管理器"""
|
||||
|
||||
|
||||
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
|
||||
self.index_file = Path(index_file)
|
||||
self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
||||
|
||||
|
||||
# 倒排索引(用于快速查找)
|
||||
self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids}
|
||||
self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids}
|
||||
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids}
|
||||
self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids}
|
||||
|
||||
|
||||
self.lock = threading.RLock()
|
||||
|
||||
|
||||
# 加载已有索引
|
||||
self._load_index()
|
||||
|
||||
|
||||
def _load_index(self):
|
||||
"""从文件加载索引"""
|
||||
if not self.index_file.exists():
|
||||
logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
with open(self.index_file, 'rb') as f:
|
||||
with open(self.index_file, "rb") as f:
|
||||
data = orjson.loads(f.read())
|
||||
|
||||
|
||||
# 重建内存索引
|
||||
for entry_data in data.get('entries', []):
|
||||
for entry_data in data.get("entries", []):
|
||||
entry = MemoryMetadataIndexEntry(**entry_data)
|
||||
self.index[entry.memory_id] = entry
|
||||
self._update_inverted_indices(entry)
|
||||
|
||||
|
||||
logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载元数据索引失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
def _save_index(self):
|
||||
"""保存索引到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.index_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 序列化所有条目
|
||||
entries = [asdict(entry) for entry in self.index.values()]
|
||||
data = {
|
||||
'version': '1.0',
|
||||
'count': len(entries),
|
||||
'last_updated': datetime.now().isoformat(),
|
||||
'entries': entries
|
||||
"version": "1.0",
|
||||
"count": len(entries),
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"entries": entries,
|
||||
}
|
||||
|
||||
|
||||
# 写入文件(使用临时文件 + 原子重命名)
|
||||
temp_file = self.index_file.with_suffix('.tmp')
|
||||
with open(temp_file, 'wb') as f:
|
||||
temp_file = self.index_file.with_suffix(".tmp")
|
||||
with open(temp_file, "wb") as f:
|
||||
f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
|
||||
|
||||
|
||||
temp_file.replace(self.index_file)
|
||||
logger.debug(f"元数据索引已保存: {len(entries)} 条记录")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存元数据索引失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry):
|
||||
"""更新倒排索引"""
|
||||
memory_id = entry.memory_id
|
||||
|
||||
|
||||
# 类型索引
|
||||
self.type_index.setdefault(entry.memory_type, set()).add(memory_id)
|
||||
|
||||
|
||||
# 主语索引
|
||||
for subject in entry.subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
if subject_norm:
|
||||
self.subject_index.setdefault(subject_norm, set()).add(memory_id)
|
||||
|
||||
|
||||
# 关键词索引
|
||||
for keyword in entry.keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
if keyword_norm:
|
||||
self.keyword_index.setdefault(keyword_norm, set()).add(memory_id)
|
||||
|
||||
|
||||
# 标签索引
|
||||
for tag in entry.tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
if tag_norm:
|
||||
self.tag_index.setdefault(tag_norm, set()).add(memory_id)
|
||||
|
||||
|
||||
def add_or_update(self, entry: MemoryMetadataIndexEntry):
|
||||
"""添加或更新索引条目"""
|
||||
with self.lock:
|
||||
# 如果已存在,先从倒排索引中移除旧记录
|
||||
if entry.memory_id in self.index:
|
||||
self._remove_from_inverted_indices(entry.memory_id)
|
||||
|
||||
|
||||
# 添加新记录
|
||||
self.index[entry.memory_id] = entry
|
||||
self._update_inverted_indices(entry)
|
||||
|
||||
|
||||
def _remove_from_inverted_indices(self, memory_id: str):
|
||||
"""从倒排索引中移除记录"""
|
||||
if memory_id not in self.index:
|
||||
return
|
||||
|
||||
|
||||
entry = self.index[memory_id]
|
||||
|
||||
|
||||
# 从类型索引移除
|
||||
if entry.memory_type in self.type_index:
|
||||
self.type_index[entry.memory_type].discard(memory_id)
|
||||
|
||||
|
||||
# 从主语索引移除
|
||||
for subject in entry.subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
if subject_norm in self.subject_index:
|
||||
self.subject_index[subject_norm].discard(memory_id)
|
||||
|
||||
|
||||
# 从关键词索引移除
|
||||
for keyword in entry.keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
if keyword_norm in self.keyword_index:
|
||||
self.keyword_index[keyword_norm].discard(memory_id)
|
||||
|
||||
|
||||
# 从标签索引移除
|
||||
for tag in entry.tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
if tag_norm in self.tag_index:
|
||||
self.tag_index[tag_norm].discard(memory_id)
|
||||
|
||||
|
||||
def remove(self, memory_id: str):
|
||||
"""移除索引条目"""
|
||||
with self.lock:
|
||||
if memory_id in self.index:
|
||||
self._remove_from_inverted_indices(memory_id)
|
||||
del self.index[memory_id]
|
||||
|
||||
|
||||
def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]):
|
||||
"""批量添加或更新"""
|
||||
with self.lock:
|
||||
for entry in entries:
|
||||
self.add_or_update(entry)
|
||||
|
||||
|
||||
def save(self):
|
||||
"""保存索引到磁盘"""
|
||||
with self.lock:
|
||||
self._save_index()
|
||||
|
||||
|
||||
def search(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
@@ -201,11 +201,11 @@ class MemoryMetadataIndex:
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
flexible_mode: bool = True # 新增:灵活匹配模式
|
||||
flexible_mode: bool = True, # 新增:灵活匹配模式
|
||||
) -> List[str]:
|
||||
"""
|
||||
搜索符合条件的记忆ID列表(支持模糊匹配)
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 符合条件的 memory_id 列表
|
||||
"""
|
||||
@@ -219,7 +219,7 @@ class MemoryMetadataIndex:
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
user_id=user_id,
|
||||
limit=limit
|
||||
limit=limit,
|
||||
)
|
||||
else:
|
||||
return self._search_strict(
|
||||
@@ -232,7 +232,7 @@ class MemoryMetadataIndex:
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
user_id=user_id,
|
||||
limit=limit
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def _search_flexible(
|
||||
@@ -243,7 +243,7 @@ class MemoryMetadataIndex:
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
**kwargs # 接受但不使用的参数
|
||||
**kwargs, # 接受但不使用的参数
|
||||
) -> List[str]:
|
||||
"""
|
||||
灵活搜索模式:2/4项匹配即可,支持部分匹配
|
||||
@@ -258,10 +258,7 @@ class MemoryMetadataIndex:
|
||||
"""
|
||||
# 用户过滤(必选)
|
||||
if user_id:
|
||||
base_candidates = {
|
||||
mid for mid, entry in self.index.items()
|
||||
if entry.user_id == user_id
|
||||
}
|
||||
base_candidates = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
|
||||
else:
|
||||
base_candidates = set(self.index.keys())
|
||||
|
||||
@@ -386,7 +383,7 @@ class MemoryMetadataIndex:
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""严格搜索模式(原有逻辑)"""
|
||||
# 初始候选集(所有记忆)
|
||||
@@ -394,10 +391,7 @@ class MemoryMetadataIndex:
|
||||
|
||||
# 用户过滤(必选)
|
||||
if user_id:
|
||||
candidate_ids = {
|
||||
mid for mid, entry in self.index.items()
|
||||
if entry.user_id == user_id
|
||||
}
|
||||
candidate_ids = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
|
||||
else:
|
||||
candidate_ids = set(self.index.keys())
|
||||
|
||||
@@ -447,7 +441,8 @@ class MemoryMetadataIndex:
|
||||
# 重要性过滤
|
||||
if importance_min is not None or importance_max is not None:
|
||||
importance_ids = {
|
||||
mid for mid in candidate_ids
|
||||
mid
|
||||
for mid in candidate_ids
|
||||
if (importance_min is None or self.index[mid].importance >= importance_min)
|
||||
and (importance_max is None or self.index[mid].importance <= importance_max)
|
||||
}
|
||||
@@ -456,41 +451,37 @@ class MemoryMetadataIndex:
|
||||
# 时间范围过滤
|
||||
if created_after is not None or created_before is not None:
|
||||
time_ids = {
|
||||
mid for mid in candidate_ids
|
||||
mid
|
||||
for mid in candidate_ids
|
||||
if (created_after is None or self.index[mid].created_at >= created_after)
|
||||
and (created_before is None or self.index[mid].created_at <= created_before)
|
||||
}
|
||||
candidate_ids &= time_ids
|
||||
|
||||
# 转换为列表并排序(按创建时间倒序)
|
||||
result_ids = sorted(
|
||||
candidate_ids,
|
||||
key=lambda mid: self.index[mid].created_at,
|
||||
reverse=True
|
||||
)
|
||||
result_ids = sorted(candidate_ids, key=lambda mid: self.index[mid].created_at, reverse=True)
|
||||
|
||||
# 限制数量
|
||||
if limit:
|
||||
result_ids = result_ids[:limit]
|
||||
|
||||
logger.debug(
|
||||
f"[严格搜索] types={memory_types}, subjects={subjects}, "
|
||||
f"keywords={keywords}, 返回={len(result_ids)}条"
|
||||
f"[严格搜索] types={memory_types}, subjects={subjects}, keywords={keywords}, 返回={len(result_ids)}条"
|
||||
)
|
||||
|
||||
return result_ids
|
||||
|
||||
|
||||
def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]:
|
||||
"""获取单个索引条目"""
|
||||
return self.index.get(memory_id)
|
||||
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
with self.lock:
|
||||
return {
|
||||
'total_memories': len(self.index),
|
||||
'types': {mtype: len(ids) for mtype, ids in self.type_index.items()},
|
||||
'subjects_count': len(self.subject_index),
|
||||
'keywords_count': len(self.keyword_index),
|
||||
'tags_count': len(self.tag_index),
|
||||
"total_memories": len(self.index),
|
||||
"types": {mtype: len(ids) for mtype, ids in self.type_index.items()},
|
||||
"subjects_count": len(self.subject_index),
|
||||
"keywords_count": len(self.keyword_index),
|
||||
"tags_count": len(self.tag_index),
|
||||
}
|
||||
|
||||
@@ -80,10 +80,7 @@ class MemoryQueryPlanner:
|
||||
return self._default_plan(query_text)
|
||||
|
||||
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
|
||||
return MemoryQueryPlan(
|
||||
semantic_query=query_text,
|
||||
limit=self.default_limit
|
||||
)
|
||||
return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit)
|
||||
|
||||
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
|
||||
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
|
||||
@@ -122,7 +119,7 @@ class MemoryQueryPlanner:
|
||||
recency_preference=self._safe_str(data.get("recency")) or "any",
|
||||
limit=self._safe_int(data.get("limit"), self.default_limit),
|
||||
emphasis=self._safe_str(data.get("emphasis")) or "balanced",
|
||||
raw_plan=data
|
||||
raw_plan=data,
|
||||
)
|
||||
return plan
|
||||
|
||||
@@ -154,18 +151,18 @@ class MemoryQueryPlanner:
|
||||
|
||||
context_section = f"""
|
||||
|
||||
## 📋 未读消息上下文 (共{unread_context.get('total_count', 0)}条未读消息)
|
||||
## 📋 未读消息上下文 (共{unread_context.get("total_count", 0)}条未读消息)
|
||||
### 最近消息预览:
|
||||
{chr(10).join(message_previews)}
|
||||
|
||||
### 上下文关键词:
|
||||
{', '.join(unread_keywords[:15]) if unread_keywords else '无'}
|
||||
{", ".join(unread_keywords[:15]) if unread_keywords else "无"}
|
||||
|
||||
### 对话参与者:
|
||||
{', '.join(unread_participants) if unread_participants else '无'}
|
||||
{", ".join(unread_participants) if unread_participants else "无"}
|
||||
|
||||
### 上下文摘要:
|
||||
{context_summary[:300] if context_summary else '无'}
|
||||
{context_summary[:300] if context_summary else "无"}
|
||||
"""
|
||||
else:
|
||||
context_section = """
|
||||
@@ -223,7 +220,7 @@ class MemoryQueryPlanner:
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start:end + 1]
|
||||
return stripped[start : end + 1]
|
||||
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
@@ -243,4 +240,4 @@ class MemoryQueryPlanner:
|
||||
return default
|
||||
return number
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return default
|
||||
|
||||
@@ -35,6 +35,7 @@ GLOBAL_MEMORY_SCOPE = "global"
|
||||
|
||||
class MemorySystemStatus(Enum):
|
||||
"""记忆系统状态"""
|
||||
|
||||
INITIALIZING = "initializing"
|
||||
READY = "ready"
|
||||
BUILDING = "building"
|
||||
@@ -45,6 +46,7 @@ class MemorySystemStatus(Enum):
|
||||
@dataclass
|
||||
class MemorySystemConfig:
|
||||
"""记忆系统配置"""
|
||||
|
||||
# 记忆构建配置
|
||||
min_memory_length: int = 10
|
||||
max_memory_length: int = 500
|
||||
@@ -97,11 +99,9 @@ class MemorySystemConfig:
|
||||
max_memory_length=global_config.memory.max_memory_length,
|
||||
memory_value_threshold=global_config.memory.memory_value_threshold,
|
||||
min_build_interval_seconds=getattr(global_config.memory, "memory_build_interval", 300.0),
|
||||
|
||||
# 向量存储配置
|
||||
vector_dimension=int(embedding_dimension),
|
||||
similarity_threshold=global_config.memory.vector_similarity_threshold,
|
||||
|
||||
# 召回配置
|
||||
coarse_recall_limit=global_config.memory.metadata_filter_limit,
|
||||
fine_recall_limit=global_config.memory.vector_search_limit,
|
||||
@@ -112,21 +112,16 @@ class MemorySystemConfig:
|
||||
semantic_weight=global_config.memory.semantic_weight,
|
||||
context_weight=global_config.memory.context_weight,
|
||||
recency_weight=global_config.memory.recency_weight,
|
||||
|
||||
# 融合配置
|
||||
fusion_similarity_threshold=global_config.memory.fusion_similarity_threshold,
|
||||
deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours)
|
||||
deduplication_window=timedelta(hours=global_config.memory.deduplication_window_hours),
|
||||
)
|
||||
|
||||
|
||||
class MemorySystem:
|
||||
"""精准记忆系统核心类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_model: Optional[LLMRequest] = None,
|
||||
config: Optional[MemorySystemConfig] = None
|
||||
):
|
||||
def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None):
|
||||
self.config = config or MemorySystemConfig.from_global_config()
|
||||
self.llm_model = llm_model
|
||||
self.status = MemorySystemStatus.INITIALIZING
|
||||
@@ -175,16 +170,16 @@ class MemorySystem:
|
||||
extraction_task_config = value_task_config or fallback_task
|
||||
|
||||
if value_task_config is None or extraction_task_config is None:
|
||||
raise RuntimeError("无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。")
|
||||
raise RuntimeError(
|
||||
"无法初始化记忆系统所需的模型配置,请检查 model_task_config 中的 utils / utils_small 设置。"
|
||||
)
|
||||
|
||||
self.value_assessment_model = LLMRequest(
|
||||
model_set=value_task_config,
|
||||
request_type="memory.value_assessment"
|
||||
model_set=value_task_config, request_type="memory.value_assessment"
|
||||
)
|
||||
|
||||
self.memory_extraction_model = LLMRequest(
|
||||
model_set=extraction_task_config,
|
||||
request_type="memory.extraction"
|
||||
model_set=extraction_task_config, request_type="memory.extraction"
|
||||
)
|
||||
|
||||
# 初始化核心组件(简化版)
|
||||
@@ -198,13 +193,13 @@ class MemorySystem:
|
||||
memory_collection="unified_memory_v2",
|
||||
metadata_collection="memory_metadata_v2",
|
||||
similarity_threshold=self.config.similarity_threshold,
|
||||
search_limit=getattr(global_config.memory, 'unified_storage_search_limit', 20),
|
||||
batch_size=getattr(global_config.memory, 'unified_storage_batch_size', 100),
|
||||
enable_caching=getattr(global_config.memory, 'unified_storage_enable_caching', True),
|
||||
cache_size_limit=getattr(global_config.memory, 'unified_storage_cache_limit', 1000),
|
||||
auto_cleanup_interval=getattr(global_config.memory, 'unified_storage_auto_cleanup_interval', 3600),
|
||||
enable_forgetting=getattr(global_config.memory, 'enable_memory_forgetting', True),
|
||||
retention_hours=getattr(global_config.memory, 'memory_retention_hours', 720) # 30天
|
||||
search_limit=getattr(global_config.memory, "unified_storage_search_limit", 20),
|
||||
batch_size=getattr(global_config.memory, "unified_storage_batch_size", 100),
|
||||
enable_caching=getattr(global_config.memory, "unified_storage_enable_caching", True),
|
||||
cache_size_limit=getattr(global_config.memory, "unified_storage_cache_limit", 1000),
|
||||
auto_cleanup_interval=getattr(global_config.memory, "unified_storage_auto_cleanup_interval", 3600),
|
||||
enable_forgetting=getattr(global_config.memory, "enable_memory_forgetting", True),
|
||||
retention_hours=getattr(global_config.memory, "memory_retention_hours", 720), # 30天
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -220,32 +215,27 @@ class MemorySystem:
|
||||
# 从全局配置创建遗忘引擎配置
|
||||
forgetting_config = ForgettingConfig(
|
||||
# 检查频率配置
|
||||
check_interval_hours=getattr(global_config.memory, 'forgetting_check_interval_hours', 24),
|
||||
check_interval_hours=getattr(global_config.memory, "forgetting_check_interval_hours", 24),
|
||||
batch_size=100, # 固定值,暂不配置
|
||||
|
||||
# 遗忘阈值配置
|
||||
base_forgetting_days=getattr(global_config.memory, 'base_forgetting_days', 30.0),
|
||||
min_forgetting_days=getattr(global_config.memory, 'min_forgetting_days', 7.0),
|
||||
max_forgetting_days=getattr(global_config.memory, 'max_forgetting_days', 365.0),
|
||||
|
||||
base_forgetting_days=getattr(global_config.memory, "base_forgetting_days", 30.0),
|
||||
min_forgetting_days=getattr(global_config.memory, "min_forgetting_days", 7.0),
|
||||
max_forgetting_days=getattr(global_config.memory, "max_forgetting_days", 365.0),
|
||||
# 重要程度权重
|
||||
critical_importance_bonus=getattr(global_config.memory, 'critical_importance_bonus', 45.0),
|
||||
high_importance_bonus=getattr(global_config.memory, 'high_importance_bonus', 30.0),
|
||||
normal_importance_bonus=getattr(global_config.memory, 'normal_importance_bonus', 15.0),
|
||||
low_importance_bonus=getattr(global_config.memory, 'low_importance_bonus', 0.0),
|
||||
|
||||
critical_importance_bonus=getattr(global_config.memory, "critical_importance_bonus", 45.0),
|
||||
high_importance_bonus=getattr(global_config.memory, "high_importance_bonus", 30.0),
|
||||
normal_importance_bonus=getattr(global_config.memory, "normal_importance_bonus", 15.0),
|
||||
low_importance_bonus=getattr(global_config.memory, "low_importance_bonus", 0.0),
|
||||
# 置信度权重
|
||||
verified_confidence_bonus=getattr(global_config.memory, 'verified_confidence_bonus', 30.0),
|
||||
high_confidence_bonus=getattr(global_config.memory, 'high_confidence_bonus', 20.0),
|
||||
medium_confidence_bonus=getattr(global_config.memory, 'medium_confidence_bonus', 10.0),
|
||||
low_confidence_bonus=getattr(global_config.memory, 'low_confidence_bonus', 0.0),
|
||||
|
||||
verified_confidence_bonus=getattr(global_config.memory, "verified_confidence_bonus", 30.0),
|
||||
high_confidence_bonus=getattr(global_config.memory, "high_confidence_bonus", 20.0),
|
||||
medium_confidence_bonus=getattr(global_config.memory, "medium_confidence_bonus", 10.0),
|
||||
low_confidence_bonus=getattr(global_config.memory, "low_confidence_bonus", 0.0),
|
||||
# 激活频率权重
|
||||
activation_frequency_weight=getattr(global_config.memory, 'activation_frequency_weight', 0.5),
|
||||
max_frequency_bonus=getattr(global_config.memory, 'max_frequency_bonus', 10.0),
|
||||
|
||||
activation_frequency_weight=getattr(global_config.memory, "activation_frequency_weight", 0.5),
|
||||
max_frequency_bonus=getattr(global_config.memory, "max_frequency_bonus", 10.0),
|
||||
# 休眠配置
|
||||
dormant_threshold_days=getattr(global_config.memory, 'dormant_threshold_days', 90)
|
||||
dormant_threshold_days=getattr(global_config.memory, "dormant_threshold_days", 90),
|
||||
)
|
||||
|
||||
self.forgetting_engine = MemoryForgettingEngine(forgetting_config)
|
||||
@@ -253,17 +243,11 @@ class MemorySystem:
|
||||
planner_task_config = getattr(model_config.model_task_config, "utils_small", None)
|
||||
planner_model: Optional[LLMRequest] = None
|
||||
try:
|
||||
planner_model = LLMRequest(
|
||||
model_set=planner_task_config,
|
||||
request_type="memory.query_planner"
|
||||
)
|
||||
planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner")
|
||||
except Exception as planner_exc:
|
||||
logger.warning("查询规划模型初始化失败,将使用默认规划策略: %s", planner_exc, exc_info=True)
|
||||
|
||||
self.query_planner = MemoryQueryPlanner(
|
||||
planner_model,
|
||||
default_limit=self.config.final_recall_limit
|
||||
)
|
||||
self.query_planner = MemoryQueryPlanner(planner_model, default_limit=self.config.final_recall_limit)
|
||||
|
||||
# 统一存储已经自动加载数据,无需额外加载
|
||||
logger.info("✅ 简化版记忆系统初始化完成")
|
||||
@@ -277,11 +261,7 @@ class MemorySystem:
|
||||
raise
|
||||
|
||||
async def retrieve_memories_for_building(
|
||||
self,
|
||||
query_text: str,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 5
|
||||
self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5
|
||||
) -> List[MemoryChunk]:
|
||||
"""在构建记忆时检索相关记忆,使用统一存储系统
|
||||
|
||||
@@ -304,9 +284,7 @@ class MemorySystem:
|
||||
try:
|
||||
# 使用统一存储检索相似记忆
|
||||
search_results = await self.unified_storage.search_similar_memories(
|
||||
query_text=query_text,
|
||||
limit=limit,
|
||||
scope_id=user_id
|
||||
query_text=query_text, limit=limit, scope_id=user_id
|
||||
)
|
||||
|
||||
# 转换为记忆对象
|
||||
@@ -324,10 +302,7 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: Dict[str, Any],
|
||||
timestamp: Optional[float] = None
|
||||
self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
"""从对话中构建记忆
|
||||
|
||||
@@ -383,7 +358,7 @@ class MemorySystem:
|
||||
conversation_text,
|
||||
normalized_context,
|
||||
GLOBAL_MEMORY_SCOPE, # 强制使用 global,不区分用户
|
||||
timestamp or time.time()
|
||||
timestamp or time.time(),
|
||||
)
|
||||
|
||||
if not memory_chunks:
|
||||
@@ -393,10 +368,7 @@ class MemorySystem:
|
||||
|
||||
# 3. 记忆融合与去重(包含与历史记忆的融合)
|
||||
existing_candidates = await self._collect_fusion_candidates(memory_chunks)
|
||||
fused_chunks = await self.fusion_engine.fuse_memories(
|
||||
memory_chunks,
|
||||
existing_candidates
|
||||
)
|
||||
fused_chunks = await self.fusion_engine.fuse_memories(memory_chunks, existing_candidates)
|
||||
|
||||
# 4. 存储记忆到统一存储
|
||||
stored_count = await self._store_memories_unified(fused_chunks)
|
||||
@@ -459,11 +431,7 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
candidate_ids: Set[str] = set()
|
||||
new_memory_ids = {
|
||||
memory.memory_id
|
||||
for memory in new_memories
|
||||
if memory and getattr(memory, "memory_id", None)
|
||||
}
|
||||
new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)}
|
||||
|
||||
# 基于指纹的直接匹配
|
||||
for memory in new_memories:
|
||||
@@ -501,9 +469,7 @@ class MemorySystem:
|
||||
continue
|
||||
search_tasks.append(
|
||||
self.unified_storage.search_similar_memories(
|
||||
query_text=display_text,
|
||||
limit=8,
|
||||
scope_id=GLOBAL_MEMORY_SCOPE
|
||||
query_text=display_text, limit=8, scope_id=GLOBAL_MEMORY_SCOPE
|
||||
)
|
||||
)
|
||||
|
||||
@@ -545,10 +511,7 @@ class MemorySystem:
|
||||
|
||||
return existing_candidates
|
||||
|
||||
async def process_conversation_memory(
|
||||
self,
|
||||
context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -563,7 +526,9 @@ class MemorySystem:
|
||||
or ""
|
||||
)
|
||||
|
||||
conversation_text = conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate)
|
||||
conversation_text = (
|
||||
conversation_candidate if isinstance(conversation_candidate, str) else str(conversation_candidate)
|
||||
)
|
||||
|
||||
timestamp = context.get("timestamp")
|
||||
if timestamp is None:
|
||||
@@ -573,9 +538,7 @@ class MemorySystem:
|
||||
normalized_context.setdefault("conversation_text", conversation_text)
|
||||
|
||||
memories = await self.build_memory_from_conversation(
|
||||
conversation_text=conversation_text,
|
||||
context=normalized_context,
|
||||
timestamp=timestamp
|
||||
conversation_text=conversation_text, context=normalized_context, timestamp=timestamp
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
@@ -586,18 +549,13 @@ class MemorySystem:
|
||||
"created_memories": memories,
|
||||
"memory_count": memory_count,
|
||||
"processing_time": processing_time,
|
||||
"status": self.status.value
|
||||
"status": self.status.value,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
processing_time = time.time() - start_time
|
||||
logger.error(f"对话记忆处理失败: {e}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"processing_time": processing_time,
|
||||
"status": self.status.value
|
||||
}
|
||||
return {"success": False, "error": str(e), "processing_time": processing_time, "status": self.status.value}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
@@ -605,7 +563,7 @@ class MemorySystem:
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 5,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> List[MemoryChunk]:
|
||||
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)"""
|
||||
raw_query = query_text or kwargs.get("query")
|
||||
@@ -617,7 +575,7 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
context = context or {}
|
||||
|
||||
|
||||
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
|
||||
resolved_user_id = GLOBAL_MEMORY_SCOPE
|
||||
|
||||
@@ -627,7 +585,7 @@ class MemorySystem:
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, None)
|
||||
effective_limit = self.config.final_recall_limit
|
||||
|
||||
|
||||
# === 阶段一:元数据粗筛(软性过滤) ===
|
||||
coarse_filters = {
|
||||
"user_id": GLOBAL_MEMORY_SCOPE, # 必选:确保作用域正确
|
||||
@@ -642,119 +600,126 @@ class MemorySystem:
|
||||
# 构建包含未读消息的增强上下文
|
||||
enhanced_context = await self._build_enhanced_query_context(raw_query, normalized_context)
|
||||
query_plan = await self.query_planner.plan_query(raw_query, enhanced_context)
|
||||
|
||||
|
||||
# 使用LLM优化后的查询语句(更精确的语义表达)
|
||||
if getattr(query_plan, "semantic_query", None):
|
||||
optimized_query = query_plan.semantic_query
|
||||
|
||||
|
||||
# 构建JSON元数据过滤条件(用于阶段一粗筛)
|
||||
# 将查询规划的结果转换为元数据过滤条件
|
||||
if getattr(query_plan, "memory_types", None):
|
||||
metadata_filters['memory_types'] = [mt.value for mt in query_plan.memory_types]
|
||||
|
||||
metadata_filters["memory_types"] = [mt.value for mt in query_plan.memory_types]
|
||||
|
||||
if getattr(query_plan, "subject_includes", None):
|
||||
metadata_filters['subjects'] = query_plan.subject_includes
|
||||
|
||||
metadata_filters["subjects"] = query_plan.subject_includes
|
||||
|
||||
if getattr(query_plan, "required_keywords", None):
|
||||
metadata_filters['keywords'] = query_plan.required_keywords
|
||||
|
||||
metadata_filters["keywords"] = query_plan.required_keywords
|
||||
|
||||
# 时间范围过滤
|
||||
recency = getattr(query_plan, "recency_preference", "any")
|
||||
current_time = time.time()
|
||||
if recency == "recent":
|
||||
# 最近7天
|
||||
metadata_filters['created_after'] = current_time - (7 * 24 * 3600)
|
||||
metadata_filters["created_after"] = current_time - (7 * 24 * 3600)
|
||||
elif recency == "historical":
|
||||
# 30天以前
|
||||
metadata_filters['created_before'] = current_time - (30 * 24 * 3600)
|
||||
|
||||
metadata_filters["created_before"] = current_time - (30 * 24 * 3600)
|
||||
|
||||
# 添加用户ID到元数据过滤
|
||||
metadata_filters['user_id'] = GLOBAL_MEMORY_SCOPE
|
||||
|
||||
metadata_filters["user_id"] = GLOBAL_MEMORY_SCOPE
|
||||
|
||||
logger.debug(f"[阶段一] 查询优化: '{raw_query}' → '{optimized_query}'")
|
||||
logger.debug(f"[阶段一] 元数据过滤条件: {metadata_filters}")
|
||||
|
||||
|
||||
except Exception as plan_exc:
|
||||
logger.warning("查询规划失败,使用原始查询: %s", plan_exc, exc_info=True)
|
||||
# 即使查询规划失败,也保留基本的user_id过滤
|
||||
metadata_filters = {'user_id': GLOBAL_MEMORY_SCOPE}
|
||||
metadata_filters = {"user_id": GLOBAL_MEMORY_SCOPE}
|
||||
|
||||
# === 阶段二:向量精筛 ===
|
||||
coarse_limit = self.config.coarse_recall_limit # 粗筛阶段返回更多候选
|
||||
|
||||
|
||||
logger.debug(f"[阶段二] 开始向量搜索: query='{optimized_query[:60]}...', limit={coarse_limit}")
|
||||
|
||||
|
||||
search_results = await self.unified_storage.search_similar_memories(
|
||||
query_text=optimized_query,
|
||||
limit=coarse_limit,
|
||||
filters=coarse_filters, # ChromaDB where条件(保留兼容)
|
||||
metadata_filters=metadata_filters # JSON元数据索引过滤
|
||||
metadata_filters=metadata_filters, # JSON元数据索引过滤
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[阶段二] 向量搜索完成: 返回 {len(search_results)} 条候选")
|
||||
|
||||
# === 阶段三:综合重排 ===
|
||||
scored_memories = []
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
for memory, vector_similarity in search_results:
|
||||
# 1. 向量相似度得分(已归一化到 0-1)
|
||||
vector_score = vector_similarity
|
||||
|
||||
|
||||
# 2. 时效性得分(指数衰减,30天半衰期)
|
||||
age_seconds = current_time - memory.metadata.created_at
|
||||
age_days = age_seconds / (24 * 3600)
|
||||
# 使用 math.exp 而非 np.exp(避免依赖numpy)
|
||||
import math
|
||||
|
||||
recency_score = math.exp(-age_days / 30)
|
||||
|
||||
|
||||
# 3. 重要性得分(枚举值转换为归一化得分 0-1)
|
||||
# ImportanceLevel: LOW=1, NORMAL=2, HIGH=3, CRITICAL=4
|
||||
importance_enum = memory.metadata.importance
|
||||
if hasattr(importance_enum, 'value'):
|
||||
if hasattr(importance_enum, "value"):
|
||||
# 枚举类型,转换为0-1范围:(value - 1) / 3
|
||||
importance_score = (importance_enum.value - 1) / 3.0
|
||||
else:
|
||||
# 如果已经是数值,直接使用
|
||||
importance_score = float(importance_enum) if importance_enum else 0.5
|
||||
|
||||
|
||||
# 4. 访问频率得分(归一化,访问10次以上得满分)
|
||||
access_count = memory.metadata.access_count
|
||||
frequency_score = min(access_count / 10.0, 1.0)
|
||||
|
||||
|
||||
# 综合得分(加权平均)
|
||||
final_score = (
|
||||
self.config.vector_weight * vector_score +
|
||||
self.config.recency_weight * recency_score +
|
||||
self.config.context_weight * importance_score +
|
||||
0.1 * frequency_score # 访问频率权重(固定10%)
|
||||
self.config.vector_weight * vector_score
|
||||
+ self.config.recency_weight * recency_score
|
||||
+ self.config.context_weight * importance_score
|
||||
+ 0.1 * frequency_score # 访问频率权重(固定10%)
|
||||
)
|
||||
|
||||
scored_memories.append((memory, final_score, {
|
||||
"vector": vector_score,
|
||||
"recency": recency_score,
|
||||
"importance": importance_score,
|
||||
"frequency": frequency_score,
|
||||
"final": final_score
|
||||
}))
|
||||
|
||||
|
||||
scored_memories.append(
|
||||
(
|
||||
memory,
|
||||
final_score,
|
||||
{
|
||||
"vector": vector_score,
|
||||
"recency": recency_score,
|
||||
"importance": importance_score,
|
||||
"frequency": frequency_score,
|
||||
"final": final_score,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# 更新访问记录
|
||||
memory.update_access()
|
||||
|
||||
# 按综合得分排序
|
||||
scored_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
# 返回 Top-K
|
||||
final_memories = [mem for mem, score, details in scored_memories[:effective_limit]]
|
||||
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
|
||||
# 详细日志
|
||||
if scored_memories:
|
||||
logger.info(f"[阶段三] 综合重排完成: Top 3 得分详情")
|
||||
logger.info("[阶段三] 综合重排完成: Top 3 得分详情")
|
||||
for i, (mem, score, details) in enumerate(scored_memories[:3], 1):
|
||||
try:
|
||||
summary = mem.content[:60] if hasattr(mem, 'content') and mem.content else ""
|
||||
summary = mem.content[:60] if hasattr(mem, "content") and mem.content else ""
|
||||
except:
|
||||
summary = ""
|
||||
logger.info(
|
||||
@@ -803,15 +768,12 @@ class MemorySystem:
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start:end + 1].strip()
|
||||
return stripped[start : end + 1].strip()
|
||||
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _normalize_context(
|
||||
self,
|
||||
raw_context: Optional[Dict[str, Any]],
|
||||
user_id: Optional[str],
|
||||
timestamp: Optional[float]
|
||||
self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float]
|
||||
) -> Dict[str, Any]:
|
||||
"""标准化上下文,确保必备字段存在且格式正确"""
|
||||
context: Dict[str, Any] = {}
|
||||
@@ -850,9 +812,7 @@ class MemorySystem:
|
||||
|
||||
# 历史窗口配置
|
||||
window_candidate = (
|
||||
context.get("history_limit")
|
||||
or context.get("history_window")
|
||||
or context.get("memory_history_limit")
|
||||
context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit")
|
||||
)
|
||||
if window_candidate is not None:
|
||||
try:
|
||||
@@ -888,7 +848,9 @@ class MemorySystem:
|
||||
enhanced_context["unread_messages_context"] = unread_messages_summary
|
||||
enhanced_context["has_unread_context"] = True
|
||||
|
||||
logger.debug(f"为查询规划构建了增强上下文,包含 {len(unread_messages_summary.get('messages', []))} 条未读消息")
|
||||
logger.debug(
|
||||
f"为查询规划构建了增强上下文,包含 {len(unread_messages_summary.get('messages', []))} 条未读消息"
|
||||
)
|
||||
else:
|
||||
enhanced_context["has_unread_context"] = False
|
||||
logger.debug("未找到未读消息,使用基础上下文进行查询规划")
|
||||
@@ -934,26 +896,30 @@ class MemorySystem:
|
||||
for msg in unread_messages[:10]: # 限制处理最近10条未读消息
|
||||
try:
|
||||
# 提取消息内容
|
||||
content = (getattr(msg, "processed_plain_text", None) or
|
||||
getattr(msg, "display_message", None) or "")
|
||||
content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None) or ""
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 提取发送者信息
|
||||
sender_name = "未知用户"
|
||||
if hasattr(msg, "user_info") and msg.user_info:
|
||||
sender_name = (getattr(msg.user_info, "user_nickname", None) or
|
||||
getattr(msg.user_info, "user_cardname", None) or
|
||||
getattr(msg.user_info, "user_id", None) or "未知用户")
|
||||
sender_name = (
|
||||
getattr(msg.user_info, "user_nickname", None)
|
||||
or getattr(msg.user_info, "user_cardname", None)
|
||||
or getattr(msg.user_info, "user_id", None)
|
||||
or "未知用户"
|
||||
)
|
||||
|
||||
participant_names.add(sender_name)
|
||||
|
||||
# 添加到消息摘要
|
||||
messages_summary.append({
|
||||
"sender": sender_name,
|
||||
"content": content[:200], # 限制长度避免过长
|
||||
"timestamp": getattr(msg, "time", None)
|
||||
})
|
||||
messages_summary.append(
|
||||
{
|
||||
"sender": sender_name,
|
||||
"content": content[:200], # 限制长度避免过长
|
||||
"timestamp": getattr(msg, "time", None),
|
||||
}
|
||||
)
|
||||
|
||||
# 提取关键词(简单实现)
|
||||
content_lower = content.lower()
|
||||
@@ -975,10 +941,12 @@ class MemorySystem:
|
||||
"processed_count": len(messages_summary),
|
||||
"keywords": list(all_keywords)[:20], # 最多20个关键词
|
||||
"participants": list(participant_names),
|
||||
"context_summary": self._build_unread_context_summary(messages_summary)
|
||||
"context_summary": self._build_unread_context_summary(messages_summary),
|
||||
}
|
||||
|
||||
logger.debug(f"收集到未读消息上下文: {len(messages_summary)}条消息,{len(all_keywords)}个关键词,{len(participant_names)}个参与者")
|
||||
logger.debug(
|
||||
f"收集到未读消息上下文: {len(messages_summary)}条消息,{len(all_keywords)}个关键词,{len(participant_names)}个参与者"
|
||||
)
|
||||
return unread_context
|
||||
|
||||
except Exception as e:
|
||||
@@ -1051,10 +1019,7 @@ class MemorySystem:
|
||||
if user_id and fallback_text:
|
||||
try:
|
||||
relevant_memories = await self.retrieve_memories_for_building(
|
||||
query_text=fallback_text,
|
||||
user_id=user_id,
|
||||
context=context,
|
||||
limit=3
|
||||
query_text=fallback_text, user_id=user_id, context=context, limit=3
|
||||
)
|
||||
|
||||
if relevant_memories:
|
||||
@@ -1068,9 +1033,7 @@ class MemorySystem:
|
||||
memory_transcript = f"{memory_transcript}\n[当前消息] {cleaned_fallback}"
|
||||
|
||||
logger.debug(
|
||||
"使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s",
|
||||
len(relevant_memories),
|
||||
user_id
|
||||
"使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s", len(relevant_memories), user_id
|
||||
)
|
||||
return memory_transcript
|
||||
|
||||
@@ -1087,11 +1050,7 @@ class MemorySystem:
|
||||
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
|
||||
"""确定历史消息获取数量,限制在30-50之间"""
|
||||
default_limit = 40
|
||||
candidate = (
|
||||
context.get("history_limit")
|
||||
or context.get("history_window")
|
||||
or context.get("memory_history_limit")
|
||||
)
|
||||
candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit")
|
||||
|
||||
if isinstance(candidate, str):
|
||||
try:
|
||||
@@ -1186,9 +1145,9 @@ class MemorySystem:
|
||||
{text}
|
||||
|
||||
上下文信息:
|
||||
- 用户ID: {context.get('user_id', 'unknown')}
|
||||
- 消息类型: {context.get('message_type', 'unknown')}
|
||||
- 时间: {datetime.fromtimestamp(context.get('timestamp', time.time()))}
|
||||
- 用户ID: {context.get("user_id", "unknown")}
|
||||
- 消息类型: {context.get("message_type", "unknown")}
|
||||
- 时间: {datetime.fromtimestamp(context.get("timestamp", time.time()))}
|
||||
|
||||
## 📋 评估要求:
|
||||
|
||||
@@ -1214,9 +1173,7 @@ class MemorySystem:
|
||||
}}
|
||||
"""
|
||||
|
||||
response, _ = await self.value_assessment_model.generate_response_async(
|
||||
prompt, temperature=0.3
|
||||
)
|
||||
response, _ = await self.value_assessment_model.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
# 解析响应
|
||||
try:
|
||||
@@ -1236,7 +1193,7 @@ class MemorySystem:
|
||||
return max(0.0, min(1.0, value_score))
|
||||
|
||||
except (orjson.JSONDecodeError, ValueError) as e:
|
||||
preview = response[:200].replace('\n', ' ')
|
||||
preview = response[:200].replace("\n", " ")
|
||||
logger.warning(f"解析价值评估响应失败: {e}, 响应片段: {preview}")
|
||||
return 0.5 # 默认中等价值
|
||||
|
||||
@@ -1331,13 +1288,15 @@ class MemorySystem:
|
||||
else:
|
||||
obj_part = str(obj).strip()
|
||||
|
||||
base = "|".join([
|
||||
str(memory.user_id or "unknown"),
|
||||
memory.memory_type.value,
|
||||
subject_part,
|
||||
predicate_part,
|
||||
obj_part,
|
||||
])
|
||||
base = "|".join(
|
||||
[
|
||||
str(memory.user_id or "unknown"),
|
||||
memory.memory_type.value,
|
||||
subject_part,
|
||||
predicate_part,
|
||||
obj_part,
|
||||
]
|
||||
)
|
||||
|
||||
return hashlib.sha256(base.encode("utf-8")).hexdigest()
|
||||
|
||||
@@ -1352,7 +1311,7 @@ class MemorySystem:
|
||||
"total_memories": self.total_memories,
|
||||
"last_build_time": self.last_build_time,
|
||||
"last_retrieval_time": self.last_retrieval_time,
|
||||
"config": asdict(self.config)
|
||||
"config": asdict(self.config),
|
||||
}
|
||||
|
||||
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
@@ -1369,7 +1328,9 @@ class MemorySystem:
|
||||
keyword_overlap = 0.0
|
||||
if context_keywords:
|
||||
memory_keywords = set(k.lower() for k in memory.keywords)
|
||||
keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(len(context_keywords), 1)
|
||||
keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(
|
||||
len(context_keywords), 1
|
||||
)
|
||||
|
||||
importance_boost = (memory.metadata.importance.value - 1) / 3 * 0.1
|
||||
confidence_boost = (memory.metadata.confidence.value - 1) / 3 * 0.05
|
||||
@@ -1429,7 +1390,7 @@ class MemorySystem:
|
||||
"""重建向量存储(如果需要)"""
|
||||
try:
|
||||
# 检查是否有记忆缓存数据
|
||||
if not hasattr(self.unified_storage, 'memory_cache') or not self.unified_storage.memory_cache:
|
||||
if not hasattr(self.unified_storage, "memory_cache") or not self.unified_storage.memory_cache:
|
||||
logger.info("无记忆缓存数据,跳过向量存储重建")
|
||||
return
|
||||
|
||||
@@ -1443,19 +1404,19 @@ class MemorySystem:
|
||||
memories_to_rebuild.append(memory)
|
||||
elif memory.text_content and memory.text_content.strip():
|
||||
memories_to_rebuild.append(memory)
|
||||
|
||||
|
||||
if not memories_to_rebuild:
|
||||
logger.warning("没有找到可重建向量的记忆")
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"准备为 {len(memories_to_rebuild)} 条记忆重建向量")
|
||||
|
||||
|
||||
# 批量重建向量
|
||||
batch_size = 10
|
||||
rebuild_count = 0
|
||||
|
||||
|
||||
for i in range(0, len(memories_to_rebuild), batch_size):
|
||||
batch = memories_to_rebuild[i:i + batch_size]
|
||||
batch = memories_to_rebuild[i : i + batch_size]
|
||||
try:
|
||||
await self.unified_storage.store_memories(batch)
|
||||
rebuild_count += len(batch)
|
||||
@@ -1472,7 +1433,7 @@ class MemorySystem:
|
||||
|
||||
final_count = self.unified_storage.storage_stats.get("total_vectors", 0)
|
||||
logger.info(f"✅ 向量存储重建完成,最终向量数量: {final_count}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 向量存储重建失败: {e}", exc_info=True)
|
||||
|
||||
@@ -1495,4 +1456,4 @@ async def initialize_memory_system(llm_model: Optional[LLMRequest] = None):
|
||||
if memory_system is None:
|
||||
memory_system = MemorySystem(llm_model=llm_model)
|
||||
await memory_system.initialize()
|
||||
return memory_system
|
||||
return memory_system
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,15 +5,12 @@
|
||||
|
||||
from .message_manager import MessageManager, message_manager
|
||||
from .context_manager import SingleStreamContextManager
|
||||
from .distribution_manager import (
|
||||
StreamLoopManager,
|
||||
stream_loop_manager
|
||||
)
|
||||
from .distribution_manager import StreamLoopManager, stream_loop_manager
|
||||
|
||||
__all__ = [
|
||||
"MessageManager",
|
||||
"message_manager",
|
||||
"SingleStreamContextManager",
|
||||
"StreamLoopManager",
|
||||
"stream_loop_manager"
|
||||
]
|
||||
"stream_loop_manager",
|
||||
]
|
||||
|
||||
@@ -230,12 +230,14 @@ class SingleStreamContextManager:
|
||||
异步计算消息的兴趣度。
|
||||
此方法通过检查当前是否存在正在运行的 asyncio 事件循环来兼容同步和异步调用。
|
||||
"""
|
||||
|
||||
# 内部异步函数,封装实际的计算逻辑
|
||||
async def _get_score():
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
|
||||
chatter_interest_scoring_system,
|
||||
)
|
||||
|
||||
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=message, bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
|
||||
@@ -34,17 +34,13 @@ class StreamLoopManager:
|
||||
}
|
||||
|
||||
# 配置参数
|
||||
self.max_concurrent_streams = (
|
||||
max_concurrent_streams or global_config.chat.max_concurrent_distributions
|
||||
)
|
||||
self.max_concurrent_streams = max_concurrent_streams or global_config.chat.max_concurrent_distributions
|
||||
|
||||
# 强制分发策略
|
||||
self.force_dispatch_unread_threshold: Optional[int] = getattr(
|
||||
global_config.chat, "force_dispatch_unread_threshold", 20
|
||||
)
|
||||
self.force_dispatch_min_interval: float = getattr(
|
||||
global_config.chat, "force_dispatch_min_interval", 0.1
|
||||
)
|
||||
self.force_dispatch_min_interval: float = getattr(global_config.chat, "force_dispatch_min_interval", 0.1)
|
||||
|
||||
# Chatter管理器
|
||||
self.chatter_manager: Optional[ChatterManager] = None
|
||||
@@ -108,7 +104,9 @@ class StreamLoopManager:
|
||||
|
||||
if force and len(self.stream_loops) >= self.max_concurrent_streams:
|
||||
logger.warning(
|
||||
"流 %s 未读消息积压严重(>%s),突破并发限制强制启动分发", stream_id, self.force_dispatch_unread_threshold
|
||||
"流 %s 未读消息积压严重(>%s),突破并发限制强制启动分发",
|
||||
stream_id,
|
||||
self.force_dispatch_unread_threshold,
|
||||
)
|
||||
|
||||
# 创建流循环任务
|
||||
@@ -168,9 +166,7 @@ class StreamLoopManager:
|
||||
|
||||
if has_messages:
|
||||
if force_dispatch:
|
||||
logger.info(
|
||||
"流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count
|
||||
)
|
||||
logger.info("流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count)
|
||||
# 3. 激活chatter处理
|
||||
success = await self._process_stream_messages(stream_id, context)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Dict, Optional, Any, TYPE_CHECKING, List
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats
|
||||
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
@@ -21,7 +21,7 @@ from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
from .distribution_manager import stream_loop_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
pass
|
||||
|
||||
logger = get_logger("message_manager")
|
||||
|
||||
@@ -63,7 +63,7 @@ class MessageManager:
|
||||
stream_loop_manager.set_chatter_manager(self.chatter_manager)
|
||||
|
||||
logger.info("🚀 消息管理器已启动 | 流循环管理器已启动")
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""停止消息管理器"""
|
||||
if not self.is_running:
|
||||
@@ -88,7 +88,9 @@ class MessageManager:
|
||||
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
await self._check_and_handle_interruption(chat_stream)
|
||||
chat_stream.context_manager.context.processing_task = asyncio.create_task(chat_stream.context_manager.add_message(message))
|
||||
chat_stream.context_manager.context.processing_task = asyncio.create_task(
|
||||
chat_stream.context_manager.add_message(message)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
|
||||
|
||||
@@ -141,11 +143,7 @@ class MessageManager:
|
||||
if not message_id:
|
||||
continue
|
||||
|
||||
payload = {
|
||||
key: value
|
||||
for key, value in item.items()
|
||||
if key != "message_id" and value is not None
|
||||
}
|
||||
payload = {key: value for key, value in item.items() if key != "message_id" and value is not None}
|
||||
|
||||
if not payload:
|
||||
continue
|
||||
@@ -169,9 +167,7 @@ class MessageManager:
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
success = await chat_stream.context_manager.update_message(
|
||||
message_id, {"actions": [action]}
|
||||
)
|
||||
success = await chat_stream.context_manager.update_message(message_id, {"actions": [action]})
|
||||
if success:
|
||||
logger.debug(f"为消息 {message_id} 添加动作 {action} 成功")
|
||||
else:
|
||||
@@ -193,7 +189,7 @@ class MessageManager:
|
||||
context.is_active = False
|
||||
|
||||
# 取消处理任务
|
||||
if hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done():
|
||||
if hasattr(context, "processing_task") and context.processing_task and not context.processing_task.done():
|
||||
context.processing_task.cancel()
|
||||
|
||||
logger.info(f"停用聊天流: {stream_id}")
|
||||
@@ -236,7 +232,11 @@ class MessageManager:
|
||||
unread_count=unread_count,
|
||||
history_count=len(context.history_messages),
|
||||
last_check_time=context.last_check_time,
|
||||
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
|
||||
has_active_task=bool(
|
||||
hasattr(context, "processing_task")
|
||||
and context.processing_task
|
||||
and not context.processing_task.done()
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -284,7 +284,10 @@ class MessageManager:
|
||||
return
|
||||
|
||||
# 检查是否有正在进行的处理任务
|
||||
if chat_stream.context_manager.context.processing_task and not chat_stream.context_manager.context.processing_task.done():
|
||||
if (
|
||||
chat_stream.context_manager.context.processing_task
|
||||
and not chat_stream.context_manager.context.processing_task.done()
|
||||
):
|
||||
# 计算打断概率
|
||||
interruption_probability = chat_stream.context_manager.context.calculate_interruption_probability(
|
||||
global_config.chat.interruption_max_limit, global_config.chat.interruption_probability_factor
|
||||
@@ -310,7 +313,9 @@ class MessageManager:
|
||||
|
||||
# 增加打断计数并应用afc阈值降低
|
||||
chat_stream.context_manager.context.increment_interruption_count()
|
||||
chat_stream.context_manager.context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction)
|
||||
chat_stream.context_manager.context.apply_interruption_afc_reduction(
|
||||
global_config.chat.interruption_afc_reduction
|
||||
)
|
||||
|
||||
# 检查是否已达到最大次数
|
||||
if chat_stream.context_manager.context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||
@@ -364,7 +369,7 @@ class MessageManager:
|
||||
return
|
||||
|
||||
context = chat_stream.context_manager.context
|
||||
if hasattr(context, 'unread_messages') and context.unread_messages:
|
||||
if hasattr(context, "unread_messages") and context.unread_messages:
|
||||
logger.debug(f"正在为流 {stream_id} 清除 {len(context.unread_messages)} 条未读消息")
|
||||
context.unread_messages.clear()
|
||||
else:
|
||||
|
||||
@@ -1,33 +1,33 @@
|
||||
from src.common.logger import get_logger
|
||||
|
||||
#from ..hfc_context import HfcContext
|
||||
# from ..hfc_context import HfcContext
|
||||
|
||||
logger = get_logger("notification_sender")
|
||||
|
||||
|
||||
class NotificationSender:
|
||||
@staticmethod
|
||||
async def send_goodnight_notification(context): # type: ignore
|
||||
async def send_goodnight_notification(context): # type: ignore
|
||||
"""发送晚安通知"""
|
||||
#try:
|
||||
#from ..proactive.events import ProactiveTriggerEvent
|
||||
#from ..proactive.proactive_thinker import ProactiveThinker
|
||||
|
||||
#event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight")
|
||||
#proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||
#await proactive_thinker.think(event)
|
||||
#except Exception as e:
|
||||
#logger.error(f"发送晚安通知失败: {e}")
|
||||
# try:
|
||||
# from ..proactive.events import ProactiveTriggerEvent
|
||||
# from ..proactive.proactive_thinker import ProactiveThinker
|
||||
|
||||
# event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight")
|
||||
# proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||
# await proactive_thinker.think(event)
|
||||
# except Exception as e:
|
||||
# logger.error(f"发送晚安通知失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def send_insomnia_notification(context, reason: str): # type: ignore
|
||||
async def send_insomnia_notification(context, reason: str): # type: ignore
|
||||
"""发送失眠通知"""
|
||||
#try:
|
||||
#from ..proactive.events import ProactiveTriggerEvent
|
||||
#from ..proactive.proactive_thinker import ProactiveThinker
|
||||
# try:
|
||||
# from ..proactive.events import ProactiveTriggerEvent
|
||||
# from ..proactive.proactive_thinker import ProactiveThinker
|
||||
|
||||
#event = ProactiveTriggerEvent(source="sleep_manager", reason=reason)
|
||||
#proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||
#await proactive_thinker.think(event)
|
||||
#except Exception as e:
|
||||
#logger.error(f"发送失眠通知失败: {e}")
|
||||
# event = ProactiveTriggerEvent(source="sleep_manager", reason=reason)
|
||||
# proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||
# await proactive_thinker.think(event)
|
||||
# except Exception as e:
|
||||
# logger.error(f"发送失眠通知失败: {e}")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime, timedelta, date
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -21,6 +21,7 @@ class SleepManager:
|
||||
它实现了一个状态机,根据预设的时间表、睡眠压力和随机因素,
|
||||
在不同的睡眠状态(如清醒、准备入睡、睡眠、失眠)之间进行切换。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化睡眠管理器。
|
||||
@@ -97,7 +98,7 @@ class SleepManager:
|
||||
logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...")
|
||||
else:
|
||||
logger.info("进入理论休眠时间,开始进行睡眠决策...")
|
||||
|
||||
|
||||
if global_config.sleep_system.enable_flexible_sleep:
|
||||
# --- 新的弹性睡眠逻辑 ---
|
||||
if wakeup_manager:
|
||||
@@ -112,7 +113,7 @@ class SleepManager:
|
||||
pressure_diff = (pressure_threshold - sleep_pressure) / pressure_threshold
|
||||
# 延迟分钟数,压力越低,延迟越长
|
||||
delay_minutes = int(pressure_diff * max_delay_minutes)
|
||||
|
||||
|
||||
# 确保总延迟不超过当日最大值
|
||||
remaining_delay = max_delay_minutes - self.context.total_delayed_minutes_today
|
||||
delay_minutes = min(delay_minutes, remaining_delay)
|
||||
@@ -151,9 +152,10 @@ class SleepManager:
|
||||
if wakeup_manager and global_config.sleep_system.enable_pre_sleep_notification:
|
||||
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
|
||||
self.context.current_state = SleepState.SLEEPING
|
||||
|
||||
|
||||
def _handle_preparing_sleep(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
|
||||
def _handle_preparing_sleep(
|
||||
self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]
|
||||
):
|
||||
"""处理“准备入睡”状态下的逻辑。"""
|
||||
# 如果在准备期间离开了理论睡眠时间,则取消入睡
|
||||
if not is_in_theoretical_sleep:
|
||||
@@ -166,16 +168,22 @@ class SleepManager:
|
||||
logger.info("睡眠缓冲期结束,正式进入休眠状态。")
|
||||
self.context.current_state = SleepState.SLEEPING
|
||||
self._last_fully_slept_log_time = now.timestamp()
|
||||
|
||||
|
||||
# 设置一个随机的延迟,用于触发“睡后失眠”检查
|
||||
delay_minutes_range = global_config.sleep_system.insomnia_trigger_delay_minutes
|
||||
delay_minutes = random.randint(delay_minutes_range[0], delay_minutes_range[1])
|
||||
self.context.sleep_buffer_end_time = now + timedelta(minutes=delay_minutes)
|
||||
logger.info(f"已设置睡后失眠检查,将在 {delay_minutes} 分钟后触发。")
|
||||
|
||||
|
||||
self.context.save()
|
||||
|
||||
def _handle_sleeping(self, now: datetime, is_in_theoretical_sleep: bool, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
||||
def _handle_sleeping(
|
||||
self,
|
||||
now: datetime,
|
||||
is_in_theoretical_sleep: bool,
|
||||
activity: Optional[str],
|
||||
wakeup_manager: Optional["WakeUpManager"],
|
||||
):
|
||||
"""处理“正在睡觉”状态下的逻辑。"""
|
||||
# 如果理论睡眠时间结束,则自然醒来
|
||||
if not is_in_theoretical_sleep:
|
||||
@@ -198,14 +206,16 @@ class SleepManager:
|
||||
|
||||
if insomnia_reason:
|
||||
self.context.current_state = SleepState.INSOMNIA
|
||||
|
||||
|
||||
# 设置失眠的持续时间
|
||||
duration_minutes_range = global_config.sleep_system.insomnia_duration_minutes
|
||||
duration_minutes = random.randint(*duration_minutes_range)
|
||||
self.context.sleep_buffer_end_time = now + timedelta(minutes=duration_minutes)
|
||||
|
||||
|
||||
# 发送失眠通知
|
||||
asyncio.create_task(NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason))
|
||||
asyncio.create_task(
|
||||
NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason)
|
||||
)
|
||||
logger.info(f"进入失眠状态 (原因: {insomnia_reason}),将持续 {duration_minutes} 分钟。")
|
||||
else:
|
||||
# 睡眠压力正常,不触发失眠,清除检查时间点
|
||||
|
||||
@@ -25,6 +25,7 @@ class SleepContext:
|
||||
"""
|
||||
睡眠上下文,负责封装和管理所有与睡眠相关的状态,并处理其持久化。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化睡眠上下文,并从本地存储加载初始状态。"""
|
||||
self.current_state: SleepState = SleepState.AWAKE
|
||||
@@ -83,4 +84,4 @@ class SleepContext:
|
||||
|
||||
logger.info(f"成功从本地存储加载睡眠上下文: {state}")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}")
|
||||
logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}")
|
||||
|
||||
@@ -15,23 +15,25 @@ class TimeChecker:
|
||||
self._daily_sleep_offset: int = 0
|
||||
self._daily_wake_offset: int = 0
|
||||
self._offset_date = None
|
||||
|
||||
|
||||
def _get_daily_offsets(self):
|
||||
"""获取当天的睡眠和起床时间偏移量,每天生成一次"""
|
||||
today = datetime.now().date()
|
||||
|
||||
|
||||
# 如果是新的一天,重新生成偏移量
|
||||
if self._offset_date != today:
|
||||
sleep_offset_range = global_config.sleep_system.sleep_time_offset_minutes
|
||||
wake_offset_range = global_config.sleep_system.wake_up_time_offset_minutes
|
||||
|
||||
|
||||
# 生成 ±offset_range 范围内的随机偏移量
|
||||
self._daily_sleep_offset = random.randint(-sleep_offset_range, sleep_offset_range)
|
||||
self._daily_wake_offset = random.randint(-wake_offset_range, wake_offset_range)
|
||||
self._offset_date = today
|
||||
|
||||
logger.debug(f"生成新的每日偏移量 - 睡觉时间偏移: {self._daily_sleep_offset}分钟, 起床时间偏移: {self._daily_wake_offset}分钟")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"生成新的每日偏移量 - 睡觉时间偏移: {self._daily_sleep_offset}分钟, 起床时间偏移: {self._daily_wake_offset}分钟"
|
||||
)
|
||||
|
||||
return self._daily_sleep_offset, self._daily_wake_offset
|
||||
|
||||
@staticmethod
|
||||
@@ -82,28 +84,36 @@ class TimeChecker:
|
||||
try:
|
||||
start_time_str = global_config.sleep_system.fixed_sleep_time
|
||||
end_time_str = global_config.sleep_system.fixed_wake_up_time
|
||||
|
||||
|
||||
# 获取当天的偏移量
|
||||
sleep_offset, wake_offset = self._get_daily_offsets()
|
||||
|
||||
|
||||
# 解析基础时间
|
||||
base_start_time = datetime.strptime(start_time_str, "%H:%M")
|
||||
base_end_time = datetime.strptime(end_time_str, "%H:%M")
|
||||
|
||||
|
||||
# 应用偏移量
|
||||
actual_start_time = (base_start_time + timedelta(minutes=sleep_offset)).time()
|
||||
actual_end_time = (base_end_time + timedelta(minutes=wake_offset)).time()
|
||||
|
||||
logger.debug(f"固定睡眠时间检查 - 基础时间: {start_time_str}-{end_time_str}, "
|
||||
f"偏移后时间: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')}, "
|
||||
f"当前时间: {now_time.strftime('%H:%M')}")
|
||||
|
||||
logger.debug(
|
||||
f"固定睡眠时间检查 - 基础时间: {start_time_str}-{end_time_str}, "
|
||||
f"偏移后时间: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')}, "
|
||||
f"当前时间: {now_time.strftime('%H:%M')}"
|
||||
)
|
||||
|
||||
if actual_start_time <= actual_end_time:
|
||||
if actual_start_time <= now_time < actual_end_time:
|
||||
return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})"
|
||||
return (
|
||||
True,
|
||||
f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})",
|
||||
)
|
||||
else:
|
||||
if now_time >= actual_start_time or now_time < actual_end_time:
|
||||
return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})"
|
||||
return (
|
||||
True,
|
||||
f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"固定的睡眠时间格式不正确,请使用 HH:MM 格式: {e}")
|
||||
return False, None
|
||||
return False, None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
@@ -9,6 +8,7 @@ class WakeUpContext:
|
||||
"""
|
||||
唤醒上下文,负责封装和管理所有与唤醒相关的状态,并处理其持久化。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化唤醒上下文,并从本地存储加载初始状态。"""
|
||||
self.wakeup_value: float = 0.0
|
||||
@@ -42,4 +42,4 @@ class WakeUpContext:
|
||||
"sleep_pressure": self.sleep_pressure,
|
||||
}
|
||||
local_storage[self._get_storage_key()] = state
|
||||
logger.debug(f"已将唤醒上下文保存到本地存储: {state}")
|
||||
logger.debug(f"已将唤醒上下文保存到本地存储: {state}")
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -51,7 +50,7 @@ class WakeUpManager:
|
||||
if not self.enabled:
|
||||
logger.info("唤醒度系统已禁用,跳过启动")
|
||||
return
|
||||
|
||||
|
||||
self.is_running = True
|
||||
if not self._decay_task or self._decay_task.done():
|
||||
self._decay_task = asyncio.create_task(self._decay_loop())
|
||||
@@ -88,6 +87,7 @@ class WakeUpManager:
|
||||
self.context.is_angry = False
|
||||
# 通知情绪管理系统清除愤怒状态
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
if self.angry_chat_id:
|
||||
mood_manager.clear_angry_from_wakeup(self.angry_chat_id)
|
||||
self.angry_chat_id = None
|
||||
@@ -104,7 +104,9 @@ class WakeUpManager:
|
||||
logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}")
|
||||
self.context.save()
|
||||
|
||||
def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None) -> bool:
|
||||
def add_wakeup_value(
|
||||
self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
增加唤醒度值
|
||||
|
||||
@@ -173,6 +175,7 @@ class WakeUpManager:
|
||||
|
||||
# 通知情绪管理系统进入愤怒状态
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
mood_manager.set_angry_from_wakeup(chat_id)
|
||||
|
||||
# 通知SleepManager重置睡眠状态
|
||||
@@ -194,6 +197,7 @@ class WakeUpManager:
|
||||
self.context.is_angry = False
|
||||
# 通知情绪管理系统清除愤怒状态
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
if self.angry_chat_id:
|
||||
mood_manager.clear_angry_from_wakeup(self.angry_chat_id)
|
||||
self.angry_chat_id = None
|
||||
|
||||
@@ -191,7 +191,7 @@ class ChatBot:
|
||||
try:
|
||||
# 检查聊天类型限制
|
||||
if not plus_command_instance.is_chat_type_allowed():
|
||||
is_group = message.message_info.group_info
|
||||
is_group = message.message_info.group_info
|
||||
logger.info(
|
||||
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||
)
|
||||
@@ -424,7 +424,9 @@ class ChatBot:
|
||||
await message.process()
|
||||
|
||||
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
|
||||
logger.info(f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m")
|
||||
logger.info(
|
||||
f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m"
|
||||
)
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
@@ -456,7 +458,7 @@ class ChatBot:
|
||||
result = await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message=message)
|
||||
if not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
|
||||
|
||||
|
||||
# TODO:暂不可用
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
@@ -473,14 +475,14 @@ class ChatBot:
|
||||
async def preprocess():
|
||||
# 存储消息到数据库
|
||||
from .storage import MessageStorage
|
||||
|
||||
|
||||
try:
|
||||
await MessageStorage.store_message(message, message.chat_stream)
|
||||
logger.debug(f"消息已存储到数据库: {message.message_info.message_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息到数据库失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 使用消息管理器处理消息(保持原有功能)
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ class ChatStream:
|
||||
|
||||
# 复制 stream_context,但跳过 processing_task
|
||||
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
|
||||
if hasattr(new_stream.stream_context, 'processing_task'):
|
||||
if hasattr(new_stream.stream_context, "processing_task"):
|
||||
new_stream.stream_context.processing_task = None
|
||||
|
||||
# 复制 context_manager
|
||||
@@ -377,6 +377,7 @@ class ChatStream:
|
||||
# 默认基础分
|
||||
return 0.3
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器,管理所有聊天流"""
|
||||
|
||||
@@ -563,9 +564,8 @@ class ChatManager:
|
||||
if not hasattr(stream, "context_manager"):
|
||||
# 创建新的单流上下文管理器
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
stream.context_manager = SingleStreamContextManager(
|
||||
stream_id=stream_id, context=stream.stream_context
|
||||
)
|
||||
|
||||
stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context)
|
||||
|
||||
# 保存到内存和数据库
|
||||
self.streams[stream_id] = stream
|
||||
@@ -721,6 +721,7 @@ class ChatManager:
|
||||
# 确保 ChatStream 有自己的 context_manager
|
||||
if not hasattr(stream, "context_manager"):
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
|
||||
stream.context_manager = SingleStreamContextManager(
|
||||
stream_id=stream.stream_id, context=stream.stream_context
|
||||
)
|
||||
|
||||
@@ -108,7 +108,7 @@ class MessageRecv(Message):
|
||||
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||
self.raw_message = message_dict.get("raw_message")
|
||||
|
||||
|
||||
self.chat_stream = None
|
||||
self.reply = None
|
||||
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||
|
||||
@@ -53,7 +53,11 @@ class MessageStorage:
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
else:
|
||||
# 如果没有设置display_message,使用processed_plain_text作为显示消息
|
||||
filtered_display_message = re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL) if message.processed_plain_text else ""
|
||||
filtered_display_message = (
|
||||
re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL)
|
||||
if message.processed_plain_text
|
||||
else ""
|
||||
)
|
||||
interest_value = 0
|
||||
is_mentioned = False
|
||||
reply_to = message.reply_to
|
||||
@@ -168,9 +172,11 @@ class MessageStorage:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
async with get_db_session() as session:
|
||||
matched_message = (await session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
)).scalar()
|
||||
matched_message = (
|
||||
await session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if matched_message:
|
||||
await session.execute(
|
||||
@@ -204,9 +210,11 @@ class MessageStorage:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
async with get_db_session() as session:
|
||||
image_record = (await session.execute(
|
||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||
)).scalar()
|
||||
image_record = (
|
||||
await session.execute(
|
||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||
)
|
||||
).scalar()
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
except Exception:
|
||||
return match.group(0)
|
||||
@@ -287,15 +295,19 @@ class MessageStorage:
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
|
||||
# 查找需要修复的记录:interest_value为0、null或很小的值
|
||||
query = select(Messages).where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time >= since_time) &
|
||||
(
|
||||
(Messages.interest_value == 0) |
|
||||
(Messages.interest_value.is_(None)) |
|
||||
(Messages.interest_value < 0.1)
|
||||
query = (
|
||||
select(Messages)
|
||||
.where(
|
||||
(Messages.chat_id == chat_id)
|
||||
& (Messages.time >= since_time)
|
||||
& (
|
||||
(Messages.interest_value == 0)
|
||||
| (Messages.interest_value.is_(None))
|
||||
| (Messages.interest_value < 0.1)
|
||||
)
|
||||
)
|
||||
).limit(50) # 限制每次修复的数量,避免性能问题
|
||||
.limit(50)
|
||||
) # 限制每次修复的数量,避免性能问题
|
||||
|
||||
result = await session.execute(query)
|
||||
messages_to_fix = result.scalars().all()
|
||||
@@ -307,7 +319,7 @@ class MessageStorage:
|
||||
default_interest = 0.3 # 默认中等兴趣度
|
||||
|
||||
# 如果消息内容较长,可能是重要消息,兴趣度稍高
|
||||
if hasattr(msg, 'processed_plain_text') and msg.processed_plain_text:
|
||||
if hasattr(msg, "processed_plain_text") and msg.processed_plain_text:
|
||||
text_length = len(msg.processed_plain_text)
|
||||
if text_length > 50: # 长消息
|
||||
default_interest = 0.4
|
||||
@@ -315,13 +327,15 @@ class MessageStorage:
|
||||
default_interest = 0.35
|
||||
|
||||
# 如果是被@的消息,兴趣度更高
|
||||
if getattr(msg, 'is_mentioned', False):
|
||||
if getattr(msg, "is_mentioned", False):
|
||||
default_interest = min(default_interest + 0.2, 0.8)
|
||||
|
||||
# 执行更新
|
||||
update_stmt = update(Messages).where(
|
||||
Messages.message_id == msg.message_id
|
||||
).values(interest_value=default_interest)
|
||||
update_stmt = (
|
||||
update(Messages)
|
||||
.where(Messages.message_id == msg.message_id)
|
||||
.values(interest_value=default_interest)
|
||||
)
|
||||
|
||||
result = await session.execute(update_stmt)
|
||||
if result.rowcount > 0:
|
||||
|
||||
@@ -40,7 +40,7 @@ class ChatterActionManager:
|
||||
|
||||
@staticmethod
|
||||
def create_action(
|
||||
action_name: str,
|
||||
action_name: str,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
@@ -162,7 +162,7 @@ class ChatterActionManager:
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
try:
|
||||
logger.debug(f"🎯 [ActionManager] execute_action接收到 target_message: {target_message}")
|
||||
# 通过chat_id获取chat_stream
|
||||
@@ -309,9 +309,7 @@ class ChatterActionManager:
|
||||
|
||||
# 通过message_manager更新消息的动作记录并刷新focus_energy
|
||||
await message_manager.add_action(
|
||||
stream_id=chat_stream.stream_id,
|
||||
message_id=target_message_id,
|
||||
action=action_name
|
||||
stream_id=chat_stream.stream_id, message_id=target_message_id, action=action_name
|
||||
)
|
||||
logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy")
|
||||
|
||||
@@ -321,9 +319,10 @@ class ChatterActionManager:
|
||||
|
||||
async def _reset_interruption_count_after_action(self, stream_id: str):
|
||||
"""在动作执行成功后重置打断计数"""
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
try:
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if chat_stream:
|
||||
@@ -332,7 +331,9 @@ class ChatterActionManager:
|
||||
old_count = context.context.interruption_count
|
||||
old_afc_adjustment = context.context.get_afc_threshold_adjustment()
|
||||
context.context.reset_interruption_count()
|
||||
logger.debug(f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0")
|
||||
logger.debug(
|
||||
f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"重置打断计数时出错: {e}")
|
||||
|
||||
@@ -531,7 +532,7 @@ class ChatterActionManager:
|
||||
# 根据新消息数量决定是否需要引用回复
|
||||
reply_text = ""
|
||||
is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True
|
||||
|
||||
|
||||
logger.debug(f"[send_response] message_data: {message_data}")
|
||||
|
||||
first_replied = False
|
||||
@@ -558,7 +559,9 @@ class ChatterActionManager:
|
||||
# 发送第一段回复
|
||||
if not first_replied:
|
||||
set_reply_flag = bool(message_data)
|
||||
logger.debug(f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}")
|
||||
logger.debug(
|
||||
f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}"
|
||||
)
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=chat_stream.stream_id,
|
||||
@@ -577,4 +580,4 @@ class ChatterActionManager:
|
||||
typing=True,
|
||||
)
|
||||
|
||||
return reply_text
|
||||
return reply_text
|
||||
|
||||
@@ -29,6 +29,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
replace_user_references_sync,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
|
||||
# 旧记忆系统已被移除
|
||||
# 旧记忆系统已被移除
|
||||
from src.mood.mood_manager import mood_manager
|
||||
@@ -580,7 +581,9 @@ class DefaultReplyer:
|
||||
memory_context["user_aliases"] = memory_aliases
|
||||
|
||||
if group_info_obj is not None:
|
||||
group_name = getattr(group_info_obj, "group_name", None) or getattr(group_info_obj, "group_nickname", None)
|
||||
group_name = getattr(group_info_obj, "group_name", None) or getattr(
|
||||
group_info_obj, "group_nickname", None
|
||||
)
|
||||
if group_name:
|
||||
memory_context["group_name"] = str(group_name)
|
||||
group_id = getattr(group_info_obj, "group_id", None)
|
||||
@@ -594,11 +597,7 @@ class DefaultReplyer:
|
||||
|
||||
# 检索相关记忆
|
||||
enhanced_memories = await memory_system.retrieve_relevant_memories(
|
||||
query=target,
|
||||
user_id=memory_user_id,
|
||||
scope_id=stream.stream_id,
|
||||
context=memory_context,
|
||||
limit=10
|
||||
query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10
|
||||
)
|
||||
|
||||
# 注意:记忆存储已迁移到回复生成完成后进行,不在查询阶段执行
|
||||
@@ -609,23 +608,27 @@ class DefaultReplyer:
|
||||
logger.debug(f"[记忆转换] 收到 {len(enhanced_memories)} 条原始记忆")
|
||||
for idx, memory_chunk in enumerate(enhanced_memories, 1):
|
||||
# 获取结构化内容的字符串表示
|
||||
structure_display = str(memory_chunk.content) if hasattr(memory_chunk, 'content') else "unknown"
|
||||
|
||||
structure_display = str(memory_chunk.content) if hasattr(memory_chunk, "content") else "unknown"
|
||||
|
||||
# 获取记忆内容,优先使用display
|
||||
content = memory_chunk.display or memory_chunk.text_content or ""
|
||||
|
||||
|
||||
# 调试:记录每条记忆的内容获取情况
|
||||
logger.debug(f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}")
|
||||
|
||||
running_memories.append({
|
||||
"content": content,
|
||||
"memory_type": memory_chunk.memory_type.value,
|
||||
"confidence": memory_chunk.metadata.confidence.value,
|
||||
"importance": memory_chunk.metadata.importance.value,
|
||||
"relevance": getattr(memory_chunk.metadata, 'relevance_score', 0.5),
|
||||
"source": memory_chunk.metadata.source,
|
||||
"structure": structure_display,
|
||||
})
|
||||
logger.debug(
|
||||
f"[记忆转换] 第{idx}条: display={repr(memory_chunk.display)[:80]}, text_content={repr(memory_chunk.text_content)[:80]}, final_content={repr(content)[:80]}"
|
||||
)
|
||||
|
||||
running_memories.append(
|
||||
{
|
||||
"content": content,
|
||||
"memory_type": memory_chunk.memory_type.value,
|
||||
"confidence": memory_chunk.metadata.confidence.value,
|
||||
"importance": memory_chunk.metadata.importance.value,
|
||||
"relevance": getattr(memory_chunk.metadata, "relevance_score", 0.5),
|
||||
"source": memory_chunk.metadata.source,
|
||||
"structure": structure_display,
|
||||
}
|
||||
)
|
||||
|
||||
# 构建瞬时记忆字符串
|
||||
if running_memories:
|
||||
@@ -633,7 +636,9 @@ class DefaultReplyer:
|
||||
if top_memory:
|
||||
instant_memory = top_memory[0].get("content", "")
|
||||
|
||||
logger.info(f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆")
|
||||
logger.info(
|
||||
f"增强记忆系统检索到 {len(enhanced_memories)} 条原始记忆,转换为 {len(running_memories)} 条可用记忆"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"增强记忆系统检索失败: {e}")
|
||||
@@ -650,17 +655,17 @@ class DefaultReplyer:
|
||||
memory_parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
|
||||
|
||||
# 按相关度排序,并记录相关度信息用于调试
|
||||
sorted_memories = sorted(running_memories, key=lambda x: x.get('relevance', 0.0), reverse=True)
|
||||
sorted_memories = sorted(running_memories, key=lambda x: x.get("relevance", 0.0), reverse=True)
|
||||
|
||||
# 调试相关度信息
|
||||
relevance_info = [(m.get('memory_type', 'unknown'), m.get('relevance', 0.0)) for m in sorted_memories]
|
||||
relevance_info = [(m.get("memory_type", "unknown"), m.get("relevance", 0.0)) for m in sorted_memories]
|
||||
logger.debug(f"记忆相关度信息: {relevance_info}")
|
||||
logger.debug(f"[记忆构建] 准备将 {len(sorted_memories)} 条记忆添加到提示词")
|
||||
|
||||
for idx, running_memory in enumerate(sorted_memories, 1):
|
||||
content = running_memory.get('content', '')
|
||||
memory_type = running_memory.get('memory_type', 'unknown')
|
||||
|
||||
content = running_memory.get("content", "")
|
||||
memory_type = running_memory.get("memory_type", "unknown")
|
||||
|
||||
# 跳过空内容
|
||||
if not content or not content.strip():
|
||||
logger.warning(f"[记忆构建] 跳过第 {idx} 条记忆:内容为空 (type={memory_type})")
|
||||
@@ -822,10 +827,10 @@ class DefaultReplyer:
|
||||
"""
|
||||
try:
|
||||
# 从message_manager获取真实的已读/未读消息
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
# 获取聊天流的上下文
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(chat_id)
|
||||
if chat_stream:
|
||||
@@ -1021,7 +1026,9 @@ class DefaultReplyer:
|
||||
interest_scores = {}
|
||||
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
|
||||
chatter_interest_scoring_system as interest_scoring_system,
|
||||
)
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 转换消息格式
|
||||
@@ -1204,7 +1211,7 @@ class DefaultReplyer:
|
||||
platform, # type: ignore
|
||||
reply_message.get("user_id"), # type: ignore
|
||||
reply_message.get("user_nickname"),
|
||||
reply_message.get("user_cardname")
|
||||
reply_message.get("user_cardname"),
|
||||
)
|
||||
|
||||
# 检查是否是bot自己的名字,如果是则替换为"(你)"
|
||||
@@ -1763,6 +1770,7 @@ class DefaultReplyer:
|
||||
|
||||
# 创建关系追踪器实例
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system)
|
||||
if relationship_tracker:
|
||||
# 获取用户信息以获取真实的user_id
|
||||
@@ -1805,7 +1813,7 @@ class DefaultReplyer:
|
||||
async def _store_chat_memory_async(self, reply_to: str, reply_message: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
异步存储聊天记忆(从build_memory_block迁移而来)
|
||||
|
||||
|
||||
Args:
|
||||
reply_to: 回复对象
|
||||
reply_message: 回复的原始消息
|
||||
@@ -1874,9 +1882,7 @@ class DefaultReplyer:
|
||||
memory_aliases.append(stripped)
|
||||
|
||||
alias_values = (
|
||||
user_info_dict.get("aliases")
|
||||
or user_info_dict.get("alias_names")
|
||||
or user_info_dict.get("alias")
|
||||
user_info_dict.get("aliases") or user_info_dict.get("alias_names") or user_info_dict.get("alias")
|
||||
)
|
||||
if isinstance(alias_values, (list, tuple, set)):
|
||||
for alias in alias_values:
|
||||
@@ -1900,7 +1906,9 @@ class DefaultReplyer:
|
||||
memory_context["user_aliases"] = memory_aliases
|
||||
|
||||
if group_info_obj is not None:
|
||||
group_name = getattr(group_info_obj, "group_name", None) or getattr(group_info_obj, "group_nickname", None)
|
||||
group_name = getattr(group_info_obj, "group_name", None) or getattr(
|
||||
group_info_obj, "group_nickname", None
|
||||
)
|
||||
if group_name:
|
||||
memory_context["group_name"] = str(group_name)
|
||||
group_id = getattr(group_info_obj, "group_id", None)
|
||||
@@ -1932,11 +1940,11 @@ class DefaultReplyer:
|
||||
"conversation_text": chat_history,
|
||||
"user_id": memory_user_id,
|
||||
"scope_id": stream.stream_id,
|
||||
**memory_context
|
||||
**memory_context,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"已启动记忆存储任务,用户: {memory_user_display or memory_user_id}")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -13,6 +13,7 @@ from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_m
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select, and_
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chat_message_builder")
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -274,21 +275,52 @@ async def get_actions_by_timestamp_with_chat(
|
||||
|
||||
async with get_db_session() as session:
|
||||
if limit > 0:
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(result.scalars())
|
||||
actions_result = []
|
||||
for action in reversed(actions):
|
||||
action_dict = {
|
||||
"id": action.id,
|
||||
"action_id": action.action_id,
|
||||
"time": action.time,
|
||||
"action_name": action.action_name,
|
||||
"action_data": action.action_data,
|
||||
"action_done": action.action_done,
|
||||
"action_build_into_prompt": action.action_build_into_prompt,
|
||||
"action_prompt_display": action.action_prompt_display,
|
||||
"chat_id": action.chat_id,
|
||||
"chat_info_stream_id": action.chat_info_stream_id,
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
actions_result.append(action_dict)
|
||||
else: # earliest
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time.desc())
|
||||
.order_by(ActionRecords.time.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(result.scalars())
|
||||
actions_result = []
|
||||
for action in reversed(actions):
|
||||
for action in actions:
|
||||
action_dict = {
|
||||
"id": action.id,
|
||||
"action_id": action.action_id,
|
||||
@@ -303,37 +335,6 @@ async def get_actions_by_timestamp_with_chat(
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
actions_result.append(action_dict)
|
||||
else: # earliest
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(result.scalars())
|
||||
actions_result = []
|
||||
for action in actions:
|
||||
action_dict = {
|
||||
"id": action.id,
|
||||
"action_id": action.action_id,
|
||||
"time": action.time,
|
||||
"action_name": action.action_name,
|
||||
"action_data": action.action_data,
|
||||
"action_done": action.action_done,
|
||||
"action_build_into_prompt": action.action_build_into_prompt,
|
||||
"action_prompt_display": action.action_prompt_display,
|
||||
"chat_id": action.chat_id,
|
||||
"chat_info_stream_id": action.chat_info_stream_id,
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else:
|
||||
result = await session.execute(
|
||||
select(ActionRecords)
|
||||
@@ -457,7 +458,9 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
async def get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id: str, timestamp: float, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -466,7 +469,9 @@ async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float,
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
async def get_raw_msg_before_timestamp_with_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -475,7 +480,9 @@ async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids:
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
|
||||
async def num_new_messages_since(
|
||||
chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None
|
||||
) -> int:
|
||||
"""
|
||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||
@@ -830,7 +837,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(Images).where(Images.image_id == pic_id))
|
||||
image = result.scalar_one_or_none()
|
||||
if image and hasattr(image, 'description') and image.description:
|
||||
if image and hasattr(image, "description") and image.description:
|
||||
description = image.description
|
||||
except Exception as e:
|
||||
# 如果查询失败,保持默认描述
|
||||
@@ -1017,24 +1024,29 @@ async def build_readable_messages(
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = (await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
|
||||
actions_in_range = (
|
||||
await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time,
|
||||
ActionRecords.time <= max_time,
|
||||
ActionRecords.chat_id == chat_id,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)).scalars()
|
||||
).scalars()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = (await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
)).scalars()
|
||||
action_after_latest = (
|
||||
await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
)
|
||||
).scalars()
|
||||
|
||||
# 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError
|
||||
actions = [
|
||||
@@ -1225,9 +1237,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
except Exception:
|
||||
return "?"
|
||||
|
||||
content = await replace_user_references_async(
|
||||
content, platform, anon_name_resolver, replace_bot_name=False
|
||||
)
|
||||
content = await replace_user_references_async(content, platform, anon_name_resolver, replace_bot_name=False)
|
||||
|
||||
header = f"{anon_name}说 "
|
||||
output_lines.append(header)
|
||||
|
||||
@@ -17,7 +17,7 @@ MEMORY_TYPE_CHINESE_MAPPING = {
|
||||
"goal": "目标计划",
|
||||
"experience": "经验教训",
|
||||
"contextual": "上下文信息",
|
||||
"unknown": "未知"
|
||||
"unknown": "未知",
|
||||
}
|
||||
|
||||
# 置信度等级到中文标签的映射表
|
||||
@@ -30,7 +30,7 @@ CONFIDENCE_LEVEL_CHINESE_MAPPING = {
|
||||
"MEDIUM": "中等置信度",
|
||||
"HIGH": "高置信度",
|
||||
"VERIFIED": "已验证",
|
||||
"unknown": "未知"
|
||||
"unknown": "未知",
|
||||
}
|
||||
|
||||
# 重要性等级到中文标签的映射表
|
||||
@@ -43,7 +43,7 @@ IMPORTANCE_LEVEL_CHINESE_MAPPING = {
|
||||
"NORMAL": "一般重要性",
|
||||
"HIGH": "高重要性",
|
||||
"CRITICAL": "关键重要性",
|
||||
"unknown": "未知"
|
||||
"unknown": "未知",
|
||||
}
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ def get_confidence_level_chinese_label(level) -> str:
|
||||
str: 对应的中文标签,如果找不到则返回"未知"
|
||||
"""
|
||||
# 处理枚举实例
|
||||
if hasattr(level, 'value'):
|
||||
if hasattr(level, "value"):
|
||||
level = level.value
|
||||
|
||||
# 处理数字
|
||||
@@ -94,7 +94,7 @@ def get_importance_level_chinese_label(level) -> str:
|
||||
str: 对应的中文标签,如果找不到则返回"未知"
|
||||
"""
|
||||
# 处理枚举实例
|
||||
if hasattr(level, 'value'):
|
||||
if hasattr(level, "value"):
|
||||
level = level.value
|
||||
|
||||
# 处理数字
|
||||
@@ -106,4 +106,4 @@ def get_importance_level_chinese_label(level) -> str:
|
||||
level_upper = level.upper()
|
||||
return IMPORTANCE_LEVEL_CHINESE_MAPPING.get(level_upper, "未知")
|
||||
|
||||
return "未知"
|
||||
return "未知"
|
||||
|
||||
@@ -381,12 +381,12 @@ class Prompt:
|
||||
|
||||
# 性能优化 - 为不同任务设置不同的超时时间
|
||||
task_timeouts = {
|
||||
"memory_block": 15.0, # 记忆系统 - 降低超时时间,鼓励预构建
|
||||
"tool_info": 15.0, # 工具信息
|
||||
"relation_info": 10.0, # 关系信息
|
||||
"knowledge_info": 10.0, # 知识库查询
|
||||
"cross_context": 10.0, # 上下文处理
|
||||
"expression_habits": 10.0, # 表达习惯
|
||||
"memory_block": 15.0, # 记忆系统 - 降低超时时间,鼓励预构建
|
||||
"tool_info": 15.0, # 工具信息
|
||||
"relation_info": 10.0, # 关系信息
|
||||
"knowledge_info": 10.0, # 知识库查询
|
||||
"cross_context": 10.0, # 上下文处理
|
||||
"expression_habits": 10.0, # 表达习惯
|
||||
}
|
||||
|
||||
# 分别处理每个任务,避免慢任务影响快任务
|
||||
@@ -563,7 +563,7 @@ class Prompt:
|
||||
),
|
||||
enhanced_memory_activator.get_instant_memory(
|
||||
target_message=self.parameters.target, chat_id=self.parameters.chat_id
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
try:
|
||||
@@ -606,26 +606,27 @@ class Prompt:
|
||||
"opinion": "opinion",
|
||||
"personal_fact": "personal_fact",
|
||||
"preference": "preference",
|
||||
"event": "event"
|
||||
"event": "event",
|
||||
}
|
||||
mapped_type = memory_type_mapping.get(topic, "personal_fact")
|
||||
|
||||
formatted_memories.append({
|
||||
"display": display_text,
|
||||
"memory_type": mapped_type,
|
||||
"metadata": {
|
||||
"confidence": memory.get("confidence", "未知"),
|
||||
"importance": memory.get("importance", "一般"),
|
||||
"timestamp": memory.get("timestamp", ""),
|
||||
"source": memory.get("source", "unknown"),
|
||||
"relevance_score": memory.get("relevance_score", 0.0)
|
||||
formatted_memories.append(
|
||||
{
|
||||
"display": display_text,
|
||||
"memory_type": mapped_type,
|
||||
"metadata": {
|
||||
"confidence": memory.get("confidence", "未知"),
|
||||
"importance": memory.get("importance", "一般"),
|
||||
"timestamp": memory.get("timestamp", ""),
|
||||
"source": memory.get("source", "unknown"),
|
||||
"relevance_score": memory.get("relevance_score", 0.0),
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
# 使用方括号格式格式化记忆
|
||||
memory_block = format_memories_bracket_style(
|
||||
formatted_memories,
|
||||
query_context=self.parameters.target
|
||||
formatted_memories, query_context=self.parameters.target
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"记忆格式化失败,使用简化格式: {e}")
|
||||
@@ -833,7 +834,8 @@ class Prompt:
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||
or context_data.get("safety_guidelines_block", ""),
|
||||
"chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
"chat_scene": self.parameters.chat_scene
|
||||
or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -860,7 +862,8 @@ class Prompt:
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||
or context_data.get("safety_guidelines_block", ""),
|
||||
"chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
"chat_scene": self.parameters.chat_scene
|
||||
or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@@ -305,11 +305,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = await db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp",
|
||||
) or []
|
||||
records = (
|
||||
await db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp",
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
@@ -401,7 +404,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
|
||||
async def _collect_online_time_for_period(
|
||||
collect_period: List[Tuple[str, datetime]], now: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的在线时间统计数据
|
||||
|
||||
@@ -420,11 +425,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = await db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||
order_by="-end_timestamp",
|
||||
) or []
|
||||
records = (
|
||||
await db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||
order_by="-end_timestamp",
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
@@ -476,11 +484,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
records = await db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time",
|
||||
) or []
|
||||
records = (
|
||||
await db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time",
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for message in records:
|
||||
if not isinstance(message, dict):
|
||||
@@ -1038,11 +1049,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
interval_seconds = interval_minutes * 60
|
||||
|
||||
# 单次查询 LLMUsage
|
||||
llm_records = await db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": start_time}},
|
||||
order_by="-timestamp",
|
||||
) or []
|
||||
llm_records = (
|
||||
await db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": start_time}},
|
||||
order_by="-timestamp",
|
||||
)
|
||||
or []
|
||||
)
|
||||
for record in llm_records:
|
||||
if not isinstance(record, dict) or not record.get("timestamp"):
|
||||
continue
|
||||
@@ -1068,11 +1082,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost_by_module[module_name][idx] += cost
|
||||
|
||||
# 单次查询 Messages
|
||||
msg_records = await db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": start_time.timestamp()}},
|
||||
order_by="-time",
|
||||
) or []
|
||||
msg_records = (
|
||||
await db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": start_time.timestamp()}},
|
||||
order_by="-time",
|
||||
)
|
||||
or []
|
||||
)
|
||||
for msg in msg_records:
|
||||
if not isinstance(msg, dict) or not msg.get("time"):
|
||||
continue
|
||||
@@ -1375,4 +1392,4 @@ class StatisticOutputTask(AsyncTask):
|
||||
}});
|
||||
</script>
|
||||
</div>
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -675,7 +675,6 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
|
||||
if loop.is_running():
|
||||
# 如果事件循环在运行,从其他线程提交并等待结果
|
||||
try:
|
||||
from concurrent.futures import TimeoutError
|
||||
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
person_info_manager.get_value(person_id, "person_name"), loop
|
||||
|
||||
@@ -81,14 +81,16 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
record = (await session.execute(
|
||||
select(ImageDescriptions).where(
|
||||
and_(
|
||||
ImageDescriptions.image_description_hash == image_hash,
|
||||
ImageDescriptions.type == description_type,
|
||||
record = (
|
||||
await session.execute(
|
||||
select(ImageDescriptions).where(
|
||||
and_(
|
||||
ImageDescriptions.image_description_hash == image_hash,
|
||||
ImageDescriptions.type == description_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
)).scalar()
|
||||
).scalar()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
||||
@@ -107,14 +109,16 @@ class ImageManager:
|
||||
current_timestamp = time.time()
|
||||
async with get_db_session() as session:
|
||||
# 查找现有记录
|
||||
existing = (await session.execute(
|
||||
select(ImageDescriptions).where(
|
||||
and_(
|
||||
ImageDescriptions.image_description_hash == image_hash,
|
||||
ImageDescriptions.type == description_type,
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(ImageDescriptions).where(
|
||||
and_(
|
||||
ImageDescriptions.image_description_hash == image_hash,
|
||||
ImageDescriptions.type == description_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
)).scalar()
|
||||
).scalar()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
@@ -262,9 +266,11 @@ class ImageManager:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
async with get_db_session() as session:
|
||||
existing_img = (await session.execute(
|
||||
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
|
||||
)).scalar()
|
||||
existing_img = (
|
||||
await session.execute(
|
||||
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if existing_img:
|
||||
existing_img.path = file_path
|
||||
|
||||
@@ -35,7 +35,7 @@ logger = get_logger("utils_video")
|
||||
# Rust模块可用性检测
|
||||
RUST_VIDEO_AVAILABLE = False
|
||||
try:
|
||||
import rust_video # pyright: ignore[reportMissingImports]
|
||||
import rust_video # pyright: ignore[reportMissingImports]
|
||||
|
||||
RUST_VIDEO_AVAILABLE = True
|
||||
logger.info("✅ Rust 视频处理模块加载成功")
|
||||
@@ -222,7 +222,7 @@ class VideoAnalyzer:
|
||||
return None
|
||||
|
||||
async def _store_video_result(
|
||||
self, video_hash: str, description: str, metadata: Optional[Dict] = None
|
||||
self, video_hash: str, description: str, metadata: Optional[Dict] = None
|
||||
) -> Optional[Videos]:
|
||||
"""存储视频分析结果到数据库"""
|
||||
# 检查描述是否为错误信息,如果是则不保存
|
||||
|
||||
Reference in New Issue
Block a user