feat: 批量生成文本embedding,优化兴趣匹配计算逻辑,支持消息兴趣值的批量更新

This commit is contained in:
Windpicker-owo
2025-11-19 16:30:44 +08:00
parent a11d251ec1
commit 14133410e6
15 changed files with 231 additions and 323 deletions

View File

@@ -442,6 +442,43 @@ class BotInterestManager:
logger.debug(f"✅ 消息embedding生成成功维度: {len(embedding)}")
return embedding
async def generate_embeddings_for_texts(
self, text_map: dict[str, str], batch_size: int = 16
) -> dict[str, list[float]]:
"""批量获取多段文本的embedding供上层统一处理。"""
if not text_map:
return {}
if not self.embedding_request:
raise RuntimeError("Embedding客户端未初始化")
batch_size = max(1, batch_size)
keys = list(text_map.keys())
results: dict[str, list[float]] = {}
for start in range(0, len(keys), batch_size):
chunk_keys = keys[start : start + batch_size]
chunk_texts = [text_map[key] or "" for key in chunk_keys]
try:
chunk_embeddings, _ = await self.embedding_request.get_embedding(chunk_texts)
except Exception as exc: # noqa: BLE001
logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}")
continue
if isinstance(chunk_embeddings, list) and chunk_embeddings and isinstance(chunk_embeddings[0], list):
normalized = chunk_embeddings
elif isinstance(chunk_embeddings, list):
normalized = [chunk_embeddings]
else:
normalized = []
for idx_offset, message_id in enumerate(chunk_keys):
vector = normalized[idx_offset] if idx_offset < len(normalized) else []
results[message_id] = vector
return results
async def _calculate_similarity_scores(
self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str]
):
@@ -473,7 +510,7 @@ class BotInterestManager:
logger.error(f"❌ 计算相似度分数失败: {e}")
async def calculate_interest_match(
self, message_text: str, keywords: list[str] | None = None
self, message_text: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
) -> InterestMatchResult:
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
@@ -505,7 +542,8 @@ class BotInterestManager:
# 生成消息的embedding
logger.debug("正在生成消息 embedding...")
message_embedding = await self._get_embedding(message_text)
if not message_embedding:
message_embedding = await self._get_embedding(message_text)
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
# 计算与每个兴趣标签的相似度(使用扩展标签)

View File

@@ -88,8 +88,13 @@ class SingleStreamContextManager:
self.context.enable_cache(True)
logger.debug(f"为StreamContext {self.stream_id} 启用缓存系统")
# 先计算兴趣值(需要在缓存前计算)
await self._calculate_message_interest(message)
# 新消息默认占位兴趣值,延迟到 Chatter 批量处理阶段
if message.interest_value is None:
message.interest_value = 0.3
message.should_reply = False
message.should_act = False
message.interest_calculated = False
message.semantic_embedding = None
message.is_read = False
# 使用StreamContext的智能缓存功能
@@ -440,6 +445,7 @@ class SingleStreamContextManager:
message.interest_value = result.interest_value
message.should_reply = result.should_reply
message.should_act = result.should_act
message.interest_calculated = True
logger.debug(
f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
@@ -448,6 +454,7 @@ class SingleStreamContextManager:
return result.interest_value
else:
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
message.interest_calculated = False
return 0.5
else:
logger.debug("未找到兴趣值计算器,使用默认兴趣值")
@@ -455,6 +462,8 @@ class SingleStreamContextManager:
except Exception as e:
logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True)
if hasattr(message, "interest_calculated"):
message.interest_calculated = False
return 0.5
def _detect_chat_type(self, message: DatabaseMessages):

View File

@@ -110,6 +110,7 @@ def init_prompt():
## 其他信息
{memory_block}
{relation_info_block}
{extra_info_block}
@@ -579,7 +580,7 @@ class DefaultReplyer:
try:
from src.memory_graph.manager_singleton import get_unified_memory_manager
from src.memory_graph.utils.memory_formatter import format_memory_for_prompt
from src.memory_graph.utils.three_tier_formatter import memory_formatter
unified_manager = get_unified_memory_manager()
if not unified_manager:
@@ -602,38 +603,12 @@ class DefaultReplyer:
short_term_memories = search_result.get("short_term_memories", [])
long_term_memories = search_result.get("long_term_memories", [])
memory_parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
# 添加感知记忆(最近的消息块)
if perceptual_blocks:
memory_parts.append("#### 🌊 感知记忆")
for block in perceptual_blocks:
messages = block.messages if hasattr(block, 'messages') else []
if messages:
block_content = "\n".join([
f"{msg.get('sender_name', msg.get('sender_id', ''))}: {msg.get('content', '')[:30]}"
for msg in messages
])
memory_parts.append(f"- {block_content}")
memory_parts.append("")
# 添加短期记忆(结构化活跃记忆)
if short_term_memories:
memory_parts.append("#### 💭 短期记忆")
for mem in short_term_memories:
content = format_memory_for_prompt(mem, include_metadata=False)
if content:
memory_parts.append(f"- {content}")
memory_parts.append("")
# 添加长期记忆(图谱记忆)
if long_term_memories:
memory_parts.append("#### 🗄️ 长期记忆")
for mem in long_term_memories:
content = format_memory_for_prompt(mem, include_metadata=False)
if content:
memory_parts.append(f"- {content}")
memory_parts.append("")
# 使用新的三级记忆格式化器
formatted_memories = await memory_formatter.format_all_tiers(
perceptual_blocks=perceptual_blocks,
short_term_memories=short_term_memories,
long_term_memories=long_term_memories
)
total_count = len(perceptual_blocks) + len(short_term_memories) + len(long_term_memories)
if total_count > 0:
@@ -642,7 +617,11 @@ class DefaultReplyer:
f"(感知:{len(perceptual_blocks)}, 短期:{len(short_term_memories)}, 长期:{len(long_term_memories)})"
)
return "\n".join(memory_parts) if len(memory_parts) > 2 else ""
# 添加标题并返回格式化后的记忆
if formatted_memories.strip():
return "### 🧠 相关记忆 (Relevant Memories)\n\n" + formatted_memories
return ""
except Exception as e:
logger.error(f"[三层记忆] 检索失败: {e}", exc_info=True)

View File

@@ -152,6 +152,10 @@ class DatabaseMessages(BaseDataModel):
group_info=self.group_info,
)
# 扩展运行时字段
self.semantic_embedding = kwargs.pop("semantic_embedding", None)
self.interest_calculated = kwargs.pop("interest_calculated", False)
# 处理额外传入的字段kwargs
if kwargs:
for key, value in kwargs.items():

View File

@@ -652,7 +652,7 @@ class AiohttpGeminiClient(BaseClient):
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
embedding_input: str | list[str],
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""

View File

@@ -51,8 +51,8 @@ class APIResponse:
tool_calls: list[ToolCall] | None = None
"""工具调用 [(工具名称, 工具参数), ...]"""
embedding: list[float] | None = None
"""嵌入向量"""
embedding: list[float] | list[list[float]] | None = None
"""嵌入结果(单条时为一维向量,批量时为向量列表)"""
usage: UsageRecord | None = None
"""使用情况 (prompt_tokens, completion_tokens, total_tokens)"""
@@ -105,7 +105,7 @@ class BaseClient(ABC):
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
embedding_input: str | list[str],
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""

View File

@@ -580,7 +580,7 @@ class OpenaiClient(BaseClient):
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
embedding_input: str | list[str],
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
@@ -590,6 +590,7 @@ class OpenaiClient(BaseClient):
:return: 嵌入响应
"""
client = self._create_client()
is_batch_request = isinstance(embedding_input, list)
try:
raw_response = await client.embeddings.create(
model=model_info.model_identifier,
@@ -616,7 +617,8 @@ class OpenaiClient(BaseClient):
# 解析嵌入响应
if len(raw_response.data) > 0:
response.embedding = raw_response.data[0].embedding
embeddings = [item.embedding for item in raw_response.data]
response.embedding = embeddings if is_batch_request else embeddings[0]
else:
raise RespParseException(
raw_response,

View File

@@ -1111,15 +1111,15 @@ class LLMRequest:
return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]:
async def get_embedding(self, embedding_input: str | list[str]) -> tuple[list[float] | list[list[float]], str]:
"""
获取嵌入向量
获取嵌入向量,支持批量文本
Args:
embedding_input (str): 获取嵌入的目标
embedding_input (str | list[str]): 需要生成嵌入的文本或文本列表
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
(Tuple[Union[List[float], List[List[float]]], str]): 嵌入结果及使用的模型名称
"""
start_time = time.time()
response, model_info = await self._strategy.execute_with_failover(
@@ -1128,10 +1128,25 @@ class LLMRequest:
await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
if not response.embedding:
if response.embedding is None:
raise RuntimeError("获取embedding失败")
return response.embedding, model_info.name
embeddings = response.embedding
is_batch_request = isinstance(embedding_input, list)
if is_batch_request:
if not isinstance(embeddings, list):
raise RuntimeError("获取embedding失败批量结果格式异常")
if embeddings and not isinstance(embeddings[0], list):
embeddings = [embeddings] # type: ignore[list-item]
return embeddings, model_info.name
if isinstance(embeddings, list) and embeddings and isinstance(embeddings[0], list):
return embeddings[0], model_info.name
return embeddings, model_info.name
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
"""

View File

@@ -308,7 +308,7 @@ class UnifiedMemoryManager:
from src.memory_graph.utils.three_tier_formatter import memory_formatter
# 使用新的三级记忆格式化器
perceptual_desc = memory_formatter.format_perceptual_memory(perceptual_blocks)
perceptual_desc = await memory_formatter.format_perceptual_memory(perceptual_blocks)
short_term_desc = memory_formatter.format_short_term_memory(short_term_memories)
# 构建聊天历史块(如果提供)

View File

@@ -1,234 +0,0 @@
"""
记忆格式化工具
提供将记忆对象格式化为提示词的功能,使用 "主体-主题(属性)" 格式。
"""
from src.memory_graph.models import Memory, MemoryNode, NodeType
from src.memory_graph.models import ShortTermMemory
def get_memory_type_label(memory_type: str) -> str:
"""
获取记忆类型的中文标签
Args:
memory_type: 记忆类型(英文)
Returns:
中文标签
"""
type_mapping = {
"事实": "事实",
"事件": "事件",
"观点": "观点",
"关系": "关系",
"目标": "目标",
"计划": "计划",
"fact": "事实",
"event": "事件",
"opinion": "观点",
"relation": "关系",
"goal": "目标",
"plan": "计划",
"unknown": "未知",
}
return type_mapping.get(memory_type.lower(), memory_type)
def format_memory_for_prompt(memory: Memory | ShortTermMemory, include_metadata: bool = True) -> str:
"""
格式化记忆为提示词文本
使用 "主体-主题(属性)" 格式,例如:
- "张三-职业(程序员, 公司=MoFox)"
- "小明-喜欢(Python, 原因=简洁优雅)"
- "拾风-地址(https://mofox.com)"
Args:
memory: Memory 或 ShortTermMemory 对象
include_metadata: 是否包含元数据(如重要性、时间等)
Returns:
格式化后的记忆文本
"""
if isinstance(memory, ShortTermMemory):
return _format_short_term_memory(memory, include_metadata)
elif isinstance(memory, Memory):
return _format_long_term_memory(memory, include_metadata)
else:
return str(memory)
def _format_short_term_memory(mem: ShortTermMemory, include_metadata: bool) -> str:
"""
格式化短期记忆
Args:
mem: ShortTermMemory 对象
include_metadata: 是否包含元数据
Returns:
格式化后的文本
"""
parts = []
# 主体
subject = mem.subject or ""
# 主题
topic = mem.topic or ""
# 客体
obj = mem.object or ""
# 构建基础格式:主体-主题
if subject and topic:
base = f"{subject}-{topic}"
elif subject:
base = subject
elif topic:
base = topic
else:
# 如果没有结构化字段,使用 content
# 防御性编程:确保 content 是字符串
if isinstance(mem.content, list):
return " ".join(str(item) for item in mem.content)
return str(mem.content) if mem.content else ""
# 添加客体和属性
attr_parts = []
if obj:
attr_parts.append(obj)
# 添加属性
if mem.attributes:
for key, value in mem.attributes.items():
if value:
attr_parts.append(f"{key}={value}")
# 组合
if attr_parts:
result = f"{base}({', '.join(attr_parts)})"
else:
result = base
# 添加元数据(可选)
if include_metadata:
metadata_parts = []
if mem.memory_type:
metadata_parts.append(f"类型:{get_memory_type_label(mem.memory_type)}")
if mem.importance > 0:
metadata_parts.append(f"重要性:{mem.importance:.2f}")
if metadata_parts:
result = f"{result} [{', '.join(metadata_parts)}]"
return result
def _format_long_term_memory(mem: Memory, include_metadata: bool) -> str:
"""
格式化长期记忆Memory 对象)
Args:
mem: Memory 对象
include_metadata: 是否包含元数据
Returns:
格式化后的文本
"""
from src.memory_graph.models import EdgeType
# 获取主体节点
subject_node = mem.get_subject_node()
if not subject_node:
return mem.to_text()
subject = subject_node.content
# 查找主题节点
topic_node = None
for edge in mem.edges:
edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
if edge_type == "记忆类型" and edge.source_id == mem.subject_id:
topic_node = mem.get_node_by_id(edge.target_id)
break
if not topic_node:
return subject
topic = topic_node.content
# 基础格式:主体-主题
base = f"{subject}-{topic}"
# 收集客体和属性
attr_parts = []
# 查找客体节点(通过核心关系边)
for edge in mem.edges:
edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
if edge_type == "核心关系" and edge.source_id == topic_node.id:
obj_node = mem.get_node_by_id(edge.target_id)
if obj_node:
# 如果有关系名称,使用关系名称
if edge.relation and edge.relation != "未知":
attr_parts.append(f"{edge.relation}={obj_node.content}")
else:
attr_parts.append(obj_node.content)
# 查找属性节点
for node in mem.nodes:
if node.node_type == NodeType.ATTRIBUTE:
# 属性节点的 content 格式可能是 "key=value" 或 "value"
attr_parts.append(node.content)
# 组合
if attr_parts:
result = f"{base}({', '.join(attr_parts)})"
else:
result = base
# 添加元数据(可选)
if include_metadata:
metadata_parts = []
if mem.memory_type:
type_value = mem.memory_type.value if hasattr(mem.memory_type, 'value') else str(mem.memory_type)
metadata_parts.append(f"类型:{get_memory_type_label(type_value)}")
if mem.importance > 0:
metadata_parts.append(f"重要性:{mem.importance:.2f}")
if metadata_parts:
result = f"{result} [{', '.join(metadata_parts)}]"
return result
def format_memories_block(
memories: list[Memory | ShortTermMemory],
title: str = "相关记忆",
max_count: int = 10,
include_metadata: bool = False,
) -> str:
"""
格式化多个记忆为提示词块
Args:
memories: 记忆列表
title: 块标题
max_count: 最多显示的记忆数量
include_metadata: 是否包含元数据
Returns:
格式化后的记忆块
"""
if not memories:
return ""
lines = [f"### 🧠 {title}", ""]
for mem in memories[:max_count]:
formatted = format_memory_for_prompt(mem, include_metadata=include_metadata)
if formatted:
lines.append(f"- {formatted}")
return "\n".join(lines)

View File

@@ -22,7 +22,7 @@ class ThreeTierMemoryFormatter:
"""初始化格式化器"""
pass
def format_perceptual_memory(self, blocks: list[MemoryBlock]) -> str:
async def format_perceptual_memory(self, blocks: list[MemoryBlock]) -> str:
"""
格式化感知记忆为提示词
@@ -53,7 +53,7 @@ class ThreeTierMemoryFormatter:
for block in blocks:
# 提取时间和聊天流信息
time_str = self._extract_time_from_block(block)
stream_name = self._extract_stream_name_from_block(block)
stream_name = await self._extract_stream_name_from_block(block)
# 添加块标题
lines.append(f"- 【{time_str} ({stream_name})】")
@@ -122,7 +122,7 @@ class ThreeTierMemoryFormatter:
return "\n".join(lines)
def format_all_tiers(
async def format_all_tiers(
self,
perceptual_blocks: list[MemoryBlock],
short_term_memories: list[ShortTermMemory],
@@ -142,7 +142,7 @@ class ThreeTierMemoryFormatter:
sections = []
# 感知记忆
perceptual_text = self.format_perceptual_memory(perceptual_blocks)
perceptual_text = await self.format_perceptual_memory(perceptual_blocks)
if perceptual_text:
sections.append("### 感知记忆(即时对话)")
sections.append(perceptual_text)
@@ -198,7 +198,7 @@ class ThreeTierMemoryFormatter:
return "未知时间"
def _extract_stream_name_from_block(self, block: MemoryBlock) -> str:
async def _extract_stream_name_from_block(self, block: MemoryBlock) -> str:
"""
从记忆块中提取聊天流名称
@@ -208,18 +208,31 @@ class ThreeTierMemoryFormatter:
Returns:
聊天流名称
"""
# 尝试从元数据中获取
if block.metadata:
stream_name = block.metadata.get("stream_name") or block.metadata.get("chat_stream")
if stream_name:
return stream_name
stream_id = None
# 尝试从消息中提取
if block.messages:
# 首先尝试从元数据中获取 stream_id
if block.metadata:
stream_id = block.metadata.get("stream_id")
# 如果从元数据中没找到,尝试从消息中提取
if not stream_id and block.messages:
first_msg = block.messages[0]
stream_name = first_msg.get("stream_name") or first_msg.get("chat_stream")
if stream_name:
return stream_name
stream_id = first_msg.get("stream_id") or first_msg.get("chat_id")
# 如果有 stream_id尝试获取实际的流名称
if stream_id:
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
actual_name = await chat_manager.get_stream_name(stream_id)
if actual_name:
return actual_name
else:
# 如果获取不到名称,返回 stream_id 的截断版本
return stream_id[:12] + "..." if len(stream_id) > 12 else stream_id
except Exception:
# 如果获取失败,返回 stream_id 的截断版本
return stream_id[:12] + "..." if len(stream_id) > 12 else stream_id
return "默认聊天"
@@ -375,7 +388,7 @@ class ThreeTierMemoryFormatter:
return type_mapping.get(type_value, "事实")
def format_for_context_injection(
async def format_for_context_injection(
self,
query: str,
perceptual_blocks: list[MemoryBlock],
@@ -407,7 +420,7 @@ class ThreeTierMemoryFormatter:
limited_short_term = short_term_memories[:max_short_term]
limited_long_term = long_term_memories[:max_long_term]
all_tiers_text = self.format_all_tiers(
all_tiers_text = await self.format_all_tiers(
limited_perceptual,
limited_short_term,
limited_long_term

View File

@@ -185,18 +185,19 @@ async def initialize_smart_interests(personality_description: str, personality_i
await interest_service.initialize_smart_interests(personality_description, personality_id)
async def calculate_interest_match(content: str, keywords: list[str] | None = None):
"""
计算内容与兴趣的匹配度
async def calculate_interest_match(
content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
):
"""计算消息兴趣匹配,返回匹配结果"""
if not content:
logger.warning("[PersonAPI] 请求兴趣匹配时 content 为空")
return None
Args:
content: 消息内容
keywords: 关键词列表
Returns:
匹配结果
"""
return await interest_service.calculate_interest_match(content, keywords)
try:
return await interest_service.calculate_interest_match(content, keywords, message_embedding)
except Exception as e:
logger.error(f"[PersonAPI] 计算消息兴趣匹配失败: {e}")
return None
# =============================================================================
@@ -213,7 +214,7 @@ def get_system_stats() -> dict[str, Any]:
"""
return {
"relationship_service": relationship_service.get_cache_stats(),
"interest_service": interest_service.get_interest_stats()
"interest_service": interest_service.get_interest_stats(),
}

View File

@@ -40,13 +40,16 @@ class InterestService:
logger.error(f"初始化智能兴趣系统失败: {e}")
self.is_initialized = False
async def calculate_interest_match(self, content: str, keywords: list[str] | None = None):
async def calculate_interest_match(
self, content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
):
"""
计算内容与兴趣的匹配度
计算消息与兴趣的匹配度
Args:
content: 消息内容
keywords: 关键列表
keywords: 关键列表
message_embedding: 已经生成的消息embedding可选
Returns:
匹配结果
@@ -57,12 +60,12 @@ class InterestService:
try:
if not keywords:
# 如果没有关键词,尝试从内容提取
# 如果没有关键字,则从内容提取
keywords = self._extract_keywords_from_content(content)
return await bot_interest_manager.calculate_interest_match(content, keywords)
return await bot_interest_manager.calculate_interest_match(content, keywords, message_embedding)
except Exception as e:
logger.error(f"计算兴趣匹配失败: {e}")
logger.error(f"计算兴趣匹配失败: {e}")
return None
def _extract_keywords_from_content(self, content: str) -> list[str]:

View File

@@ -103,7 +103,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
# 1. 计算兴趣匹配分
keywords = self._extract_keywords_from_database(message)
interest_match_score = await self._calculate_interest_match_score(content, keywords)
interest_match_score = await self._calculate_interest_match_score(message, content, keywords)
logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}")
# 2. 计算关系分
@@ -180,7 +180,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e)
)
async def _calculate_interest_match_score(self, content: str, keywords: list[str] | None = None) -> float:
async def _calculate_interest_match_score(
self, message: "DatabaseMessages", content: str, keywords: list[str] | None = None
) -> float:
"""计算兴趣匹配度(使用智能兴趣匹配系统,带超时保护)"""
# 调试日志:检查各个条件
@@ -199,7 +201,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
try:
# 使用机器人的兴趣标签系统进行智能匹配1.5秒超时保护)
match_result = await asyncio.wait_for(
bot_interest_manager.calculate_interest_match(content, keywords or []),
bot_interest_manager.calculate_interest_match(
content, keywords or [], getattr(message, "semantic_embedding", None)
),
timeout=1.5
)
logger.debug(f"兴趣匹配结果: {match_result}")

View File

@@ -7,6 +7,9 @@ import asyncio
from dataclasses import asdict
from typing import TYPE_CHECKING, Any
from src.chat.interest_system import bot_interest_manager
from src.chat.interest_system.interest_manager import get_interest_manager
from src.chat.message_receive.storage import MessageStorage
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager
@@ -19,6 +22,7 @@ if TYPE_CHECKING:
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.info_data_model import Plan
from src.common.data_models.message_manager_data_model import StreamContext
from src.common.data_models.database_data_model import DatabaseMessages
# 导入提示词模块以确保其被初始化
@@ -115,6 +119,74 @@ class ChatterActionPlanner:
context.processing_message_id = None
return [], None
async def _prepare_interest_scores(
self, context: "StreamContext | None", unread_messages: list["DatabaseMessages"]
) -> None:
"""在执行规划前,为未计算兴趣的消息批量补齐兴趣数据"""
if not context or not unread_messages:
return
pending_messages = [msg for msg in unread_messages if not getattr(msg, "interest_calculated", False)]
if not pending_messages:
return
logger.debug(f"批量兴趣值计算:待处理 {len(pending_messages)} 条消息")
if not bot_interest_manager.is_initialized:
logger.debug("bot_interest_manager 未初始化,跳过批量兴趣计算")
return
try:
interest_manager = get_interest_manager()
except Exception as exc: # noqa: BLE001
logger.warning(f"获取兴趣管理器失败: {exc}")
return
if not interest_manager or not interest_manager.has_calculator():
logger.debug("当前无可用兴趣计算器,跳过批量兴趣计算")
return
text_map: dict[str, str] = {}
for message in pending_messages:
text = getattr(message, "processed_plain_text", None) or getattr(message, "display_message", "") or ""
text_map[str(message.message_id)] = text
try:
embeddings = await bot_interest_manager.generate_embeddings_for_texts(text_map)
except Exception as exc: # noqa: BLE001
logger.error(f"批量获取消息embedding失败: {exc}")
embeddings = {}
interest_updates: dict[str, float] = {}
reply_updates: dict[str, bool] = {}
for message in pending_messages:
message_id = str(message.message_id)
if message_id in embeddings:
message.semantic_embedding = embeddings[message_id]
try:
result = await interest_manager.calculate_interest(message)
except Exception as exc: # noqa: BLE001
logger.error(f"批量计算消息兴趣失败: {exc}")
continue
if result.success:
message.interest_value = result.interest_value
message.should_reply = result.should_reply
message.should_act = result.should_act
message.interest_calculated = True
interest_updates[message_id] = result.interest_value
reply_updates[message_id] = result.should_reply
else:
message.interest_calculated = False
if interest_updates:
try:
await MessageStorage.bulk_update_interest_values(interest_updates, reply_updates)
except Exception as exc: # noqa: BLE001
logger.error(f"批量更新消息兴趣值失败: {exc}")
async def _focus_mode_flow(self, context: "StreamContext | None") -> tuple[list[dict[str, Any]], Any | None]:
"""Focus模式下的完整plan流程
@@ -122,6 +194,7 @@ class ChatterActionPlanner:
"""
try:
unread_messages = context.get_unread_messages() if context else []
await self._prepare_interest_scores(context, unread_messages)
# 1. 使用新的兴趣度管理系统进行评分
max_message_interest = 0.0
@@ -303,6 +376,7 @@ class ChatterActionPlanner:
try:
unread_messages = context.get_unread_messages() if context else []
await self._prepare_interest_scores(context, unread_messages)
# 1. 检查是否有未读消息
if not unread_messages: