feat: 批量生成文本embedding,优化兴趣匹配计算逻辑,支持消息兴趣值的批量更新
This commit is contained in:
@@ -103,7 +103,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
# 1. 计算兴趣匹配分
|
||||
keywords = self._extract_keywords_from_database(message)
|
||||
interest_match_score = await self._calculate_interest_match_score(content, keywords)
|
||||
interest_match_score = await self._calculate_interest_match_score(message, content, keywords)
|
||||
logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}")
|
||||
|
||||
# 2. 计算关系分
|
||||
@@ -180,7 +180,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e)
|
||||
)
|
||||
|
||||
async def _calculate_interest_match_score(self, content: str, keywords: list[str] | None = None) -> float:
|
||||
async def _calculate_interest_match_score(
|
||||
self, message: "DatabaseMessages", content: str, keywords: list[str] | None = None
|
||||
) -> float:
|
||||
"""计算兴趣匹配度(使用智能兴趣匹配系统,带超时保护)"""
|
||||
|
||||
# 调试日志:检查各个条件
|
||||
@@ -199,7 +201,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
try:
|
||||
# 使用机器人的兴趣标签系统进行智能匹配(1.5秒超时保护)
|
||||
match_result = await asyncio.wait_for(
|
||||
bot_interest_manager.calculate_interest_match(content, keywords or []),
|
||||
bot_interest_manager.calculate_interest_match(
|
||||
content, keywords or [], getattr(message, "semantic_embedding", None)
|
||||
),
|
||||
timeout=1.5
|
||||
)
|
||||
logger.debug(f"兴趣匹配结果: {match_result}")
|
||||
|
||||
@@ -7,6 +7,9 @@ import asyncio
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager
|
||||
@@ -19,6 +22,7 @@ if TYPE_CHECKING:
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.info_data_model import Plan
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 导入提示词模块以确保其被初始化
|
||||
|
||||
@@ -115,6 +119,74 @@ class ChatterActionPlanner:
|
||||
context.processing_message_id = None
|
||||
return [], None
|
||||
|
||||
async def _prepare_interest_scores(
|
||||
self, context: "StreamContext | None", unread_messages: list["DatabaseMessages"]
|
||||
) -> None:
|
||||
"""在执行规划前,为未计算兴趣的消息批量补齐兴趣数据"""
|
||||
if not context or not unread_messages:
|
||||
return
|
||||
|
||||
pending_messages = [msg for msg in unread_messages if not getattr(msg, "interest_calculated", False)]
|
||||
if not pending_messages:
|
||||
return
|
||||
|
||||
logger.debug(f"批量兴趣值计算:待处理 {len(pending_messages)} 条消息")
|
||||
|
||||
if not bot_interest_manager.is_initialized:
|
||||
logger.debug("bot_interest_manager 未初始化,跳过批量兴趣计算")
|
||||
return
|
||||
|
||||
try:
|
||||
interest_manager = get_interest_manager()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(f"获取兴趣管理器失败: {exc}")
|
||||
return
|
||||
|
||||
if not interest_manager or not interest_manager.has_calculator():
|
||||
logger.debug("当前无可用兴趣计算器,跳过批量兴趣计算")
|
||||
return
|
||||
|
||||
text_map: dict[str, str] = {}
|
||||
for message in pending_messages:
|
||||
text = getattr(message, "processed_plain_text", None) or getattr(message, "display_message", "") or ""
|
||||
text_map[str(message.message_id)] = text
|
||||
|
||||
try:
|
||||
embeddings = await bot_interest_manager.generate_embeddings_for_texts(text_map)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"批量获取消息embedding失败: {exc}")
|
||||
embeddings = {}
|
||||
|
||||
interest_updates: dict[str, float] = {}
|
||||
reply_updates: dict[str, bool] = {}
|
||||
|
||||
for message in pending_messages:
|
||||
message_id = str(message.message_id)
|
||||
if message_id in embeddings:
|
||||
message.semantic_embedding = embeddings[message_id]
|
||||
|
||||
try:
|
||||
result = await interest_manager.calculate_interest(message)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"批量计算消息兴趣失败: {exc}")
|
||||
continue
|
||||
|
||||
if result.success:
|
||||
message.interest_value = result.interest_value
|
||||
message.should_reply = result.should_reply
|
||||
message.should_act = result.should_act
|
||||
message.interest_calculated = True
|
||||
interest_updates[message_id] = result.interest_value
|
||||
reply_updates[message_id] = result.should_reply
|
||||
else:
|
||||
message.interest_calculated = False
|
||||
|
||||
if interest_updates:
|
||||
try:
|
||||
await MessageStorage.bulk_update_interest_values(interest_updates, reply_updates)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"批量更新消息兴趣值失败: {exc}")
|
||||
|
||||
async def _focus_mode_flow(self, context: "StreamContext | None") -> tuple[list[dict[str, Any]], Any | None]:
|
||||
"""Focus模式下的完整plan流程
|
||||
|
||||
@@ -122,6 +194,7 @@ class ChatterActionPlanner:
|
||||
"""
|
||||
try:
|
||||
unread_messages = context.get_unread_messages() if context else []
|
||||
await self._prepare_interest_scores(context, unread_messages)
|
||||
|
||||
# 1. 使用新的兴趣度管理系统进行评分
|
||||
max_message_interest = 0.0
|
||||
@@ -303,6 +376,7 @@ class ChatterActionPlanner:
|
||||
|
||||
try:
|
||||
unread_messages = context.get_unread_messages() if context else []
|
||||
await self._prepare_interest_scores(context, unread_messages)
|
||||
|
||||
# 1. 检查是否有未读消息
|
||||
if not unread_messages:
|
||||
|
||||
Reference in New Issue
Block a user