feat: 批量生成文本embedding,优化兴趣匹配计算逻辑,支持消息兴趣值的批量更新

This commit is contained in:
Windpicker-owo
2025-11-19 16:30:44 +08:00
parent a11d251ec1
commit 14133410e6
15 changed files with 231 additions and 323 deletions

View File

@@ -442,6 +442,43 @@ class BotInterestManager:
logger.debug(f"✅ 消息embedding生成成功维度: {len(embedding)}")
return embedding
async def generate_embeddings_for_texts(
self, text_map: dict[str, str], batch_size: int = 16
) -> dict[str, list[float]]:
"""批量获取多段文本的embedding供上层统一处理。"""
if not text_map:
return {}
if not self.embedding_request:
raise RuntimeError("Embedding客户端未初始化")
batch_size = max(1, batch_size)
keys = list(text_map.keys())
results: dict[str, list[float]] = {}
for start in range(0, len(keys), batch_size):
chunk_keys = keys[start : start + batch_size]
chunk_texts = [text_map[key] or "" for key in chunk_keys]
try:
chunk_embeddings, _ = await self.embedding_request.get_embedding(chunk_texts)
except Exception as exc: # noqa: BLE001
logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}")
continue
if isinstance(chunk_embeddings, list) and chunk_embeddings and isinstance(chunk_embeddings[0], list):
normalized = chunk_embeddings
elif isinstance(chunk_embeddings, list):
normalized = [chunk_embeddings]
else:
normalized = []
for idx_offset, message_id in enumerate(chunk_keys):
vector = normalized[idx_offset] if idx_offset < len(normalized) else []
results[message_id] = vector
return results
async def _calculate_similarity_scores(
self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str]
):
@@ -473,7 +510,7 @@ class BotInterestManager:
logger.error(f"❌ 计算相似度分数失败: {e}")
async def calculate_interest_match(
self, message_text: str, keywords: list[str] | None = None
self, message_text: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
) -> InterestMatchResult:
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
@@ -505,7 +542,8 @@ class BotInterestManager:
# 生成消息的embedding
logger.debug("正在生成消息 embedding...")
message_embedding = await self._get_embedding(message_text)
if not message_embedding:
message_embedding = await self._get_embedding(message_text)
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
# 计算与每个兴趣标签的相似度(使用扩展标签)

View File

@@ -88,8 +88,13 @@ class SingleStreamContextManager:
self.context.enable_cache(True)
logger.debug(f"为StreamContext {self.stream_id} 启用缓存系统")
# 先计算兴趣值(需要在缓存前计算)
await self._calculate_message_interest(message)
# 新消息默认占位兴趣值,延迟到 Chatter 批量处理阶段
if message.interest_value is None:
message.interest_value = 0.3
message.should_reply = False
message.should_act = False
message.interest_calculated = False
message.semantic_embedding = None
message.is_read = False
# 使用StreamContext的智能缓存功能
@@ -440,6 +445,7 @@ class SingleStreamContextManager:
message.interest_value = result.interest_value
message.should_reply = result.should_reply
message.should_act = result.should_act
message.interest_calculated = True
logger.debug(
f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
@@ -448,6 +454,7 @@ class SingleStreamContextManager:
return result.interest_value
else:
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
message.interest_calculated = False
return 0.5
else:
logger.debug("未找到兴趣值计算器,使用默认兴趣值")
@@ -455,6 +462,8 @@ class SingleStreamContextManager:
except Exception as e:
logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True)
if hasattr(message, "interest_calculated"):
message.interest_calculated = False
return 0.5
def _detect_chat_type(self, message: DatabaseMessages):

View File

@@ -110,6 +110,7 @@ def init_prompt():
## 其他信息
{memory_block}
{relation_info_block}
{extra_info_block}
@@ -579,7 +580,7 @@ class DefaultReplyer:
try:
from src.memory_graph.manager_singleton import get_unified_memory_manager
from src.memory_graph.utils.memory_formatter import format_memory_for_prompt
from src.memory_graph.utils.three_tier_formatter import memory_formatter
unified_manager = get_unified_memory_manager()
if not unified_manager:
@@ -602,38 +603,12 @@ class DefaultReplyer:
short_term_memories = search_result.get("short_term_memories", [])
long_term_memories = search_result.get("long_term_memories", [])
memory_parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
# 添加感知记忆(最近的消息块)
if perceptual_blocks:
memory_parts.append("#### 🌊 感知记忆")
for block in perceptual_blocks:
messages = block.messages if hasattr(block, 'messages') else []
if messages:
block_content = "\n".join([
f"{msg.get('sender_name', msg.get('sender_id', ''))}: {msg.get('content', '')[:30]}"
for msg in messages
])
memory_parts.append(f"- {block_content}")
memory_parts.append("")
# 添加短期记忆(结构化活跃记忆)
if short_term_memories:
memory_parts.append("#### 💭 短期记忆")
for mem in short_term_memories:
content = format_memory_for_prompt(mem, include_metadata=False)
if content:
memory_parts.append(f"- {content}")
memory_parts.append("")
# 添加长期记忆(图谱记忆)
if long_term_memories:
memory_parts.append("#### 🗄️ 长期记忆")
for mem in long_term_memories:
content = format_memory_for_prompt(mem, include_metadata=False)
if content:
memory_parts.append(f"- {content}")
memory_parts.append("")
# 使用新的三级记忆格式化器
formatted_memories = await memory_formatter.format_all_tiers(
perceptual_blocks=perceptual_blocks,
short_term_memories=short_term_memories,
long_term_memories=long_term_memories
)
total_count = len(perceptual_blocks) + len(short_term_memories) + len(long_term_memories)
if total_count > 0:
@@ -642,7 +617,11 @@ class DefaultReplyer:
f"(感知:{len(perceptual_blocks)}, 短期:{len(short_term_memories)}, 长期:{len(long_term_memories)})"
)
return "\n".join(memory_parts) if len(memory_parts) > 2 else ""
# 添加标题并返回格式化后的记忆
if formatted_memories.strip():
return "### 🧠 相关记忆 (Relevant Memories)\n\n" + formatted_memories
return ""
except Exception as e:
logger.error(f"[三层记忆] 检索失败: {e}", exc_info=True)