diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index 959796a51..ada2d0365 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -442,6 +442,43 @@ class BotInterestManager: logger.debug(f"✅ 消息embedding生成成功,维度: {len(embedding)}") return embedding + async def generate_embeddings_for_texts( + self, text_map: dict[str, str], batch_size: int = 16 + ) -> dict[str, list[float]]: + """批量获取多段文本的embedding,供上层统一处理。""" + if not text_map: + return {} + + if not self.embedding_request: + raise RuntimeError("Embedding客户端未初始化") + + batch_size = max(1, batch_size) + keys = list(text_map.keys()) + results: dict[str, list[float]] = {} + + for start in range(0, len(keys), batch_size): + chunk_keys = keys[start : start + batch_size] + chunk_texts = [text_map[key] or "" for key in chunk_keys] + + try: + chunk_embeddings, _ = await self.embedding_request.get_embedding(chunk_texts) + except Exception as exc: # noqa: BLE001 + logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}") + continue + + if isinstance(chunk_embeddings, list) and chunk_embeddings and isinstance(chunk_embeddings[0], list): + normalized = chunk_embeddings + elif isinstance(chunk_embeddings, list): + normalized = [chunk_embeddings] + else: + normalized = [] + + for idx_offset, message_id in enumerate(chunk_keys): + vector = normalized[idx_offset] if idx_offset < len(normalized) else [] + results[message_id] = vector + + return results + async def _calculate_similarity_scores( self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str] ): @@ -473,7 +510,7 @@ class BotInterestManager: logger.error(f"❌ 计算相似度分数失败: {e}") async def calculate_interest_match( - self, message_text: str, keywords: list[str] | None = None + self, message_text: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None ) -> InterestMatchResult: """计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略) @@ -505,7 +542,8 @@ class BotInterestManager: # 生成消息的embedding logger.debug("正在生成消息 embedding...") - message_embedding = await self._get_embedding(message_text) + if not message_embedding: + message_embedding = await self._get_embedding(message_text) logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}") # 计算与每个兴趣标签的相似度(使用扩展标签) diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py index 99a20cce0..61c5f0440 100644 --- a/src/chat/message_manager/context_manager.py +++ b/src/chat/message_manager/context_manager.py @@ -88,8 +88,13 @@ class SingleStreamContextManager: self.context.enable_cache(True) logger.debug(f"为StreamContext {self.stream_id} 启用缓存系统") - # 先计算兴趣值(需要在缓存前计算) - await self._calculate_message_interest(message) + # 新消息默认占位兴趣值,延迟到 Chatter 批量处理阶段 + if message.interest_value is None: + message.interest_value = 0.3 + message.should_reply = False + message.should_act = False + message.interest_calculated = False + message.semantic_embedding = None message.is_read = False # 使用StreamContext的智能缓存功能 @@ -440,6 +445,7 @@ class SingleStreamContextManager: message.interest_value = result.interest_value message.should_reply = result.should_reply message.should_act = result.should_act + message.interest_calculated = True logger.debug( f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, " @@ -448,6 +454,7 @@ class SingleStreamContextManager: return result.interest_value else: logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}") + message.interest_calculated = False return 0.5 else: logger.debug("未找到兴趣值计算器,使用默认兴趣值") @@ -455,6 +462,8 @@ class SingleStreamContextManager: except Exception as e: logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True) + if hasattr(message, "interest_calculated"): + message.interest_calculated = False return 0.5 def _detect_chat_type(self, message: DatabaseMessages): diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 0880ef08e..0f1ab2c26 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -110,6 +110,7 @@ def init_prompt(): ## 其他信息 {memory_block} + {relation_info_block} {extra_info_block} @@ -579,7 +580,7 @@ class DefaultReplyer: try: from src.memory_graph.manager_singleton import get_unified_memory_manager - from src.memory_graph.utils.memory_formatter import format_memory_for_prompt + from src.memory_graph.utils.three_tier_formatter import memory_formatter unified_manager = get_unified_memory_manager() if not unified_manager: @@ -602,38 +603,12 @@ class DefaultReplyer: short_term_memories = search_result.get("short_term_memories", []) long_term_memories = search_result.get("long_term_memories", []) - memory_parts = ["### 🧠 相关记忆 (Relevant Memories)", ""] - - # 添加感知记忆(最近的消息块) - if perceptual_blocks: - memory_parts.append("#### 🌊 感知记忆") - for block in perceptual_blocks: - messages = block.messages if hasattr(block, 'messages') else [] - if messages: - block_content = "\n".join([ - f"{msg.get('sender_name', msg.get('sender_id', ''))}: {msg.get('content', '')[:30]}" - for msg in messages - ]) - memory_parts.append(f"- {block_content}") - memory_parts.append("") - - # 添加短期记忆(结构化活跃记忆) - if short_term_memories: - memory_parts.append("#### 💭 短期记忆") - for mem in short_term_memories: - content = format_memory_for_prompt(mem, include_metadata=False) - if content: - memory_parts.append(f"- {content}") - memory_parts.append("") - - # 添加长期记忆(图谱记忆) - if long_term_memories: - memory_parts.append("#### 🗄️ 长期记忆") - for mem in long_term_memories: - content = format_memory_for_prompt(mem, include_metadata=False) - if content: - memory_parts.append(f"- {content}") - memory_parts.append("") + # 使用新的三级记忆格式化器 + formatted_memories = await memory_formatter.format_all_tiers( + perceptual_blocks=perceptual_blocks, + short_term_memories=short_term_memories, + long_term_memories=long_term_memories + ) total_count = len(perceptual_blocks) + len(short_term_memories) + len(long_term_memories) if total_count > 0: @@ -642,7 +617,11 @@ class DefaultReplyer: f"(感知:{len(perceptual_blocks)}, 短期:{len(short_term_memories)}, 长期:{len(long_term_memories)})" ) - return "\n".join(memory_parts) if len(memory_parts) > 2 else "" + # 添加标题并返回格式化后的记忆 + if formatted_memories.strip(): + return "### 🧠 相关记忆 (Relevant Memories)\n\n" + formatted_memories + + return "" except Exception as e: logger.error(f"[三层记忆] 检索失败: {e}", exc_info=True) diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 24b56ff4e..af06eb7b5 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -152,6 +152,10 @@ class DatabaseMessages(BaseDataModel): group_info=self.group_info, ) + # 扩展运行时字段 + self.semantic_embedding = kwargs.pop("semantic_embedding", None) + self.interest_calculated = kwargs.pop("interest_calculated", False) + # 处理额外传入的字段(kwargs) if kwargs: for key, value in kwargs.items(): diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index 3114b5fda..507fd8436 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -652,7 +652,7 @@ class AiohttpGeminiClient(BaseClient): async def get_embedding( self, model_info: ModelInfo, - embedding_input: str, + embedding_input: str | list[str], extra_params: dict[str, Any] | None = None, ) -> APIResponse: """ diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index baab2897b..ebb4b1b86 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -51,8 +51,8 @@ class APIResponse: tool_calls: list[ToolCall] | None = None """工具调用 [(工具名称, 工具参数), ...]""" - embedding: list[float] | None = None - """嵌入向量""" + embedding: list[float] | list[list[float]] | None = None + """嵌入结果(单条时为一维向量,批量时为向量列表)""" usage: UsageRecord | None = None """使用情况 (prompt_tokens, completion_tokens, total_tokens)""" @@ -105,7 +105,7 @@ class BaseClient(ABC): async def get_embedding( self, model_info: ModelInfo, - embedding_input: str, + embedding_input: str | list[str], extra_params: dict[str, Any] | None = None, ) -> APIResponse: """ diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 7245a79db..e62ef597f 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -580,7 +580,7 @@ class OpenaiClient(BaseClient): async def get_embedding( self, model_info: ModelInfo, - embedding_input: str, + embedding_input: str | list[str], extra_params: dict[str, Any] | None = None, ) -> APIResponse: """ @@ -590,6 +590,7 @@ class OpenaiClient(BaseClient): :return: 嵌入响应 """ client = self._create_client() + is_batch_request = isinstance(embedding_input, list) try: raw_response = await client.embeddings.create( model=model_info.model_identifier, @@ -616,7 +617,8 @@ class OpenaiClient(BaseClient): # 解析嵌入响应 if len(raw_response.data) > 0: - response.embedding = raw_response.data[0].embedding + embeddings = [item.embedding for item in raw_response.data] + response.embedding = embeddings if is_batch_request else embeddings[0] else: raise RespParseException( raw_response, diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index a46824a72..2bb1a3c37 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1111,15 +1111,15 @@ class LLMRequest: return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls) - async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]: + async def get_embedding(self, embedding_input: str | list[str]) -> tuple[list[float] | list[list[float]], str]: """ - 获取嵌入向量。 + 获取嵌入向量,支持批量文本 Args: - embedding_input (str): 获取嵌入的目标 + embedding_input (str | list[str]): 需要生成嵌入的文本或文本列表 Returns: - (Tuple[List[float], str]): (嵌入向量,使用的模型名称) + (Tuple[Union[List[float], List[List[float]]], str]): 嵌入结果及使用的模型名称 """ start_time = time.time() response, model_info = await self._strategy.execute_with_failover( @@ -1128,10 +1128,25 @@ class LLMRequest: await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings") - if not response.embedding: + if response.embedding is None: raise RuntimeError("获取embedding失败") - return response.embedding, model_info.name + embeddings = response.embedding + is_batch_request = isinstance(embedding_input, list) + + if is_batch_request: + if not isinstance(embeddings, list): + raise RuntimeError("获取embedding失败,批量结果格式异常") + + if embeddings and not isinstance(embeddings[0], list): + embeddings = [embeddings] # type: ignore[list-item] + + return embeddings, model_info.name + + if isinstance(embeddings, list) and embeddings and isinstance(embeddings[0], list): + return embeddings[0], model_info.name + + return embeddings, model_info.name async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str): """ diff --git a/src/memory_graph/unified_manager.py b/src/memory_graph/unified_manager.py index c9e5990e5..a7755f49b 100644 --- a/src/memory_graph/unified_manager.py +++ b/src/memory_graph/unified_manager.py @@ -308,7 +308,7 @@ class UnifiedMemoryManager: from src.memory_graph.utils.three_tier_formatter import memory_formatter # 使用新的三级记忆格式化器 - perceptual_desc = memory_formatter.format_perceptual_memory(perceptual_blocks) + perceptual_desc = await memory_formatter.format_perceptual_memory(perceptual_blocks) short_term_desc = memory_formatter.format_short_term_memory(short_term_memories) # 构建聊天历史块(如果提供) diff --git a/src/memory_graph/utils/memory_formatter.py b/src/memory_graph/utils/memory_formatter.py deleted file mode 100644 index c546e0614..000000000 --- a/src/memory_graph/utils/memory_formatter.py +++ /dev/null @@ -1,234 +0,0 @@ -""" -记忆格式化工具 - -提供将记忆对象格式化为提示词的功能,使用 "主体-主题(属性)" 格式。 -""" - -from src.memory_graph.models import Memory, MemoryNode, NodeType -from src.memory_graph.models import ShortTermMemory - - -def get_memory_type_label(memory_type: str) -> str: - """ - 获取记忆类型的中文标签 - - Args: - memory_type: 记忆类型(英文) - - Returns: - 中文标签 - """ - type_mapping = { - "事实": "事实", - "事件": "事件", - "观点": "观点", - "关系": "关系", - "目标": "目标", - "计划": "计划", - "fact": "事实", - "event": "事件", - "opinion": "观点", - "relation": "关系", - "goal": "目标", - "plan": "计划", - "unknown": "未知", - } - return type_mapping.get(memory_type.lower(), memory_type) - - -def format_memory_for_prompt(memory: Memory | ShortTermMemory, include_metadata: bool = True) -> str: - """ - 格式化记忆为提示词文本 - - 使用 "主体-主题(属性)" 格式,例如: - - "张三-职业(程序员, 公司=MoFox)" - - "小明-喜欢(Python, 原因=简洁优雅)" - - "拾风-地址(https://mofox.com)" - - Args: - memory: Memory 或 ShortTermMemory 对象 - include_metadata: 是否包含元数据(如重要性、时间等) - - Returns: - 格式化后的记忆文本 - """ - if isinstance(memory, ShortTermMemory): - return _format_short_term_memory(memory, include_metadata) - elif isinstance(memory, Memory): - return _format_long_term_memory(memory, include_metadata) - else: - return str(memory) - - -def _format_short_term_memory(mem: ShortTermMemory, include_metadata: bool) -> str: - """ - 格式化短期记忆 - - Args: - mem: ShortTermMemory 对象 - include_metadata: 是否包含元数据 - - Returns: - 格式化后的文本 - """ - parts = [] - - # 主体 - subject = mem.subject or "" - # 主题 - topic = mem.topic or "" - # 客体 - obj = mem.object or "" - - # 构建基础格式:主体-主题 - if subject and topic: - base = f"{subject}-{topic}" - elif subject: - base = subject - elif topic: - base = topic - else: - # 如果没有结构化字段,使用 content - # 防御性编程:确保 content 是字符串 - if isinstance(mem.content, list): - return " ".join(str(item) for item in mem.content) - return str(mem.content) if mem.content else "" - - # 添加客体和属性 - attr_parts = [] - if obj: - attr_parts.append(obj) - - # 添加属性 - if mem.attributes: - for key, value in mem.attributes.items(): - if value: - attr_parts.append(f"{key}={value}") - - # 组合 - if attr_parts: - result = f"{base}({', '.join(attr_parts)})" - else: - result = base - - # 添加元数据(可选) - if include_metadata: - metadata_parts = [] - if mem.memory_type: - metadata_parts.append(f"类型:{get_memory_type_label(mem.memory_type)}") - if mem.importance > 0: - metadata_parts.append(f"重要性:{mem.importance:.2f}") - - if metadata_parts: - result = f"{result} [{', '.join(metadata_parts)}]" - - return result - - -def _format_long_term_memory(mem: Memory, include_metadata: bool) -> str: - """ - 格式化长期记忆(Memory 对象) - - Args: - mem: Memory 对象 - include_metadata: 是否包含元数据 - - Returns: - 格式化后的文本 - """ - from src.memory_graph.models import EdgeType - - # 获取主体节点 - subject_node = mem.get_subject_node() - if not subject_node: - return mem.to_text() - - subject = subject_node.content - - # 查找主题节点 - topic_node = None - for edge in mem.edges: - edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type) - if edge_type == "记忆类型" and edge.source_id == mem.subject_id: - topic_node = mem.get_node_by_id(edge.target_id) - break - - if not topic_node: - return subject - - topic = topic_node.content - - # 基础格式:主体-主题 - base = f"{subject}-{topic}" - - # 收集客体和属性 - attr_parts = [] - - # 查找客体节点(通过核心关系边) - for edge in mem.edges: - edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type) - if edge_type == "核心关系" and edge.source_id == topic_node.id: - obj_node = mem.get_node_by_id(edge.target_id) - if obj_node: - # 如果有关系名称,使用关系名称 - if edge.relation and edge.relation != "未知": - attr_parts.append(f"{edge.relation}={obj_node.content}") - else: - attr_parts.append(obj_node.content) - - # 查找属性节点 - for node in mem.nodes: - if node.node_type == NodeType.ATTRIBUTE: - # 属性节点的 content 格式可能是 "key=value" 或 "value" - attr_parts.append(node.content) - - # 组合 - if attr_parts: - result = f"{base}({', '.join(attr_parts)})" - else: - result = base - - # 添加元数据(可选) - if include_metadata: - metadata_parts = [] - if mem.memory_type: - type_value = mem.memory_type.value if hasattr(mem.memory_type, 'value') else str(mem.memory_type) - metadata_parts.append(f"类型:{get_memory_type_label(type_value)}") - if mem.importance > 0: - metadata_parts.append(f"重要性:{mem.importance:.2f}") - - if metadata_parts: - result = f"{result} [{', '.join(metadata_parts)}]" - - return result - - -def format_memories_block( - memories: list[Memory | ShortTermMemory], - title: str = "相关记忆", - max_count: int = 10, - include_metadata: bool = False, -) -> str: - """ - 格式化多个记忆为提示词块 - - Args: - memories: 记忆列表 - title: 块标题 - max_count: 最多显示的记忆数量 - include_metadata: 是否包含元数据 - - Returns: - 格式化后的记忆块 - """ - if not memories: - return "" - - lines = [f"### 🧠 {title}", ""] - - for mem in memories[:max_count]: - formatted = format_memory_for_prompt(mem, include_metadata=include_metadata) - if formatted: - lines.append(f"- {formatted}") - - return "\n".join(lines) diff --git a/src/memory_graph/utils/three_tier_formatter.py b/src/memory_graph/utils/three_tier_formatter.py index 451eb90f9..47ae3cda0 100644 --- a/src/memory_graph/utils/three_tier_formatter.py +++ b/src/memory_graph/utils/three_tier_formatter.py @@ -22,7 +22,7 @@ class ThreeTierMemoryFormatter: """初始化格式化器""" pass - def format_perceptual_memory(self, blocks: list[MemoryBlock]) -> str: + async def format_perceptual_memory(self, blocks: list[MemoryBlock]) -> str: """ 格式化感知记忆为提示词 @@ -53,7 +53,7 @@ class ThreeTierMemoryFormatter: for block in blocks: # 提取时间和聊天流信息 time_str = self._extract_time_from_block(block) - stream_name = self._extract_stream_name_from_block(block) + stream_name = await self._extract_stream_name_from_block(block) # 添加块标题 lines.append(f"- 【{time_str} ({stream_name})】") @@ -122,7 +122,7 @@ class ThreeTierMemoryFormatter: return "\n".join(lines) - def format_all_tiers( + async def format_all_tiers( self, perceptual_blocks: list[MemoryBlock], short_term_memories: list[ShortTermMemory], @@ -142,7 +142,7 @@ class ThreeTierMemoryFormatter: sections = [] # 感知记忆 - perceptual_text = self.format_perceptual_memory(perceptual_blocks) + perceptual_text = await self.format_perceptual_memory(perceptual_blocks) if perceptual_text: sections.append("### 感知记忆(即时对话)") sections.append(perceptual_text) @@ -198,7 +198,7 @@ class ThreeTierMemoryFormatter: return "未知时间" - def _extract_stream_name_from_block(self, block: MemoryBlock) -> str: + async def _extract_stream_name_from_block(self, block: MemoryBlock) -> str: """ 从记忆块中提取聊天流名称 @@ -208,18 +208,31 @@ class ThreeTierMemoryFormatter: Returns: 聊天流名称 """ - # 尝试从元数据中获取 - if block.metadata: - stream_name = block.metadata.get("stream_name") or block.metadata.get("chat_stream") - if stream_name: - return stream_name + stream_id = None - # 尝试从消息中提取 - if block.messages: + # 首先尝试从元数据中获取 stream_id + if block.metadata: + stream_id = block.metadata.get("stream_id") + + # 如果从元数据中没找到,尝试从消息中提取 + if not stream_id and block.messages: first_msg = block.messages[0] - stream_name = first_msg.get("stream_name") or first_msg.get("chat_stream") - if stream_name: - return stream_name + stream_id = first_msg.get("stream_id") or first_msg.get("chat_id") + + # 如果有 stream_id,尝试获取实际的流名称 + if stream_id: + try: + from src.chat.message_receive.chat_stream import get_chat_manager + chat_manager = get_chat_manager() + actual_name = await chat_manager.get_stream_name(stream_id) + if actual_name: + return actual_name + else: + # 如果获取不到名称,返回 stream_id 的截断版本 + return stream_id[:12] + "..." if len(stream_id) > 12 else stream_id + except Exception: + # 如果获取失败,返回 stream_id 的截断版本 + return stream_id[:12] + "..." if len(stream_id) > 12 else stream_id return "默认聊天" @@ -375,7 +388,7 @@ class ThreeTierMemoryFormatter: return type_mapping.get(type_value, "事实") - def format_for_context_injection( + async def format_for_context_injection( self, query: str, perceptual_blocks: list[MemoryBlock], @@ -407,7 +420,7 @@ class ThreeTierMemoryFormatter: limited_short_term = short_term_memories[:max_short_term] limited_long_term = long_term_memories[:max_long_term] - all_tiers_text = self.format_all_tiers( + all_tiers_text = await self.format_all_tiers( limited_perceptual, limited_short_term, limited_long_term diff --git a/src/plugin_system/apis/person_api.py b/src/plugin_system/apis/person_api.py index a97e741b8..03e0b716f 100644 --- a/src/plugin_system/apis/person_api.py +++ b/src/plugin_system/apis/person_api.py @@ -185,18 +185,19 @@ async def initialize_smart_interests(personality_description: str, personality_i await interest_service.initialize_smart_interests(personality_description, personality_id) -async def calculate_interest_match(content: str, keywords: list[str] | None = None): - """ - 计算内容与兴趣的匹配度 +async def calculate_interest_match( + content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None +): + """计算消息兴趣匹配,返回匹配结果""" + if not content: + logger.warning("[PersonAPI] 请求兴趣匹配时 content 为空") + return None - Args: - content: 消息内容 - keywords: 关键词列表 - - Returns: - 匹配结果 - """ - return await interest_service.calculate_interest_match(content, keywords) + try: + return await interest_service.calculate_interest_match(content, keywords, message_embedding) + except Exception as e: + logger.error(f"[PersonAPI] 计算消息兴趣匹配失败: {e}") + return None # ============================================================================= @@ -213,7 +214,7 @@ def get_system_stats() -> dict[str, Any]: """ return { "relationship_service": relationship_service.get_cache_stats(), - "interest_service": interest_service.get_interest_stats() + "interest_service": interest_service.get_interest_stats(), } diff --git a/src/plugin_system/services/interest_service.py b/src/plugin_system/services/interest_service.py index fd127a425..478f04ee2 100644 --- a/src/plugin_system/services/interest_service.py +++ b/src/plugin_system/services/interest_service.py @@ -40,13 +40,16 @@ class InterestService: logger.error(f"初始化智能兴趣系统失败: {e}") self.is_initialized = False - async def calculate_interest_match(self, content: str, keywords: list[str] | None = None): + async def calculate_interest_match( + self, content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None + ): """ - 计算内容与兴趣的匹配度 + 计算消息与兴趣的匹配度 Args: content: 消息内容 - keywords: 关键词列表 + keywords: 关键字列表 + message_embedding: 已经生成的消息embedding,可选 Returns: 匹配结果 @@ -57,12 +60,12 @@ class InterestService: try: if not keywords: - # 如果没有关键词,尝试从内容提取 + # 如果没有关键字,则从内容中提取 keywords = self._extract_keywords_from_content(content) - return await bot_interest_manager.calculate_interest_match(content, keywords) + return await bot_interest_manager.calculate_interest_match(content, keywords, message_embedding) except Exception as e: - logger.error(f"计算兴趣匹配度失败: {e}") + logger.error(f"计算兴趣匹配失败: {e}") return None def _extract_keywords_from_content(self, content: str) -> list[str]: diff --git a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py index 38ae6ad8c..66ba4cee5 100644 --- a/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py +++ b/src/plugins/built_in/affinity_flow_chatter/core/affinity_interest_calculator.py @@ -103,7 +103,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): # 1. 计算兴趣匹配分 keywords = self._extract_keywords_from_database(message) - interest_match_score = await self._calculate_interest_match_score(content, keywords) + interest_match_score = await self._calculate_interest_match_score(message, content, keywords) logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}") # 2. 计算关系分 @@ -180,7 +180,9 @@ class AffinityInterestCalculator(BaseInterestCalculator): success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e) ) - async def _calculate_interest_match_score(self, content: str, keywords: list[str] | None = None) -> float: + async def _calculate_interest_match_score( + self, message: "DatabaseMessages", content: str, keywords: list[str] | None = None + ) -> float: """计算兴趣匹配度(使用智能兴趣匹配系统,带超时保护)""" # 调试日志:检查各个条件 @@ -199,7 +201,9 @@ class AffinityInterestCalculator(BaseInterestCalculator): try: # 使用机器人的兴趣标签系统进行智能匹配(1.5秒超时保护) match_result = await asyncio.wait_for( - bot_interest_manager.calculate_interest_match(content, keywords or []), + bot_interest_manager.calculate_interest_match( + content, keywords or [], getattr(message, "semantic_embedding", None) + ), timeout=1.5 ) logger.debug(f"兴趣匹配结果: {match_result}") diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py index 1483b73f2..cebb32a66 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py @@ -7,6 +7,9 @@ import asyncio from dataclasses import asdict from typing import TYPE_CHECKING, Any +from src.chat.interest_system import bot_interest_manager +from src.chat.interest_system.interest_manager import get_interest_manager +from src.chat.message_receive.storage import MessageStorage from src.common.logger import get_logger from src.config.config import global_config from src.mood.mood_manager import mood_manager @@ -19,6 +22,7 @@ if TYPE_CHECKING: from src.chat.planner_actions.action_manager import ChatterActionManager from src.common.data_models.info_data_model import Plan from src.common.data_models.message_manager_data_model import StreamContext + from src.common.data_models.database_data_model import DatabaseMessages # 导入提示词模块以确保其被初始化 @@ -115,6 +119,74 @@ class ChatterActionPlanner: context.processing_message_id = None return [], None + async def _prepare_interest_scores( + self, context: "StreamContext | None", unread_messages: list["DatabaseMessages"] + ) -> None: + """在执行规划前,为未计算兴趣的消息批量补齐兴趣数据""" + if not context or not unread_messages: + return + + pending_messages = [msg for msg in unread_messages if not getattr(msg, "interest_calculated", False)] + if not pending_messages: + return + + logger.debug(f"批量兴趣值计算:待处理 {len(pending_messages)} 条消息") + + if not bot_interest_manager.is_initialized: + logger.debug("bot_interest_manager 未初始化,跳过批量兴趣计算") + return + + try: + interest_manager = get_interest_manager() + except Exception as exc: # noqa: BLE001 + logger.warning(f"获取兴趣管理器失败: {exc}") + return + + if not interest_manager or not interest_manager.has_calculator(): + logger.debug("当前无可用兴趣计算器,跳过批量兴趣计算") + return + + text_map: dict[str, str] = {} + for message in pending_messages: + text = getattr(message, "processed_plain_text", None) or getattr(message, "display_message", "") or "" + text_map[str(message.message_id)] = text + + try: + embeddings = await bot_interest_manager.generate_embeddings_for_texts(text_map) + except Exception as exc: # noqa: BLE001 + logger.error(f"批量获取消息embedding失败: {exc}") + embeddings = {} + + interest_updates: dict[str, float] = {} + reply_updates: dict[str, bool] = {} + + for message in pending_messages: + message_id = str(message.message_id) + if message_id in embeddings: + message.semantic_embedding = embeddings[message_id] + + try: + result = await interest_manager.calculate_interest(message) + except Exception as exc: # noqa: BLE001 + logger.error(f"批量计算消息兴趣失败: {exc}") + continue + + if result.success: + message.interest_value = result.interest_value + message.should_reply = result.should_reply + message.should_act = result.should_act + message.interest_calculated = True + interest_updates[message_id] = result.interest_value + reply_updates[message_id] = result.should_reply + else: + message.interest_calculated = False + + if interest_updates: + try: + await MessageStorage.bulk_update_interest_values(interest_updates, reply_updates) + except Exception as exc: # noqa: BLE001 + logger.error(f"批量更新消息兴趣值失败: {exc}") + async def _focus_mode_flow(self, context: "StreamContext | None") -> tuple[list[dict[str, Any]], Any | None]: """Focus模式下的完整plan流程 @@ -122,6 +194,7 @@ class ChatterActionPlanner: """ try: unread_messages = context.get_unread_messages() if context else [] + await self._prepare_interest_scores(context, unread_messages) # 1. 使用新的兴趣度管理系统进行评分 max_message_interest = 0.0 @@ -303,6 +376,7 @@ class ChatterActionPlanner: try: unread_messages = context.get_unread_messages() if context else [] + await self._prepare_interest_scores(context, unread_messages) # 1. 检查是否有未读消息 if not unread_messages: