From b18a13b0915fe5a86522151d77ec98c57c0a44d8 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 12 Sep 2025 21:35:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=88=E6=B7=BB=E5=8A=A0=E4=B8=80=E4=B8=AA?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E6=A8=A1=E5=9E=8B=E4=BD=A0=E5=88=AB=E7=AE=A1?= =?UTF-8?q?=E4=BB=96=E7=94=A8=E6=B2=A1=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/data_models/__init__.py | 53 ++++ src/common/data_models/database_data_model.py | 235 ++++++++++++++++++ src/common/data_models/info_data_model.py | 25 ++ src/common/data_models/llm_data_model.py | 16 ++ src/common/data_models/message_data_model.py | 36 +++ 5 files changed, 365 insertions(+) create mode 100644 src/common/data_models/__init__.py create mode 100644 src/common/data_models/database_data_model.py create mode 100644 src/common/data_models/info_data_model.py create mode 100644 src/common/data_models/llm_data_model.py create mode 100644 src/common/data_models/message_data_model.py diff --git a/src/common/data_models/__init__.py b/src/common/data_models/__init__.py new file mode 100644 index 000000000..222ff59ca --- /dev/null +++ b/src/common/data_models/__init__.py @@ -0,0 +1,53 @@ +import copy +from typing import Any + + +class BaseDataModel: + def deepcopy(self): + return copy.deepcopy(self) + +def temporarily_transform_class_to_dict(obj: Any) -> Any: + # sourcery skip: assign-if-exp, reintroduce-else + """ + 将对象或容器中的 BaseDataModel 子类(类对象)或 BaseDataModel 实例 + 递归转换为普通 dict,不修改原对象。 + - 对于类对象(isinstance(value, type) 且 issubclass(..., BaseDataModel)), + 读取类的 __dict__ 中非 dunder 项并递归转换。 + - 对于实例(isinstance(value, BaseDataModel)),读取 vars(instance) 并递归转换。 + """ + + def _transform(value: Any) -> Any: + # 值是类对象且为 BaseDataModel 的子类 + if isinstance(value, type) and issubclass(value, BaseDataModel): + return {k: _transform(v) for k, v in value.__dict__.items() if not k.startswith("__") and not callable(v)} + + # 值是 BaseDataModel 的实例 + if isinstance(value, BaseDataModel): + return {k: _transform(v) for k, v in vars(value).items()} + + # 常见容器类型,递归处理 + if isinstance(value, dict): + return {k: _transform(v) for k, v in value.items()} + if isinstance(value, list): + return [_transform(v) for v in value] + if isinstance(value, tuple): + return tuple(_transform(v) for v in value) + if isinstance(value, set): + return {_transform(v) for v in value} + # 基本类型,直接返回 + return value + + result = _transform(obj) + + def flatten(target_dict: dict): + flat_dict = {} + for k, v in target_dict.items(): + if isinstance(v, dict): + # 递归扁平化子字典 + sub_flat = flatten(v) + flat_dict.update(sub_flat) + else: + flat_dict[k] = v + return flat_dict + + return flatten(result) if isinstance(result, dict) else result diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py new file mode 100644 index 000000000..bf4a5f527 --- /dev/null +++ b/src/common/data_models/database_data_model.py @@ -0,0 +1,235 @@ +import json +from typing import Optional, Any, Dict +from dataclasses import dataclass, field + +from . import BaseDataModel + + +@dataclass +class DatabaseUserInfo(BaseDataModel): + platform: str = field(default_factory=str) + user_id: str = field(default_factory=str) + user_nickname: str = field(default_factory=str) + user_cardname: Optional[str] = None + + # def __post_init__(self): + # assert isinstance(self.platform, str), "platform must be a string" + # assert isinstance(self.user_id, str), "user_id must be a string" + # assert isinstance(self.user_nickname, str), "user_nickname must be a string" + # assert isinstance(self.user_cardname, str) or self.user_cardname is None, ( + # "user_cardname must be a string or None" + # ) + + +@dataclass +class DatabaseGroupInfo(BaseDataModel): + group_id: str = field(default_factory=str) + group_name: str = field(default_factory=str) + group_platform: Optional[str] = None + + # def __post_init__(self): + # assert isinstance(self.group_id, str), "group_id must be a string" + # assert isinstance(self.group_name, str), "group_name must be a string" + # assert isinstance(self.group_platform, str) or self.group_platform is None, ( + # "group_platform must be a string or None" + # ) + + +@dataclass +class DatabaseChatInfo(BaseDataModel): + stream_id: str = field(default_factory=str) + platform: str = field(default_factory=str) + create_time: float = field(default_factory=float) + last_active_time: float = field(default_factory=float) + user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo) + group_info: Optional[DatabaseGroupInfo] = None + + # def __post_init__(self): + # assert isinstance(self.stream_id, str), "stream_id must be a string" + # assert isinstance(self.platform, str), "platform must be a string" + # assert isinstance(self.create_time, float), "create_time must be a float" + # assert isinstance(self.last_active_time, float), "last_active_time must be a float" + # assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance" + # assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, ( + # "group_info must be a DatabaseGroupInfo instance or None" + # ) + + +@dataclass(init=False) +class DatabaseMessages(BaseDataModel): + def __init__( + self, + message_id: str = "", + time: float = 0.0, + chat_id: str = "", + reply_to: Optional[str] = None, + interest_value: Optional[float] = None, + key_words: Optional[str] = None, + key_words_lite: Optional[str] = None, + is_mentioned: Optional[bool] = None, + is_at: Optional[bool] = None, + reply_probability_boost: Optional[float] = None, + processed_plain_text: Optional[str] = None, + display_message: Optional[str] = None, + priority_mode: Optional[str] = None, + priority_info: Optional[str] = None, + additional_config: Optional[str] = None, + is_emoji: bool = False, + is_picid: bool = False, + is_command: bool = False, + is_notify: bool = False, + selected_expressions: Optional[str] = None, + user_id: str = "", + user_nickname: str = "", + user_cardname: Optional[str] = None, + user_platform: str = "", + chat_info_group_id: Optional[str] = None, + chat_info_group_name: Optional[str] = None, + chat_info_group_platform: Optional[str] = None, + chat_info_user_id: str = "", + chat_info_user_nickname: str = "", + chat_info_user_cardname: Optional[str] = None, + chat_info_user_platform: str = "", + chat_info_stream_id: str = "", + chat_info_platform: str = "", + chat_info_create_time: float = 0.0, + chat_info_last_active_time: float = 0.0, + **kwargs: Any, + ): + self.message_id = message_id + self.time = time + self.chat_id = chat_id + self.reply_to = reply_to + self.interest_value = interest_value + + self.key_words = key_words + self.key_words_lite = key_words_lite + self.is_mentioned = is_mentioned + + self.is_at = is_at + self.reply_probability_boost = reply_probability_boost + + self.processed_plain_text = processed_plain_text + self.display_message = display_message + + self.priority_mode = priority_mode + self.priority_info = priority_info + + self.additional_config = additional_config + self.is_emoji = is_emoji + self.is_picid = is_picid + self.is_command = is_command + self.is_notify = is_notify + + self.selected_expressions = selected_expressions + + self.group_info: Optional[DatabaseGroupInfo] = None + self.user_info = DatabaseUserInfo( + user_id=user_id, + user_nickname=user_nickname, + user_cardname=user_cardname, + platform=user_platform, + ) + if chat_info_group_id and chat_info_group_name: + self.group_info = DatabaseGroupInfo( + group_id=chat_info_group_id, + group_name=chat_info_group_name, + group_platform=chat_info_group_platform, + ) + + self.chat_info = DatabaseChatInfo( + stream_id=chat_info_stream_id, + platform=chat_info_platform, + create_time=chat_info_create_time, + last_active_time=chat_info_last_active_time, + user_info=DatabaseUserInfo( + user_id=chat_info_user_id, + user_nickname=chat_info_user_nickname, + user_cardname=chat_info_user_cardname, + platform=chat_info_user_platform, + ), + group_info=self.group_info, + ) + + if kwargs: + for key, value in kwargs.items(): + setattr(self, key, value) + + # def __post_init__(self): + # assert isinstance(self.message_id, str), "message_id must be a string" + # assert isinstance(self.time, float), "time must be a float" + # assert isinstance(self.chat_id, str), "chat_id must be a string" + # assert isinstance(self.reply_to, str) or self.reply_to is None, "reply_to must be a string or None" + # assert isinstance(self.interest_value, float) or self.interest_value is None, ( + # "interest_value must be a float or None" + # ) + def flatten(self) -> Dict[str, Any]: + """ + 将消息数据模型转换为字典格式,便于存储或传输 + """ + return { + "message_id": self.message_id, + "time": self.time, + "chat_id": self.chat_id, + "reply_to": self.reply_to, + "interest_value": self.interest_value, + "key_words": self.key_words, + "key_words_lite": self.key_words_lite, + "is_mentioned": self.is_mentioned, + "is_at": self.is_at, + "reply_probability_boost": self.reply_probability_boost, + "processed_plain_text": self.processed_plain_text, + "display_message": self.display_message, + "priority_mode": self.priority_mode, + "priority_info": self.priority_info, + "additional_config": self.additional_config, + "is_emoji": self.is_emoji, + "is_picid": self.is_picid, + "is_command": self.is_command, + "is_notify": self.is_notify, + "selected_expressions": self.selected_expressions, + "user_id": self.user_info.user_id, + "user_nickname": self.user_info.user_nickname, + "user_cardname": self.user_info.user_cardname, + "user_platform": self.user_info.platform, + "chat_info_group_id": self.group_info.group_id if self.group_info else None, + "chat_info_group_name": self.group_info.group_name if self.group_info else None, + "chat_info_group_platform": self.group_info.group_platform if self.group_info else None, + "chat_info_stream_id": self.chat_info.stream_id, + "chat_info_platform": self.chat_info.platform, + "chat_info_create_time": self.chat_info.create_time, + "chat_info_last_active_time": self.chat_info.last_active_time, + "chat_info_user_platform": self.chat_info.user_info.platform, + "chat_info_user_id": self.chat_info.user_info.user_id, + "chat_info_user_nickname": self.chat_info.user_info.user_nickname, + "chat_info_user_cardname": self.chat_info.user_info.user_cardname, + } + +@dataclass(init=False) +class DatabaseActionRecords(BaseDataModel): + def __init__( + self, + action_id: str, + time: float, + action_name: str, + action_data: str, + action_done: bool, + action_build_into_prompt: bool, + action_prompt_display: str, + chat_id: str, + chat_info_stream_id: str, + chat_info_platform: str, + ): + self.action_id = action_id + self.time = time + self.action_name = action_name + if isinstance(action_data, str): + self.action_data = json.loads(action_data) + else: + raise ValueError("action_data must be a JSON string") + self.action_done = action_done + self.action_build_into_prompt = action_build_into_prompt + self.action_prompt_display = action_prompt_display + self.chat_id = chat_id + self.chat_info_stream_id = chat_info_stream_id + self.chat_info_platform = chat_info_platform \ No newline at end of file diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py new file mode 100644 index 000000000..0f7b1f950 --- /dev/null +++ b/src/common/data_models/info_data_model.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass, field +from typing import Optional, Dict, TYPE_CHECKING +from . import BaseDataModel + +if TYPE_CHECKING: + from .database_data_model import DatabaseMessages + from src.plugin_system.base.component_types import ActionInfo + + +@dataclass +class TargetPersonInfo(BaseDataModel): + platform: str = field(default_factory=str) + user_id: str = field(default_factory=str) + user_nickname: str = field(default_factory=str) + person_id: Optional[str] = None + person_name: Optional[str] = None + + +@dataclass +class ActionPlannerInfo(BaseDataModel): + action_type: str = field(default_factory=str) + reasoning: Optional[str] = None + action_data: Optional[Dict] = None + action_message: Optional["DatabaseMessages"] = None + available_actions: Optional[Dict[str, "ActionInfo"]] = None diff --git a/src/common/data_models/llm_data_model.py b/src/common/data_models/llm_data_model.py new file mode 100644 index 000000000..1d5b75e0c --- /dev/null +++ b/src/common/data_models/llm_data_model.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from typing import Optional, List, Tuple, TYPE_CHECKING, Any + +from . import BaseDataModel +if TYPE_CHECKING: + from src.llm_models.payload_content.tool_option import ToolCall + +@dataclass +class LLMGenerationDataModel(BaseDataModel): + content: Optional[str] = None + reasoning: Optional[str] = None + model: Optional[str] = None + tool_calls: Optional[List["ToolCall"]] = None + prompt: Optional[str] = None + selected_expressions: Optional[List[int]] = None + reply_set: Optional[List[Tuple[str, Any]]] = None \ No newline at end of file diff --git a/src/common/data_models/message_data_model.py b/src/common/data_models/message_data_model.py new file mode 100644 index 000000000..8e0b77862 --- /dev/null +++ b/src/common/data_models/message_data_model.py @@ -0,0 +1,36 @@ +from typing import Optional, TYPE_CHECKING +from dataclasses import dataclass, field + +from . import BaseDataModel + +if TYPE_CHECKING: + from .database_data_model import DatabaseMessages + + +@dataclass +class MessageAndActionModel(BaseDataModel): + chat_id: str = field(default_factory=str) + time: float = field(default_factory=float) + user_id: str = field(default_factory=str) + user_platform: str = field(default_factory=str) + user_nickname: str = field(default_factory=str) + user_cardname: Optional[str] = None + processed_plain_text: Optional[str] = None + display_message: Optional[str] = None + chat_info_platform: str = field(default_factory=str) + is_action_record: bool = field(default=False) + action_name: Optional[str] = None + + @classmethod + def from_DatabaseMessages(cls, message: "DatabaseMessages"): + return cls( + chat_id=message.chat_id, + time=message.time, + user_id=message.user_info.user_id, + user_platform=message.user_info.platform, + user_nickname=message.user_info.user_nickname, + user_cardname=message.user_info.user_cardname, + processed_plain_text=message.processed_plain_text, + display_message=message.display_message, + chat_info_platform=message.chat_info.platform, + )