From 943c2a656650780f64ab81ef4bc8844055ade7d4 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Mon, 1 Dec 2025 19:57:33 +0800 Subject: [PATCH] =?UTF-8?q?feat(data-models):=20=E4=BD=BF=E7=94=A8=20=5F?= =?UTF-8?q?=5Fslots=5F=5F=20=E4=BC=98=E5=8C=96=E5=86=85=E5=AD=98=E5=8D=A0?= =?UTF-8?q?=E7=94=A8=E5=92=8C=E5=B1=9E=E6=80=A7=E8=AE=BF=E9=97=AE=E6=80=A7?= =?UTF-8?q?=E8=83=BD=EF=BC=8C=E6=9B=B4=E6=96=B0=E5=A4=9A=E4=B8=AA=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/data_models/__init__.py | 1 + src/common/data_models/database_data_model.py | 180 ++++++++++++++---- .../data_models/message_manager_data_model.py | 2 +- src/memory_graph/models.py | 21 +- src/memory_graph/utils/path_expansion.py | 4 +- src/plugin_system/base/component_types.py | 8 +- src/plugin_system/base/config_types.py | 2 +- src/plugin_system/base/plugin_metadata.py | 2 +- src/plugin_system/core/stream_tool_history.py | 2 +- src/plugin_system/core/tool_use.py | 4 +- .../planner/plan_filter.py | 10 +- .../kokoro_flow_chatter/prompt/builder.py | 2 +- 12 files changed, 175 insertions(+), 63 deletions(-) diff --git a/src/common/data_models/__init__.py b/src/common/data_models/__init__.py index d104eec9c..7a84e86cd 100644 --- a/src/common/data_models/__init__.py +++ b/src/common/data_models/__init__.py @@ -3,6 +3,7 @@ from typing import Any class BaseDataModel: + __slots__ = () def deepcopy(self): return copy.deepcopy(self) diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 8662df845..0d56ee2c2 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -5,11 +5,12 @@ from typing import Any from . import BaseDataModel -@dataclass +@dataclass(slots=True) class DatabaseUserInfo(BaseDataModel): """ 用户信息数据模型,用于存储用户的基本信息。 该类通过 dataclass 实现,继承自 BaseDataModel。 + 使用 __slots__ 优化内存占用和属性访问性能。 """ platform: str = field(default_factory=str) # 用户所属平台(如微信、QQ 等) user_id: str = field(default_factory=str) # 用户唯一标识 ID @@ -35,10 +36,12 @@ class DatabaseUserInfo(BaseDataModel): "user_cardname": self.user_cardname, } -@dataclass + +@dataclass(slots=True) class DatabaseGroupInfo(BaseDataModel): """ 群组信息数据模型,用于存储群组的基本信息。 + 使用 __slots__ 优化内存占用和属性访问性能。 """ group_id: str = field(default_factory=str) # 群组唯一标识 ID group_name: str = field(default_factory=str) # 群组名称 @@ -52,7 +55,7 @@ class DatabaseGroupInfo(BaseDataModel): group_name=data.get("group_name", ""), platform=data.get("platform"), ) - + def to_dict(self) -> dict: """将实例转换为字典""" return { @@ -60,12 +63,14 @@ class DatabaseGroupInfo(BaseDataModel): "group_name": self.group_name, "group_platform": self.platform, } - -@dataclass + + +@dataclass(slots=True) class DatabaseChatInfo(BaseDataModel): """ 聊天会话信息数据模型,用于描述一个聊天对话的上下文信息。 包括会话 ID、平台、创建时间、最后活跃时间以及关联的用户和群组信息。 + 使用 __slots__ 优化内存占用和属性访问性能。 """ stream_id: str = field(default_factory=str) # 会话流 ID,唯一标识一个聊天对话 platform: str = field(default_factory=str) # 所属平台(如微信、QQ 等) @@ -80,7 +85,50 @@ class DatabaseMessages(BaseDataModel): """ 消息数据模型,用于存储每一条消息的完整信息,包括内容、元数据、用户、聊天上下文等。 使用 init=False 实现自定义初始化逻辑,通过 __init__ 手动设置字段。 + 使用 __slots__ 优化内存占用和属性访问性能。 """ + + __slots__ = ( + # 基础消息字段 + "message_id", + "time", + "chat_id", + "reply_to", + "interest_value", + "key_words", + "key_words_lite", + "is_mentioned", + "is_at", + "reply_probability_boost", + "processed_plain_text", + "display_message", + "priority_mode", + "priority_info", + "additional_config", + "is_emoji", + "is_picid", + "is_command", + "is_notify", + "is_public_notice", + "notice_type", + "selected_expressions", + "is_read", + "actions", + "should_reply", + "should_act", + # 关联对象 + "user_info", + "group_info", + "chat_info", + # 运行时扩展字段(固定) + "semantic_embedding", + "interest_calculated", + "is_voice", + "is_video", + "has_emoji", + "has_picid", + ) + def __init__( self, message_id: str = "", # 消息唯一 ID @@ -101,30 +149,38 @@ class DatabaseMessages(BaseDataModel): is_emoji: bool = False, # 是否为表情消息 is_picid: bool = False, # 是否为图片消息(包含图片 ID) is_command: bool = False, # 是否为命令消息(如 /help) - is_notify: bool = False, # 是否为notice消息(如禁言、戳一戳等系统事件) - is_public_notice: bool = False, # 是否为公共notice(所有聊天可见) - notice_type: str | None = None, # notice类型(由适配器指定,如 "group_ban", "poke" 等) + is_notify: bool = False, # 是否为 notice 消息(如禁言、戳一戳等系统事件) + is_public_notice: bool = False, # 是否为公共 notice(所有聊天可见) + notice_type: str | None = None, # notice 类型(由适配器指定,如 "group_ban", "poke" 等) selected_expressions: str | None = None, # 选择的表情或响应模板 is_read: bool = False, # 是否已读 - user_id: str = "", # 用户 ID - user_nickname: str = "", # 用户昵称 - user_cardname: str | None = None, # 用户备注名或群名片 - user_platform: str = "", # 用户所属平台 - chat_info_group_id: str | None = None, # 所属群组 ID(聊天上下文信息) - chat_info_group_name: str | None = None, # 所属群组名称 - chat_info_group_platform: str | None = None, # 所属群组平台 - chat_info_user_id: str = "", # 聊天上下文中的用户 ID - chat_info_user_nickname: str = "", # 聊天上下文中的用户昵称 - chat_info_user_cardname: str | None = None, # 聊天上下文中的用户备注名 - chat_info_user_platform: str = "", # 聊天上下文中的用户平台 - chat_info_stream_id: str = "", # 聊天上下文的会话流 ID - chat_info_platform: str = "", # 聊天上下文平台 - chat_info_create_time: float = 0.0, # 聊天上下文创建时间 - chat_info_last_active_time: float = 0.0, # 聊天上下文最后活跃时间 actions: list | None = None, # 与消息相关的动作列表(如回复、转发等) should_reply: bool = False, # 是否应该自动回复 should_act: bool = False, # 是否应该执行动作(如发送消息) - **kwargs: Any, # 允许传入任意额外字段 + # 用户信息(用于构建 user_info) + user_id: str = "", + user_nickname: str = "", + user_cardname: str | None = None, + user_platform: str = "", + # 群组 / 聊天上下文信息(用于构建 group_info / chat_info) + chat_info_group_id: str | None = None, + chat_info_group_name: str | None = None, + chat_info_group_platform: str | None = None, + chat_info_user_id: str = "", + chat_info_user_nickname: str = "", + chat_info_user_cardname: str | None = None, + chat_info_user_platform: str = "", + chat_info_stream_id: str = "", + chat_info_platform: str = "", + chat_info_create_time: float = 0.0, + chat_info_last_active_time: float = 0.0, + # 运行时字段(固定) + semantic_embedding: Any | None = None, + interest_calculated: bool = False, + is_voice: bool = False, # 是否为语音消息 + is_video: bool = False, # 是否为视频消息 + has_emoji: bool = False, # 是否包含表情 + has_picid: bool = False, # 是否包含图片 ID ): # 初始化基础字段 self.message_id = message_id @@ -186,21 +242,19 @@ class DatabaseMessages(BaseDataModel): group_info=self.group_info, ) - # 扩展运行时字段 - self.semantic_embedding = kwargs.pop("semantic_embedding", None) - self.interest_calculated = kwargs.pop("interest_calculated", False) - - # 处理额外传入的字段(kwargs) - if kwargs: - for key, value in kwargs.items(): - setattr(self, key, value) + # 运行时字段 + self.semantic_embedding = semantic_embedding + self.interest_calculated = interest_calculated + self.is_voice = is_voice + self.is_video = is_video + self.has_emoji = has_emoji + self.has_picid = has_picid + # 注意: id 参数从数据库加载时会传入,但不存储(使用 message_id 作为业务主键) def flatten(self) -> dict[str, Any]: """ 将消息对象转换为字典格式,便于序列化存储或传输。 - - Returns: - 包含所有字段的字典,其中嵌套对象(如 user_info、group_info)已展开为扁平结构。 + 嵌套对象(如 user_info、group_info、chat_info)展开为扁平结构。 """ return { "message_id": self.message_id, @@ -228,13 +282,17 @@ class DatabaseMessages(BaseDataModel): "is_read": self.is_read, "actions": self.actions, "should_reply": self.should_reply, + "should_act": self.should_act, + # user_info 展开 "user_id": self.user_info.user_id, "user_nickname": self.user_info.user_nickname, "user_cardname": self.user_info.user_cardname, "user_platform": self.user_info.platform, + # group_info 展开(可能为 None) "chat_info_group_id": self.group_info.group_id if self.group_info else None, "chat_info_group_name": self.group_info.group_name if self.group_info else None, "chat_info_group_platform": self.group_info.platform if self.group_info else None, + # chat_info 展开 "chat_info_stream_id": self.chat_info.stream_id, "chat_info_platform": self.chat_info.platform, "chat_info_create_time": self.chat_info.create_time, @@ -243,6 +301,9 @@ class DatabaseMessages(BaseDataModel): "chat_info_user_nickname": self.chat_info.user_info.user_nickname, "chat_info_user_cardname": self.chat_info.user_info.user_cardname, "chat_info_user_platform": self.chat_info.user_info.platform, + # 运行时字段 + "semantic_embedding": self.semantic_embedding, + "interest_calculated": self.interest_calculated, } def update_message_info( @@ -301,13 +362,62 @@ class DatabaseMessages(BaseDataModel): "display_message": self.display_message, } + # DatabaseMessages 接受的所有参数名集合(用于 from_dict 过滤) + _VALID_INIT_PARAMS: frozenset[str] = frozenset({ + "message_id", "time", "chat_id", "reply_to", "interest_value", + "key_words", "key_words_lite", "is_mentioned", "is_at", + "reply_probability_boost", "processed_plain_text", "display_message", + "priority_mode", "priority_info", "additional_config", + "is_emoji", "is_picid", "is_command", "is_notify", "is_public_notice", + "notice_type", "selected_expressions", "is_read", "actions", + "should_reply", "should_act", + "user_id", "user_nickname", "user_cardname", "user_platform", + "chat_info_group_id", "chat_info_group_name", "chat_info_group_platform", + "chat_info_user_id", "chat_info_user_nickname", "chat_info_user_cardname", + "chat_info_user_platform", "chat_info_stream_id", "chat_info_platform", + "chat_info_create_time", "chat_info_last_active_time", + "semantic_embedding", "interest_calculated", + "is_voice", "is_video", "has_emoji", "has_picid", + "id", # 数据库自增主键 + }) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "DatabaseMessages": + """ + 从字典创建 DatabaseMessages 实例,自动过滤掉不支持的参数。 + + Args: + data: 包含消息数据的字典(如从数据库查询返回的结果) + + Returns: + DatabaseMessages 实例 + """ + # 只保留有效的参数 + filtered_data = {k: v for k, v in data.items() if k in cls._VALID_INIT_PARAMS} + return cls(**filtered_data) + @dataclass(init=False) class DatabaseActionRecords(BaseDataModel): """ 动作记录数据模型,用于记录系统执行的某个操作或动作的详细信息。 用于审计、日志或调试。 + 使用 __slots__ 优化内存占用和属性访问性能。 """ + + __slots__ = ( + "action_id", + "time", + "action_name", + "action_data", + "action_done", + "action_build_into_prompt", + "action_prompt_display", + "chat_id", + "chat_info_stream_id", + "chat_info_platform", + ) + def __init__( self, action_id: str, # 动作唯一 ID diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index e0733ad4b..4c4955cef 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -533,7 +533,7 @@ class StreamContext(BaseDataModel): loaded_count = 0 for msg_dict in db_messages: try: - db_msg = DatabaseMessages(**msg_dict) + db_msg = DatabaseMessages.from_dict(msg_dict) db_msg.is_read = True self.history_messages.append(db_msg) loaded_count += 1 diff --git a/src/memory_graph/models.py b/src/memory_graph/models.py index 3f4378e9c..6b662b28f 100644 --- a/src/memory_graph/models.py +++ b/src/memory_graph/models.py @@ -2,6 +2,7 @@ 记忆图系统核心数据模型 定义节点、边、记忆等核心数据结构(包含三层记忆系统) +使用 __slots__ 优化内存占用和属性访问性能 """ from __future__ import annotations @@ -112,7 +113,7 @@ class MemoryStatus(Enum): ARCHIVED = "archived" # 已归档(低价值,很少访问) -@dataclass +@dataclass(slots=True) class MemoryNode: """记忆节点""" @@ -168,7 +169,7 @@ class MemoryNode: return f"Node({self.node_type.value}: {self.content})" -@dataclass +@dataclass(slots=True) class MemoryEdge: """记忆边(节点之间的关系)""" @@ -219,7 +220,7 @@ class MemoryEdge: return f"Edge({self.source_id} --{self.relation}--> {self.target_id})" -@dataclass +@dataclass(slots=True) class Memory: """完整记忆(由节点和边组成的子图)""" @@ -342,7 +343,7 @@ class Memory: return f"Memory({self.memory_type.value}: {self.to_text()})" -@dataclass +@dataclass(slots=True) class StagedMemory: """临时记忆(未整理状态)""" @@ -379,7 +380,7 @@ class StagedMemory: # ============================================================================ -@dataclass +@dataclass(slots=True) class MemoryBlock: """ 感知记忆块 @@ -439,7 +440,7 @@ class MemoryBlock: return f"MemoryBlock({self.id[:8]}, messages={len(self.messages)}, recalls={self.recall_count})" -@dataclass +@dataclass(slots=True) class PerceptualMemory: """ 感知记忆(记忆堆的完整状态) @@ -478,7 +479,7 @@ class PerceptualMemory: ) -@dataclass +@dataclass(slots=True) class ShortTermMemory: """ 短期记忆 @@ -558,7 +559,7 @@ class ShortTermMemory: return f"ShortTermMemory({self.id[:8]}, content={self.content[:30]}..., importance={self.importance:.2f})" -@dataclass +@dataclass(slots=True) class GraphOperation: """ 图操作指令 @@ -604,7 +605,7 @@ class GraphOperation: return f"GraphOperation({self.operation_type.value}, target={self.target_id}, confidence={self.confidence:.2f})" -@dataclass +@dataclass(slots=True) class JudgeDecision: """ 裁判模型决策结果 @@ -648,7 +649,7 @@ class JudgeDecision: return f"JudgeDecision({status}, confidence={self.confidence:.2f}, extra_queries={len(self.additional_queries)})" -@dataclass +@dataclass(slots=True) class ShortTermDecision: """ 短期记忆决策结果 diff --git a/src/memory_graph/utils/path_expansion.py b/src/memory_graph/utils/path_expansion.py index 8f34bfae8..90421d2a4 100644 --- a/src/memory_graph/utils/path_expansion.py +++ b/src/memory_graph/utils/path_expansion.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -@dataclass +@dataclass(slots=True) class Path: """表示一条路径""" @@ -58,7 +58,7 @@ class Path: return node_id in self.nodes -@dataclass +@dataclass(slots=True) class PathExpansionConfig: """路径扩展配置""" diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index aa3147785..d9a97ce09 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -16,7 +16,7 @@ class InjectionType(Enum): return self.value -@dataclass +@dataclass(slots=True) class InjectionRule: """Prompt注入规则""" @@ -118,7 +118,7 @@ class EventType(Enum): return self.value -@dataclass +@dataclass(slots=True) class PythonDependency: """Python包依赖信息""" @@ -139,7 +139,7 @@ class PythonDependency: return self.install_name -@dataclass +@dataclass(slots=True) class PermissionNodeField: """权限节点声明字段""" @@ -147,7 +147,7 @@ class PermissionNodeField: description: str # 权限描述 -@dataclass +@dataclass(slots=True) class AdapterInfo: """适配器组件信息""" diff --git a/src/plugin_system/base/config_types.py b/src/plugin_system/base/config_types.py index 9dc9b58eb..0fe2acd00 100644 --- a/src/plugin_system/base/config_types.py +++ b/src/plugin_system/base/config_types.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from typing import Any -@dataclass +@dataclass(slots=True) class ConfigField: """配置字段定义""" diff --git a/src/plugin_system/base/plugin_metadata.py b/src/plugin_system/base/plugin_metadata.py index be25e04d7..0fd9169b6 100644 --- a/src/plugin_system/base/plugin_metadata.py +++ b/src/plugin_system/base/plugin_metadata.py @@ -4,7 +4,7 @@ from typing import Any from src.plugin_system.base.component_types import PythonDependency -@dataclass +@dataclass(slots=True) class PluginMetadata: """ 插件元数据,用于存储插件的开发者信息和用户帮助信息。 diff --git a/src/plugin_system/core/stream_tool_history.py b/src/plugin_system/core/stream_tool_history.py index 6a2cfc997..e589e6fe7 100644 --- a/src/plugin_system/core/stream_tool_history.py +++ b/src/plugin_system/core/stream_tool_history.py @@ -15,7 +15,7 @@ from src.common.logger import get_logger logger = get_logger("stream_tool_history") -@dataclass +@dataclass(slots=True) class ToolCallRecord: """工具调用记录""" tool_name: str diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index b79565f89..09322c184 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -17,7 +17,7 @@ from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_strea logger = get_logger("tool_use") -@dataclass +@dataclass(slots=True) class ToolExecutionConfig: """工具执行配置""" max_concurrent_tools: int = 5 # 最大并发工具数量 @@ -25,7 +25,7 @@ class ToolExecutionConfig: enable_dependency_check: bool = True # 是否启用依赖检查 -@dataclass +@dataclass(slots=True) class ToolExecutionResult: """工具执行结果""" tool_call: ToolCall diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py index d13fbe2a3..3373a8b0d 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py @@ -481,7 +481,7 @@ class ChatterPlanFilter: ) # 将字典转换为DatabaseMessages对象 read_messages = [ - DatabaseMessages(**msg_dict) for msg_dict in fallback_messages_dicts + DatabaseMessages.from_dict(msg_dict) for msg_dict in fallback_messages_dicts ] unread_messages = stream_context.get_unread_messages() # 获取未读消息 @@ -646,8 +646,8 @@ class ChatterPlanFilter: target_message_obj["message_id"] = target_message_obj["id"] try: - # 使用 ** 解包字典传入构造函数 - action_message_obj = DatabaseMessages(**target_message_obj) + # 使用 from_dict 工厂方法创建对象(自动过滤无效参数) + action_message_obj = DatabaseMessages.from_dict(target_message_obj) logger.debug( f"[{action}] 成功转换目标消息为 DatabaseMessages 对象: {action_message_obj.message_id}" ) @@ -670,7 +670,7 @@ class ChatterPlanFilter: if latest_message_dict: from src.common.data_models.database_data_model import DatabaseMessages try: - action_message_obj = DatabaseMessages(**latest_message_dict) + action_message_obj = DatabaseMessages.from_dict(latest_message_dict) logger.info(f"[{action}] 成功使用最新消息: {action_message_obj.message_id}") except Exception as e: logger.error(f"[{action}] 无法转换最新消息: {e}") @@ -689,7 +689,7 @@ class ChatterPlanFilter: if target_message_dict: from src.common.data_models.database_data_model import DatabaseMessages try: - action_message_obj = DatabaseMessages(**target_message_dict) + action_message_obj = DatabaseMessages.from_dict(target_message_dict) except Exception as e: logger.error( f"[{action}] 无法将默认的最新消息转换为 DatabaseMessages 对象: {e}", diff --git a/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py b/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py index cf14e06e1..6a7bb4fd1 100644 --- a/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py +++ b/src/plugins/built_in/kokoro_flow_chatter/prompt/builder.py @@ -309,7 +309,7 @@ class PromptBuilder: limit=30, # 限制数量,私聊不需要太多 ) history_messages = [ - DatabaseMessages(**msg_dict) for msg_dict in fallback_messages_dicts + DatabaseMessages.from_dict(msg_dict) for msg_dict in fallback_messages_dicts ] if not history_messages: