feat: 批量生成文本embedding,优化兴趣匹配计算逻辑,支持消息兴趣值的批量更新
This commit is contained in:
@@ -442,6 +442,43 @@ class BotInterestManager:
|
||||
logger.debug(f"✅ 消息embedding生成成功,维度: {len(embedding)}")
|
||||
return embedding
|
||||
|
||||
async def generate_embeddings_for_texts(
|
||||
self, text_map: dict[str, str], batch_size: int = 16
|
||||
) -> dict[str, list[float]]:
|
||||
"""批量获取多段文本的embedding,供上层统一处理。"""
|
||||
if not text_map:
|
||||
return {}
|
||||
|
||||
if not self.embedding_request:
|
||||
raise RuntimeError("Embedding客户端未初始化")
|
||||
|
||||
batch_size = max(1, batch_size)
|
||||
keys = list(text_map.keys())
|
||||
results: dict[str, list[float]] = {}
|
||||
|
||||
for start in range(0, len(keys), batch_size):
|
||||
chunk_keys = keys[start : start + batch_size]
|
||||
chunk_texts = [text_map[key] or "" for key in chunk_keys]
|
||||
|
||||
try:
|
||||
chunk_embeddings, _ = await self.embedding_request.get_embedding(chunk_texts)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}")
|
||||
continue
|
||||
|
||||
if isinstance(chunk_embeddings, list) and chunk_embeddings and isinstance(chunk_embeddings[0], list):
|
||||
normalized = chunk_embeddings
|
||||
elif isinstance(chunk_embeddings, list):
|
||||
normalized = [chunk_embeddings]
|
||||
else:
|
||||
normalized = []
|
||||
|
||||
for idx_offset, message_id in enumerate(chunk_keys):
|
||||
vector = normalized[idx_offset] if idx_offset < len(normalized) else []
|
||||
results[message_id] = vector
|
||||
|
||||
return results
|
||||
|
||||
async def _calculate_similarity_scores(
|
||||
self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str]
|
||||
):
|
||||
@@ -473,7 +510,7 @@ class BotInterestManager:
|
||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||
|
||||
async def calculate_interest_match(
|
||||
self, message_text: str, keywords: list[str] | None = None
|
||||
self, message_text: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
||||
) -> InterestMatchResult:
|
||||
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
|
||||
|
||||
@@ -505,6 +542,7 @@ class BotInterestManager:
|
||||
|
||||
# 生成消息的embedding
|
||||
logger.debug("正在生成消息 embedding...")
|
||||
if not message_embedding:
|
||||
message_embedding = await self._get_embedding(message_text)
|
||||
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
|
||||
|
||||
|
||||
@@ -88,8 +88,13 @@ class SingleStreamContextManager:
|
||||
self.context.enable_cache(True)
|
||||
logger.debug(f"为StreamContext {self.stream_id} 启用缓存系统")
|
||||
|
||||
# 先计算兴趣值(需要在缓存前计算)
|
||||
await self._calculate_message_interest(message)
|
||||
# 新消息默认占位兴趣值,延迟到 Chatter 批量处理阶段
|
||||
if message.interest_value is None:
|
||||
message.interest_value = 0.3
|
||||
message.should_reply = False
|
||||
message.should_act = False
|
||||
message.interest_calculated = False
|
||||
message.semantic_embedding = None
|
||||
message.is_read = False
|
||||
|
||||
# 使用StreamContext的智能缓存功能
|
||||
@@ -440,6 +445,7 @@ class SingleStreamContextManager:
|
||||
message.interest_value = result.interest_value
|
||||
message.should_reply = result.should_reply
|
||||
message.should_act = result.should_act
|
||||
message.interest_calculated = True
|
||||
|
||||
logger.debug(
|
||||
f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
@@ -448,6 +454,7 @@ class SingleStreamContextManager:
|
||||
return result.interest_value
|
||||
else:
|
||||
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
|
||||
message.interest_calculated = False
|
||||
return 0.5
|
||||
else:
|
||||
logger.debug("未找到兴趣值计算器,使用默认兴趣值")
|
||||
@@ -455,6 +462,8 @@ class SingleStreamContextManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True)
|
||||
if hasattr(message, "interest_calculated"):
|
||||
message.interest_calculated = False
|
||||
return 0.5
|
||||
|
||||
def _detect_chat_type(self, message: DatabaseMessages):
|
||||
|
||||
@@ -110,6 +110,7 @@ def init_prompt():
|
||||
|
||||
## 其他信息
|
||||
{memory_block}
|
||||
|
||||
{relation_info_block}
|
||||
|
||||
{extra_info_block}
|
||||
@@ -579,7 +580,7 @@ class DefaultReplyer:
|
||||
|
||||
try:
|
||||
from src.memory_graph.manager_singleton import get_unified_memory_manager
|
||||
from src.memory_graph.utils.memory_formatter import format_memory_for_prompt
|
||||
from src.memory_graph.utils.three_tier_formatter import memory_formatter
|
||||
|
||||
unified_manager = get_unified_memory_manager()
|
||||
if not unified_manager:
|
||||
@@ -602,38 +603,12 @@ class DefaultReplyer:
|
||||
short_term_memories = search_result.get("short_term_memories", [])
|
||||
long_term_memories = search_result.get("long_term_memories", [])
|
||||
|
||||
memory_parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
|
||||
|
||||
# 添加感知记忆(最近的消息块)
|
||||
if perceptual_blocks:
|
||||
memory_parts.append("#### 🌊 感知记忆")
|
||||
for block in perceptual_blocks:
|
||||
messages = block.messages if hasattr(block, 'messages') else []
|
||||
if messages:
|
||||
block_content = "\n".join([
|
||||
f"{msg.get('sender_name', msg.get('sender_id', ''))}: {msg.get('content', '')[:30]}"
|
||||
for msg in messages
|
||||
])
|
||||
memory_parts.append(f"- {block_content}")
|
||||
memory_parts.append("")
|
||||
|
||||
# 添加短期记忆(结构化活跃记忆)
|
||||
if short_term_memories:
|
||||
memory_parts.append("#### 💭 短期记忆")
|
||||
for mem in short_term_memories:
|
||||
content = format_memory_for_prompt(mem, include_metadata=False)
|
||||
if content:
|
||||
memory_parts.append(f"- {content}")
|
||||
memory_parts.append("")
|
||||
|
||||
# 添加长期记忆(图谱记忆)
|
||||
if long_term_memories:
|
||||
memory_parts.append("#### 🗄️ 长期记忆")
|
||||
for mem in long_term_memories:
|
||||
content = format_memory_for_prompt(mem, include_metadata=False)
|
||||
if content:
|
||||
memory_parts.append(f"- {content}")
|
||||
memory_parts.append("")
|
||||
# 使用新的三级记忆格式化器
|
||||
formatted_memories = await memory_formatter.format_all_tiers(
|
||||
perceptual_blocks=perceptual_blocks,
|
||||
short_term_memories=short_term_memories,
|
||||
long_term_memories=long_term_memories
|
||||
)
|
||||
|
||||
total_count = len(perceptual_blocks) + len(short_term_memories) + len(long_term_memories)
|
||||
if total_count > 0:
|
||||
@@ -642,7 +617,11 @@ class DefaultReplyer:
|
||||
f"(感知:{len(perceptual_blocks)}, 短期:{len(short_term_memories)}, 长期:{len(long_term_memories)})"
|
||||
)
|
||||
|
||||
return "\n".join(memory_parts) if len(memory_parts) > 2 else ""
|
||||
# 添加标题并返回格式化后的记忆
|
||||
if formatted_memories.strip():
|
||||
return "### 🧠 相关记忆 (Relevant Memories)\n\n" + formatted_memories
|
||||
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[三层记忆] 检索失败: {e}", exc_info=True)
|
||||
|
||||
@@ -152,6 +152,10 @@ class DatabaseMessages(BaseDataModel):
|
||||
group_info=self.group_info,
|
||||
)
|
||||
|
||||
# 扩展运行时字段
|
||||
self.semantic_embedding = kwargs.pop("semantic_embedding", None)
|
||||
self.interest_calculated = kwargs.pop("interest_calculated", False)
|
||||
|
||||
# 处理额外传入的字段(kwargs)
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
|
||||
@@ -652,7 +652,7 @@ class AiohttpGeminiClient(BaseClient):
|
||||
async def get_embedding(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
embedding_input: str,
|
||||
embedding_input: str | list[str],
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
|
||||
@@ -51,8 +51,8 @@ class APIResponse:
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
"""工具调用 [(工具名称, 工具参数), ...]"""
|
||||
|
||||
embedding: list[float] | None = None
|
||||
"""嵌入向量"""
|
||||
embedding: list[float] | list[list[float]] | None = None
|
||||
"""嵌入结果(单条时为一维向量,批量时为向量列表)"""
|
||||
|
||||
usage: UsageRecord | None = None
|
||||
"""使用情况 (prompt_tokens, completion_tokens, total_tokens)"""
|
||||
@@ -105,7 +105,7 @@ class BaseClient(ABC):
|
||||
async def get_embedding(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
embedding_input: str,
|
||||
embedding_input: str | list[str],
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
|
||||
@@ -580,7 +580,7 @@ class OpenaiClient(BaseClient):
|
||||
async def get_embedding(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
embedding_input: str,
|
||||
embedding_input: str | list[str],
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
@@ -590,6 +590,7 @@ class OpenaiClient(BaseClient):
|
||||
:return: 嵌入响应
|
||||
"""
|
||||
client = self._create_client()
|
||||
is_batch_request = isinstance(embedding_input, list)
|
||||
try:
|
||||
raw_response = await client.embeddings.create(
|
||||
model=model_info.model_identifier,
|
||||
@@ -616,7 +617,8 @@ class OpenaiClient(BaseClient):
|
||||
|
||||
# 解析嵌入响应
|
||||
if len(raw_response.data) > 0:
|
||||
response.embedding = raw_response.data[0].embedding
|
||||
embeddings = [item.embedding for item in raw_response.data]
|
||||
response.embedding = embeddings if is_batch_request else embeddings[0]
|
||||
else:
|
||||
raise RespParseException(
|
||||
raw_response,
|
||||
|
||||
@@ -1111,15 +1111,15 @@ class LLMRequest:
|
||||
|
||||
return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]:
|
||||
async def get_embedding(self, embedding_input: str | list[str]) -> tuple[list[float] | list[list[float]], str]:
|
||||
"""
|
||||
获取嵌入向量。
|
||||
获取嵌入向量,支持批量文本
|
||||
|
||||
Args:
|
||||
embedding_input (str): 获取嵌入的目标
|
||||
embedding_input (str | list[str]): 需要生成嵌入的文本或文本列表
|
||||
|
||||
Returns:
|
||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||
(Tuple[Union[List[float], List[List[float]]], str]): 嵌入结果及使用的模型名称
|
||||
"""
|
||||
start_time = time.time()
|
||||
response, model_info = await self._strategy.execute_with_failover(
|
||||
@@ -1128,10 +1128,25 @@ class LLMRequest:
|
||||
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
|
||||
|
||||
if not response.embedding:
|
||||
if response.embedding is None:
|
||||
raise RuntimeError("获取embedding失败")
|
||||
|
||||
return response.embedding, model_info.name
|
||||
embeddings = response.embedding
|
||||
is_batch_request = isinstance(embedding_input, list)
|
||||
|
||||
if is_batch_request:
|
||||
if not isinstance(embeddings, list):
|
||||
raise RuntimeError("获取embedding失败,批量结果格式异常")
|
||||
|
||||
if embeddings and not isinstance(embeddings[0], list):
|
||||
embeddings = [embeddings] # type: ignore[list-item]
|
||||
|
||||
return embeddings, model_info.name
|
||||
|
||||
if isinstance(embeddings, list) and embeddings and isinstance(embeddings[0], list):
|
||||
return embeddings[0], model_info.name
|
||||
|
||||
return embeddings, model_info.name
|
||||
|
||||
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
|
||||
"""
|
||||
|
||||
@@ -308,7 +308,7 @@ class UnifiedMemoryManager:
|
||||
from src.memory_graph.utils.three_tier_formatter import memory_formatter
|
||||
|
||||
# 使用新的三级记忆格式化器
|
||||
perceptual_desc = memory_formatter.format_perceptual_memory(perceptual_blocks)
|
||||
perceptual_desc = await memory_formatter.format_perceptual_memory(perceptual_blocks)
|
||||
short_term_desc = memory_formatter.format_short_term_memory(short_term_memories)
|
||||
|
||||
# 构建聊天历史块(如果提供)
|
||||
|
||||
@@ -1,234 +0,0 @@
|
||||
"""
|
||||
记忆格式化工具
|
||||
|
||||
提供将记忆对象格式化为提示词的功能,使用 "主体-主题(属性)" 格式。
|
||||
"""
|
||||
|
||||
from src.memory_graph.models import Memory, MemoryNode, NodeType
|
||||
from src.memory_graph.models import ShortTermMemory
|
||||
|
||||
|
||||
def get_memory_type_label(memory_type: str) -> str:
|
||||
"""
|
||||
获取记忆类型的中文标签
|
||||
|
||||
Args:
|
||||
memory_type: 记忆类型(英文)
|
||||
|
||||
Returns:
|
||||
中文标签
|
||||
"""
|
||||
type_mapping = {
|
||||
"事实": "事实",
|
||||
"事件": "事件",
|
||||
"观点": "观点",
|
||||
"关系": "关系",
|
||||
"目标": "目标",
|
||||
"计划": "计划",
|
||||
"fact": "事实",
|
||||
"event": "事件",
|
||||
"opinion": "观点",
|
||||
"relation": "关系",
|
||||
"goal": "目标",
|
||||
"plan": "计划",
|
||||
"unknown": "未知",
|
||||
}
|
||||
return type_mapping.get(memory_type.lower(), memory_type)
|
||||
|
||||
|
||||
def format_memory_for_prompt(memory: Memory | ShortTermMemory, include_metadata: bool = True) -> str:
|
||||
"""
|
||||
格式化记忆为提示词文本
|
||||
|
||||
使用 "主体-主题(属性)" 格式,例如:
|
||||
- "张三-职业(程序员, 公司=MoFox)"
|
||||
- "小明-喜欢(Python, 原因=简洁优雅)"
|
||||
- "拾风-地址(https://mofox.com)"
|
||||
|
||||
Args:
|
||||
memory: Memory 或 ShortTermMemory 对象
|
||||
include_metadata: 是否包含元数据(如重要性、时间等)
|
||||
|
||||
Returns:
|
||||
格式化后的记忆文本
|
||||
"""
|
||||
if isinstance(memory, ShortTermMemory):
|
||||
return _format_short_term_memory(memory, include_metadata)
|
||||
elif isinstance(memory, Memory):
|
||||
return _format_long_term_memory(memory, include_metadata)
|
||||
else:
|
||||
return str(memory)
|
||||
|
||||
|
||||
def _format_short_term_memory(mem: ShortTermMemory, include_metadata: bool) -> str:
|
||||
"""
|
||||
格式化短期记忆
|
||||
|
||||
Args:
|
||||
mem: ShortTermMemory 对象
|
||||
include_metadata: 是否包含元数据
|
||||
|
||||
Returns:
|
||||
格式化后的文本
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# 主体
|
||||
subject = mem.subject or ""
|
||||
# 主题
|
||||
topic = mem.topic or ""
|
||||
# 客体
|
||||
obj = mem.object or ""
|
||||
|
||||
# 构建基础格式:主体-主题
|
||||
if subject and topic:
|
||||
base = f"{subject}-{topic}"
|
||||
elif subject:
|
||||
base = subject
|
||||
elif topic:
|
||||
base = topic
|
||||
else:
|
||||
# 如果没有结构化字段,使用 content
|
||||
# 防御性编程:确保 content 是字符串
|
||||
if isinstance(mem.content, list):
|
||||
return " ".join(str(item) for item in mem.content)
|
||||
return str(mem.content) if mem.content else ""
|
||||
|
||||
# 添加客体和属性
|
||||
attr_parts = []
|
||||
if obj:
|
||||
attr_parts.append(obj)
|
||||
|
||||
# 添加属性
|
||||
if mem.attributes:
|
||||
for key, value in mem.attributes.items():
|
||||
if value:
|
||||
attr_parts.append(f"{key}={value}")
|
||||
|
||||
# 组合
|
||||
if attr_parts:
|
||||
result = f"{base}({', '.join(attr_parts)})"
|
||||
else:
|
||||
result = base
|
||||
|
||||
# 添加元数据(可选)
|
||||
if include_metadata:
|
||||
metadata_parts = []
|
||||
if mem.memory_type:
|
||||
metadata_parts.append(f"类型:{get_memory_type_label(mem.memory_type)}")
|
||||
if mem.importance > 0:
|
||||
metadata_parts.append(f"重要性:{mem.importance:.2f}")
|
||||
|
||||
if metadata_parts:
|
||||
result = f"{result} [{', '.join(metadata_parts)}]"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _format_long_term_memory(mem: Memory, include_metadata: bool) -> str:
|
||||
"""
|
||||
格式化长期记忆(Memory 对象)
|
||||
|
||||
Args:
|
||||
mem: Memory 对象
|
||||
include_metadata: 是否包含元数据
|
||||
|
||||
Returns:
|
||||
格式化后的文本
|
||||
"""
|
||||
from src.memory_graph.models import EdgeType
|
||||
|
||||
# 获取主体节点
|
||||
subject_node = mem.get_subject_node()
|
||||
if not subject_node:
|
||||
return mem.to_text()
|
||||
|
||||
subject = subject_node.content
|
||||
|
||||
# 查找主题节点
|
||||
topic_node = None
|
||||
for edge in mem.edges:
|
||||
edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
|
||||
if edge_type == "记忆类型" and edge.source_id == mem.subject_id:
|
||||
topic_node = mem.get_node_by_id(edge.target_id)
|
||||
break
|
||||
|
||||
if not topic_node:
|
||||
return subject
|
||||
|
||||
topic = topic_node.content
|
||||
|
||||
# 基础格式:主体-主题
|
||||
base = f"{subject}-{topic}"
|
||||
|
||||
# 收集客体和属性
|
||||
attr_parts = []
|
||||
|
||||
# 查找客体节点(通过核心关系边)
|
||||
for edge in mem.edges:
|
||||
edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
|
||||
if edge_type == "核心关系" and edge.source_id == topic_node.id:
|
||||
obj_node = mem.get_node_by_id(edge.target_id)
|
||||
if obj_node:
|
||||
# 如果有关系名称,使用关系名称
|
||||
if edge.relation and edge.relation != "未知":
|
||||
attr_parts.append(f"{edge.relation}={obj_node.content}")
|
||||
else:
|
||||
attr_parts.append(obj_node.content)
|
||||
|
||||
# 查找属性节点
|
||||
for node in mem.nodes:
|
||||
if node.node_type == NodeType.ATTRIBUTE:
|
||||
# 属性节点的 content 格式可能是 "key=value" 或 "value"
|
||||
attr_parts.append(node.content)
|
||||
|
||||
# 组合
|
||||
if attr_parts:
|
||||
result = f"{base}({', '.join(attr_parts)})"
|
||||
else:
|
||||
result = base
|
||||
|
||||
# 添加元数据(可选)
|
||||
if include_metadata:
|
||||
metadata_parts = []
|
||||
if mem.memory_type:
|
||||
type_value = mem.memory_type.value if hasattr(mem.memory_type, 'value') else str(mem.memory_type)
|
||||
metadata_parts.append(f"类型:{get_memory_type_label(type_value)}")
|
||||
if mem.importance > 0:
|
||||
metadata_parts.append(f"重要性:{mem.importance:.2f}")
|
||||
|
||||
if metadata_parts:
|
||||
result = f"{result} [{', '.join(metadata_parts)}]"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def format_memories_block(
|
||||
memories: list[Memory | ShortTermMemory],
|
||||
title: str = "相关记忆",
|
||||
max_count: int = 10,
|
||||
include_metadata: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
格式化多个记忆为提示词块
|
||||
|
||||
Args:
|
||||
memories: 记忆列表
|
||||
title: 块标题
|
||||
max_count: 最多显示的记忆数量
|
||||
include_metadata: 是否包含元数据
|
||||
|
||||
Returns:
|
||||
格式化后的记忆块
|
||||
"""
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
lines = [f"### 🧠 {title}", ""]
|
||||
|
||||
for mem in memories[:max_count]:
|
||||
formatted = format_memory_for_prompt(mem, include_metadata=include_metadata)
|
||||
if formatted:
|
||||
lines.append(f"- {formatted}")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -22,7 +22,7 @@ class ThreeTierMemoryFormatter:
|
||||
"""初始化格式化器"""
|
||||
pass
|
||||
|
||||
def format_perceptual_memory(self, blocks: list[MemoryBlock]) -> str:
|
||||
async def format_perceptual_memory(self, blocks: list[MemoryBlock]) -> str:
|
||||
"""
|
||||
格式化感知记忆为提示词
|
||||
|
||||
@@ -53,7 +53,7 @@ class ThreeTierMemoryFormatter:
|
||||
for block in blocks:
|
||||
# 提取时间和聊天流信息
|
||||
time_str = self._extract_time_from_block(block)
|
||||
stream_name = self._extract_stream_name_from_block(block)
|
||||
stream_name = await self._extract_stream_name_from_block(block)
|
||||
|
||||
# 添加块标题
|
||||
lines.append(f"- 【{time_str} ({stream_name})】")
|
||||
@@ -122,7 +122,7 @@ class ThreeTierMemoryFormatter:
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_all_tiers(
|
||||
async def format_all_tiers(
|
||||
self,
|
||||
perceptual_blocks: list[MemoryBlock],
|
||||
short_term_memories: list[ShortTermMemory],
|
||||
@@ -142,7 +142,7 @@ class ThreeTierMemoryFormatter:
|
||||
sections = []
|
||||
|
||||
# 感知记忆
|
||||
perceptual_text = self.format_perceptual_memory(perceptual_blocks)
|
||||
perceptual_text = await self.format_perceptual_memory(perceptual_blocks)
|
||||
if perceptual_text:
|
||||
sections.append("### 感知记忆(即时对话)")
|
||||
sections.append(perceptual_text)
|
||||
@@ -198,7 +198,7 @@ class ThreeTierMemoryFormatter:
|
||||
|
||||
return "未知时间"
|
||||
|
||||
def _extract_stream_name_from_block(self, block: MemoryBlock) -> str:
|
||||
async def _extract_stream_name_from_block(self, block: MemoryBlock) -> str:
|
||||
"""
|
||||
从记忆块中提取聊天流名称
|
||||
|
||||
@@ -208,18 +208,31 @@ class ThreeTierMemoryFormatter:
|
||||
Returns:
|
||||
聊天流名称
|
||||
"""
|
||||
# 尝试从元数据中获取
|
||||
if block.metadata:
|
||||
stream_name = block.metadata.get("stream_name") or block.metadata.get("chat_stream")
|
||||
if stream_name:
|
||||
return stream_name
|
||||
stream_id = None
|
||||
|
||||
# 尝试从消息中提取
|
||||
if block.messages:
|
||||
# 首先尝试从元数据中获取 stream_id
|
||||
if block.metadata:
|
||||
stream_id = block.metadata.get("stream_id")
|
||||
|
||||
# 如果从元数据中没找到,尝试从消息中提取
|
||||
if not stream_id and block.messages:
|
||||
first_msg = block.messages[0]
|
||||
stream_name = first_msg.get("stream_name") or first_msg.get("chat_stream")
|
||||
if stream_name:
|
||||
return stream_name
|
||||
stream_id = first_msg.get("stream_id") or first_msg.get("chat_id")
|
||||
|
||||
# 如果有 stream_id,尝试获取实际的流名称
|
||||
if stream_id:
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
actual_name = await chat_manager.get_stream_name(stream_id)
|
||||
if actual_name:
|
||||
return actual_name
|
||||
else:
|
||||
# 如果获取不到名称,返回 stream_id 的截断版本
|
||||
return stream_id[:12] + "..." if len(stream_id) > 12 else stream_id
|
||||
except Exception:
|
||||
# 如果获取失败,返回 stream_id 的截断版本
|
||||
return stream_id[:12] + "..." if len(stream_id) > 12 else stream_id
|
||||
|
||||
return "默认聊天"
|
||||
|
||||
@@ -375,7 +388,7 @@ class ThreeTierMemoryFormatter:
|
||||
|
||||
return type_mapping.get(type_value, "事实")
|
||||
|
||||
def format_for_context_injection(
|
||||
async def format_for_context_injection(
|
||||
self,
|
||||
query: str,
|
||||
perceptual_blocks: list[MemoryBlock],
|
||||
@@ -407,7 +420,7 @@ class ThreeTierMemoryFormatter:
|
||||
limited_short_term = short_term_memories[:max_short_term]
|
||||
limited_long_term = long_term_memories[:max_long_term]
|
||||
|
||||
all_tiers_text = self.format_all_tiers(
|
||||
all_tiers_text = await self.format_all_tiers(
|
||||
limited_perceptual,
|
||||
limited_short_term,
|
||||
limited_long_term
|
||||
|
||||
@@ -185,18 +185,19 @@ async def initialize_smart_interests(personality_description: str, personality_i
|
||||
await interest_service.initialize_smart_interests(personality_description, personality_id)
|
||||
|
||||
|
||||
async def calculate_interest_match(content: str, keywords: list[str] | None = None):
|
||||
"""
|
||||
计算内容与兴趣的匹配度
|
||||
async def calculate_interest_match(
|
||||
content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
||||
):
|
||||
"""计算消息兴趣匹配,返回匹配结果"""
|
||||
if not content:
|
||||
logger.warning("[PersonAPI] 请求兴趣匹配时 content 为空")
|
||||
return None
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
keywords: 关键词列表
|
||||
|
||||
Returns:
|
||||
匹配结果
|
||||
"""
|
||||
return await interest_service.calculate_interest_match(content, keywords)
|
||||
try:
|
||||
return await interest_service.calculate_interest_match(content, keywords, message_embedding)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 计算消息兴趣匹配失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -213,7 +214,7 @@ def get_system_stats() -> dict[str, Any]:
|
||||
"""
|
||||
return {
|
||||
"relationship_service": relationship_service.get_cache_stats(),
|
||||
"interest_service": interest_service.get_interest_stats()
|
||||
"interest_service": interest_service.get_interest_stats(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -40,13 +40,16 @@ class InterestService:
|
||||
logger.error(f"初始化智能兴趣系统失败: {e}")
|
||||
self.is_initialized = False
|
||||
|
||||
async def calculate_interest_match(self, content: str, keywords: list[str] | None = None):
|
||||
async def calculate_interest_match(
|
||||
self, content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
||||
):
|
||||
"""
|
||||
计算内容与兴趣的匹配度
|
||||
计算消息与兴趣的匹配度
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
keywords: 关键词列表
|
||||
keywords: 关键字列表
|
||||
message_embedding: 已经生成的消息embedding,可选
|
||||
|
||||
Returns:
|
||||
匹配结果
|
||||
@@ -57,12 +60,12 @@ class InterestService:
|
||||
|
||||
try:
|
||||
if not keywords:
|
||||
# 如果没有关键词,尝试从内容提取
|
||||
# 如果没有关键字,则从内容中提取
|
||||
keywords = self._extract_keywords_from_content(content)
|
||||
|
||||
return await bot_interest_manager.calculate_interest_match(content, keywords)
|
||||
return await bot_interest_manager.calculate_interest_match(content, keywords, message_embedding)
|
||||
except Exception as e:
|
||||
logger.error(f"计算兴趣匹配度失败: {e}")
|
||||
logger.error(f"计算兴趣匹配失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_keywords_from_content(self, content: str) -> list[str]:
|
||||
|
||||
@@ -103,7 +103,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
# 1. 计算兴趣匹配分
|
||||
keywords = self._extract_keywords_from_database(message)
|
||||
interest_match_score = await self._calculate_interest_match_score(content, keywords)
|
||||
interest_match_score = await self._calculate_interest_match_score(message, content, keywords)
|
||||
logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}")
|
||||
|
||||
# 2. 计算关系分
|
||||
@@ -180,7 +180,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e)
|
||||
)
|
||||
|
||||
async def _calculate_interest_match_score(self, content: str, keywords: list[str] | None = None) -> float:
|
||||
async def _calculate_interest_match_score(
|
||||
self, message: "DatabaseMessages", content: str, keywords: list[str] | None = None
|
||||
) -> float:
|
||||
"""计算兴趣匹配度(使用智能兴趣匹配系统,带超时保护)"""
|
||||
|
||||
# 调试日志:检查各个条件
|
||||
@@ -199,7 +201,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
try:
|
||||
# 使用机器人的兴趣标签系统进行智能匹配(1.5秒超时保护)
|
||||
match_result = await asyncio.wait_for(
|
||||
bot_interest_manager.calculate_interest_match(content, keywords or []),
|
||||
bot_interest_manager.calculate_interest_match(
|
||||
content, keywords or [], getattr(message, "semantic_embedding", None)
|
||||
),
|
||||
timeout=1.5
|
||||
)
|
||||
logger.debug(f"兴趣匹配结果: {match_result}")
|
||||
|
||||
@@ -7,6 +7,9 @@ import asyncio
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager
|
||||
@@ -19,6 +22,7 @@ if TYPE_CHECKING:
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.info_data_model import Plan
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 导入提示词模块以确保其被初始化
|
||||
|
||||
@@ -115,6 +119,74 @@ class ChatterActionPlanner:
|
||||
context.processing_message_id = None
|
||||
return [], None
|
||||
|
||||
async def _prepare_interest_scores(
|
||||
self, context: "StreamContext | None", unread_messages: list["DatabaseMessages"]
|
||||
) -> None:
|
||||
"""在执行规划前,为未计算兴趣的消息批量补齐兴趣数据"""
|
||||
if not context or not unread_messages:
|
||||
return
|
||||
|
||||
pending_messages = [msg for msg in unread_messages if not getattr(msg, "interest_calculated", False)]
|
||||
if not pending_messages:
|
||||
return
|
||||
|
||||
logger.debug(f"批量兴趣值计算:待处理 {len(pending_messages)} 条消息")
|
||||
|
||||
if not bot_interest_manager.is_initialized:
|
||||
logger.debug("bot_interest_manager 未初始化,跳过批量兴趣计算")
|
||||
return
|
||||
|
||||
try:
|
||||
interest_manager = get_interest_manager()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(f"获取兴趣管理器失败: {exc}")
|
||||
return
|
||||
|
||||
if not interest_manager or not interest_manager.has_calculator():
|
||||
logger.debug("当前无可用兴趣计算器,跳过批量兴趣计算")
|
||||
return
|
||||
|
||||
text_map: dict[str, str] = {}
|
||||
for message in pending_messages:
|
||||
text = getattr(message, "processed_plain_text", None) or getattr(message, "display_message", "") or ""
|
||||
text_map[str(message.message_id)] = text
|
||||
|
||||
try:
|
||||
embeddings = await bot_interest_manager.generate_embeddings_for_texts(text_map)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"批量获取消息embedding失败: {exc}")
|
||||
embeddings = {}
|
||||
|
||||
interest_updates: dict[str, float] = {}
|
||||
reply_updates: dict[str, bool] = {}
|
||||
|
||||
for message in pending_messages:
|
||||
message_id = str(message.message_id)
|
||||
if message_id in embeddings:
|
||||
message.semantic_embedding = embeddings[message_id]
|
||||
|
||||
try:
|
||||
result = await interest_manager.calculate_interest(message)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"批量计算消息兴趣失败: {exc}")
|
||||
continue
|
||||
|
||||
if result.success:
|
||||
message.interest_value = result.interest_value
|
||||
message.should_reply = result.should_reply
|
||||
message.should_act = result.should_act
|
||||
message.interest_calculated = True
|
||||
interest_updates[message_id] = result.interest_value
|
||||
reply_updates[message_id] = result.should_reply
|
||||
else:
|
||||
message.interest_calculated = False
|
||||
|
||||
if interest_updates:
|
||||
try:
|
||||
await MessageStorage.bulk_update_interest_values(interest_updates, reply_updates)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"批量更新消息兴趣值失败: {exc}")
|
||||
|
||||
async def _focus_mode_flow(self, context: "StreamContext | None") -> tuple[list[dict[str, Any]], Any | None]:
|
||||
"""Focus模式下的完整plan流程
|
||||
|
||||
@@ -122,6 +194,7 @@ class ChatterActionPlanner:
|
||||
"""
|
||||
try:
|
||||
unread_messages = context.get_unread_messages() if context else []
|
||||
await self._prepare_interest_scores(context, unread_messages)
|
||||
|
||||
# 1. 使用新的兴趣度管理系统进行评分
|
||||
max_message_interest = 0.0
|
||||
@@ -303,6 +376,7 @@ class ChatterActionPlanner:
|
||||
|
||||
try:
|
||||
unread_messages = context.get_unread_messages() if context else []
|
||||
await self._prepare_interest_scores(context, unread_messages)
|
||||
|
||||
# 1. 检查是否有未读消息
|
||||
if not unread_messages:
|
||||
|
||||
Reference in New Issue
Block a user