re-style: 格式化代码
This commit is contained in:
@@ -4,8 +4,8 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
@@ -16,12 +16,12 @@ class BotInterestTag(BaseDataModel):
|
||||
|
||||
tag_name: str
|
||||
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
||||
embedding: Optional[List[float]] = None # 标签的embedding向量
|
||||
embedding: list[float] | None = None # 标签的embedding向量
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"tag_name": self.tag_name,
|
||||
@@ -33,7 +33,7 @@ class BotInterestTag(BaseDataModel):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "BotInterestTag":
|
||||
"""从字典创建对象"""
|
||||
return cls(
|
||||
tag_name=data["tag_name"],
|
||||
@@ -51,16 +51,16 @@ class BotPersonalityInterests(BaseDataModel):
|
||||
|
||||
personality_id: str
|
||||
personality_description: str # 人设描述文本
|
||||
interest_tags: List[BotInterestTag] = field(default_factory=list)
|
||||
interest_tags: list[BotInterestTag] = field(default_factory=list)
|
||||
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
|
||||
last_updated: datetime = field(default_factory=datetime.now)
|
||||
version: int = 1 # 版本号,用于追踪更新
|
||||
|
||||
def get_active_tags(self) -> List[BotInterestTag]:
|
||||
def get_active_tags(self) -> list[BotInterestTag]:
|
||||
"""获取活跃的兴趣标签"""
|
||||
return [tag for tag in self.interest_tags if tag.is_active]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"personality_id": self.personality_id,
|
||||
@@ -72,7 +72,7 @@ class BotPersonalityInterests(BaseDataModel):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "BotPersonalityInterests":
|
||||
"""从字典创建对象"""
|
||||
return cls(
|
||||
personality_id=data["personality_id"],
|
||||
@@ -89,14 +89,14 @@ class InterestMatchResult(BaseDataModel):
|
||||
"""兴趣匹配结果"""
|
||||
|
||||
message_id: str
|
||||
matched_tags: List[str] = field(default_factory=list)
|
||||
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
|
||||
matched_tags: list[str] = field(default_factory=list)
|
||||
match_scores: dict[str, float] = field(default_factory=dict) # tag_name -> score
|
||||
overall_score: float = 0.0
|
||||
top_tag: Optional[str] = None
|
||||
top_tag: str | None = None
|
||||
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
|
||||
matched_keywords: List[str] = field(default_factory=list)
|
||||
matched_keywords: list[str] = field(default_factory=list)
|
||||
|
||||
def add_match(self, tag_name: str, score: float, keywords: List[str] = None):
|
||||
def add_match(self, tag_name: str, score: float, keywords: list[str] = None):
|
||||
"""添加匹配结果"""
|
||||
self.matched_tags.append(tag_name)
|
||||
self.match_scores[tag_name] = score
|
||||
@@ -131,7 +131,7 @@ class InterestMatchResult(BaseDataModel):
|
||||
else:
|
||||
self.confidence = 0.0
|
||||
|
||||
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
|
||||
def get_top_matches(self, top_n: int = 3) -> list[tuple]:
|
||||
"""获取前N个最佳匹配"""
|
||||
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_matches[:top_n]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from typing import Optional, Any, Dict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
@@ -10,7 +10,7 @@ class DatabaseUserInfo(BaseDataModel):
|
||||
platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
user_cardname: str | None = None
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.platform, str), "platform must be a string"
|
||||
@@ -25,7 +25,7 @@ class DatabaseUserInfo(BaseDataModel):
|
||||
class DatabaseGroupInfo(BaseDataModel):
|
||||
group_id: str = field(default_factory=str)
|
||||
group_name: str = field(default_factory=str)
|
||||
group_platform: Optional[str] = None
|
||||
group_platform: str | None = None
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.group_id, str), "group_id must be a string"
|
||||
@@ -42,7 +42,7 @@ class DatabaseChatInfo(BaseDataModel):
|
||||
create_time: float = field(default_factory=float)
|
||||
last_active_time: float = field(default_factory=float)
|
||||
user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo)
|
||||
group_info: Optional[DatabaseGroupInfo] = None
|
||||
group_info: DatabaseGroupInfo | None = None
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.stream_id, str), "stream_id must be a string"
|
||||
@@ -62,41 +62,41 @@ class DatabaseMessages(BaseDataModel):
|
||||
message_id: str = "",
|
||||
time: float = 0.0,
|
||||
chat_id: str = "",
|
||||
reply_to: Optional[str] = None,
|
||||
interest_value: Optional[float] = None,
|
||||
key_words: Optional[str] = None,
|
||||
key_words_lite: Optional[str] = None,
|
||||
is_mentioned: Optional[bool] = None,
|
||||
is_at: Optional[bool] = None,
|
||||
reply_probability_boost: Optional[float] = None,
|
||||
processed_plain_text: Optional[str] = None,
|
||||
display_message: Optional[str] = None,
|
||||
priority_mode: Optional[str] = None,
|
||||
priority_info: Optional[str] = None,
|
||||
additional_config: Optional[str] = None,
|
||||
reply_to: str | None = None,
|
||||
interest_value: float | None = None,
|
||||
key_words: str | None = None,
|
||||
key_words_lite: str | None = None,
|
||||
is_mentioned: bool | None = None,
|
||||
is_at: bool | None = None,
|
||||
reply_probability_boost: float | None = None,
|
||||
processed_plain_text: str | None = None,
|
||||
display_message: str | None = None,
|
||||
priority_mode: str | None = None,
|
||||
priority_info: str | None = None,
|
||||
additional_config: str | None = None,
|
||||
is_emoji: bool = False,
|
||||
is_picid: bool = False,
|
||||
is_command: bool = False,
|
||||
is_notify: bool = False,
|
||||
selected_expressions: Optional[str] = None,
|
||||
selected_expressions: str | None = None,
|
||||
is_read: bool = False,
|
||||
user_id: str = "",
|
||||
user_nickname: str = "",
|
||||
user_cardname: Optional[str] = None,
|
||||
user_cardname: str | None = None,
|
||||
user_platform: str = "",
|
||||
chat_info_group_id: Optional[str] = None,
|
||||
chat_info_group_name: Optional[str] = None,
|
||||
chat_info_group_platform: Optional[str] = None,
|
||||
chat_info_group_id: str | None = None,
|
||||
chat_info_group_name: str | None = None,
|
||||
chat_info_group_platform: str | None = None,
|
||||
chat_info_user_id: str = "",
|
||||
chat_info_user_nickname: str = "",
|
||||
chat_info_user_cardname: Optional[str] = None,
|
||||
chat_info_user_cardname: str | None = None,
|
||||
chat_info_user_platform: str = "",
|
||||
chat_info_stream_id: str = "",
|
||||
chat_info_platform: str = "",
|
||||
chat_info_create_time: float = 0.0,
|
||||
chat_info_last_active_time: float = 0.0,
|
||||
# 新增字段
|
||||
actions: Optional[list] = None,
|
||||
actions: list | None = None,
|
||||
should_reply: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@@ -132,7 +132,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.selected_expressions = selected_expressions
|
||||
self.is_read = is_read
|
||||
|
||||
self.group_info: Optional[DatabaseGroupInfo] = None
|
||||
self.group_info: DatabaseGroupInfo | None = None
|
||||
self.user_info = DatabaseUserInfo(
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
@@ -172,7 +172,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
# assert isinstance(self.interest_value, float) or self.interest_value is None, (
|
||||
# "interest_value must be a float or None"
|
||||
# )
|
||||
def flatten(self) -> Dict[str, Any]:
|
||||
def flatten(self) -> dict[str, Any]:
|
||||
"""
|
||||
将消息数据模型转换为字典格式,便于存储或传输
|
||||
"""
|
||||
@@ -255,7 +255,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
"""
|
||||
return self.actions or []
|
||||
|
||||
def get_message_summary(self) -> Dict[str, Any]:
|
||||
def get_message_summary(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取消息摘要信息
|
||||
|
||||
|
||||
@@ -1,30 +1,32 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .database_data_model import DatabaseMessages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
|
||||
from .database_data_model import DatabaseMessages
|
||||
|
||||
|
||||
@dataclass
|
||||
class TargetPersonInfo(BaseDataModel):
|
||||
platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
person_id: Optional[str] = None
|
||||
person_name: Optional[str] = None
|
||||
person_id: str | None = None
|
||||
person_name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionPlannerInfo(BaseDataModel):
|
||||
action_type: str = field(default_factory=str)
|
||||
reasoning: Optional[str] = None
|
||||
action_data: Optional[Dict] = None
|
||||
reasoning: str | None = None
|
||||
action_data: dict | None = None
|
||||
action_message: Optional["DatabaseMessages"] = None
|
||||
available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||
available_actions: dict[str, "ActionInfo"] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -36,7 +38,7 @@ class InterestScore(BaseDataModel):
|
||||
interest_match_score: float
|
||||
relationship_score: float
|
||||
mentioned_score: float
|
||||
details: Dict[str, str]
|
||||
details: dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -50,10 +52,10 @@ class Plan(BaseDataModel):
|
||||
|
||||
chat_type: "ChatType"
|
||||
# Generator 填充
|
||||
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
|
||||
chat_history: List["DatabaseMessages"] = field(default_factory=list)
|
||||
target_info: Optional[TargetPersonInfo] = None
|
||||
available_actions: dict[str, "ActionInfo"] = field(default_factory=dict)
|
||||
chat_history: list["DatabaseMessages"] = field(default_factory=list)
|
||||
target_info: TargetPersonInfo | None = None
|
||||
|
||||
# Filter 填充
|
||||
llm_prompt: Optional[str] = None
|
||||
decided_actions: Optional[List[ActionPlannerInfo]] = None
|
||||
llm_prompt: str | None = None
|
||||
decided_actions: list[ActionPlannerInfo] | None = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
@@ -9,10 +9,10 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class LLMGenerationDataModel(BaseDataModel):
|
||||
content: Optional[str] = None
|
||||
reasoning: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
tool_calls: Optional[List["ToolCall"]] = None
|
||||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
content: str | None = None
|
||||
reasoning: str | None = None
|
||||
model: str | None = None
|
||||
tool_calls: list["ToolCall"] | None = None
|
||||
prompt: str | None = None
|
||||
selected_expressions: list[int] | None = None
|
||||
reply_set: list[tuple[str, Any]] | None = None
|
||||
|
||||
@@ -7,11 +7,12 @@ import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
from . import BaseDataModel
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .database_data_model import DatabaseMessages
|
||||
@@ -34,11 +35,11 @@ class StreamContext(BaseDataModel):
|
||||
stream_id: str
|
||||
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
|
||||
chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式
|
||||
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
unread_messages: list["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: list["DatabaseMessages"] = field(default_factory=list)
|
||||
last_check_time: float = field(default_factory=time.time)
|
||||
is_active: bool = True
|
||||
processing_task: Optional[asyncio.Task] = None
|
||||
processing_task: asyncio.Task | None = None
|
||||
interruption_count: int = 0 # 打断计数器
|
||||
last_interruption_time: float = 0.0 # 上次打断时间
|
||||
afc_threshold_adjustment: float = 0.0 # afc阈值调整量
|
||||
@@ -49,8 +50,8 @@ class StreamContext(BaseDataModel):
|
||||
|
||||
# 新增字段以替代ChatMessageContext功能
|
||||
current_message: Optional["DatabaseMessages"] = None
|
||||
priority_mode: Optional[str] = None
|
||||
priority_info: Optional[dict] = None
|
||||
priority_mode: str | None = None
|
||||
priority_info: dict | None = None
|
||||
|
||||
def add_message(self, message: "DatabaseMessages"):
|
||||
"""添加消息到上下文"""
|
||||
@@ -150,11 +151,11 @@ class StreamContext(BaseDataModel):
|
||||
self.unread_messages.remove(msg)
|
||||
break
|
||||
|
||||
def get_unread_messages(self) -> List["DatabaseMessages"]:
|
||||
def get_unread_messages(self) -> list["DatabaseMessages"]:
|
||||
"""获取未读消息"""
|
||||
return [msg for msg in self.unread_messages if not msg.is_read]
|
||||
|
||||
def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]:
|
||||
def get_history_messages(self, limit: int = 20) -> list["DatabaseMessages"]:
|
||||
"""获取历史消息"""
|
||||
# 优先返回最近的历史消息和所有未读消息
|
||||
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
|
||||
@@ -230,7 +231,7 @@ class StreamContext(BaseDataModel):
|
||||
"""设置当前消息"""
|
||||
self.current_message = message
|
||||
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
def get_template_name(self) -> str | None:
|
||||
"""获取模板名称"""
|
||||
if (
|
||||
self.current_message
|
||||
@@ -336,11 +337,11 @@ class StreamContext(BaseDataModel):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> Optional[str]:
|
||||
def get_priority_mode(self) -> str | None:
|
||||
"""获取优先级模式"""
|
||||
return self.priority_mode
|
||||
|
||||
def get_priority_info(self) -> Optional[dict]:
|
||||
def get_priority_info(self) -> dict | None:
|
||||
"""获取优先级信息"""
|
||||
return self.priority_info
|
||||
|
||||
|
||||
Reference in New Issue
Block a user