re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -4,8 +4,8 @@
"""
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from datetime import datetime
from typing import Any
from . import BaseDataModel
@@ -16,12 +16,12 @@ class BotInterestTag(BaseDataModel):
tag_name: str
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
embedding: Optional[List[float]] = None # 标签的embedding向量
embedding: list[float] | None = None # 标签的embedding向量
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
is_active: bool = True
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"tag_name": self.tag_name,
@@ -33,7 +33,7 @@ class BotInterestTag(BaseDataModel):
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag":
def from_dict(cls, data: dict[str, Any]) -> "BotInterestTag":
"""从字典创建对象"""
return cls(
tag_name=data["tag_name"],
@@ -51,16 +51,16 @@ class BotPersonalityInterests(BaseDataModel):
personality_id: str
personality_description: str # 人设描述文本
interest_tags: List[BotInterestTag] = field(default_factory=list)
interest_tags: list[BotInterestTag] = field(default_factory=list)
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
last_updated: datetime = field(default_factory=datetime.now)
version: int = 1 # 版本号,用于追踪更新
def get_active_tags(self) -> List[BotInterestTag]:
def get_active_tags(self) -> list[BotInterestTag]:
"""获取活跃的兴趣标签"""
return [tag for tag in self.interest_tags if tag.is_active]
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"personality_id": self.personality_id,
@@ -72,7 +72,7 @@ class BotPersonalityInterests(BaseDataModel):
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests":
def from_dict(cls, data: dict[str, Any]) -> "BotPersonalityInterests":
"""从字典创建对象"""
return cls(
personality_id=data["personality_id"],
@@ -89,14 +89,14 @@ class InterestMatchResult(BaseDataModel):
"""兴趣匹配结果"""
message_id: str
matched_tags: List[str] = field(default_factory=list)
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
matched_tags: list[str] = field(default_factory=list)
match_scores: dict[str, float] = field(default_factory=dict) # tag_name -> score
overall_score: float = 0.0
top_tag: Optional[str] = None
top_tag: str | None = None
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
matched_keywords: List[str] = field(default_factory=list)
matched_keywords: list[str] = field(default_factory=list)
def add_match(self, tag_name: str, score: float, keywords: List[str] = None):
def add_match(self, tag_name: str, score: float, keywords: list[str] = None):
"""添加匹配结果"""
self.matched_tags.append(tag_name)
self.match_scores[tag_name] = score
@@ -131,7 +131,7 @@ class InterestMatchResult(BaseDataModel):
else:
self.confidence = 0.0
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
def get_top_matches(self, top_n: int = 3) -> list[tuple]:
"""获取前N个最佳匹配"""
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
return sorted_matches[:top_n]

View File

@@ -1,6 +1,6 @@
import json
from typing import Optional, Any, Dict
from dataclasses import dataclass, field
from typing import Any
from . import BaseDataModel
@@ -10,7 +10,7 @@ class DatabaseUserInfo(BaseDataModel):
platform: str = field(default_factory=str)
user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
user_cardname: Optional[str] = None
user_cardname: str | None = None
# def __post_init__(self):
# assert isinstance(self.platform, str), "platform must be a string"
@@ -25,7 +25,7 @@ class DatabaseUserInfo(BaseDataModel):
class DatabaseGroupInfo(BaseDataModel):
group_id: str = field(default_factory=str)
group_name: str = field(default_factory=str)
group_platform: Optional[str] = None
group_platform: str | None = None
# def __post_init__(self):
# assert isinstance(self.group_id, str), "group_id must be a string"
@@ -42,7 +42,7 @@ class DatabaseChatInfo(BaseDataModel):
create_time: float = field(default_factory=float)
last_active_time: float = field(default_factory=float)
user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo)
group_info: Optional[DatabaseGroupInfo] = None
group_info: DatabaseGroupInfo | None = None
# def __post_init__(self):
# assert isinstance(self.stream_id, str), "stream_id must be a string"
@@ -62,41 +62,41 @@ class DatabaseMessages(BaseDataModel):
message_id: str = "",
time: float = 0.0,
chat_id: str = "",
reply_to: Optional[str] = None,
interest_value: Optional[float] = None,
key_words: Optional[str] = None,
key_words_lite: Optional[str] = None,
is_mentioned: Optional[bool] = None,
is_at: Optional[bool] = None,
reply_probability_boost: Optional[float] = None,
processed_plain_text: Optional[str] = None,
display_message: Optional[str] = None,
priority_mode: Optional[str] = None,
priority_info: Optional[str] = None,
additional_config: Optional[str] = None,
reply_to: str | None = None,
interest_value: float | None = None,
key_words: str | None = None,
key_words_lite: str | None = None,
is_mentioned: bool | None = None,
is_at: bool | None = None,
reply_probability_boost: float | None = None,
processed_plain_text: str | None = None,
display_message: str | None = None,
priority_mode: str | None = None,
priority_info: str | None = None,
additional_config: str | None = None,
is_emoji: bool = False,
is_picid: bool = False,
is_command: bool = False,
is_notify: bool = False,
selected_expressions: Optional[str] = None,
selected_expressions: str | None = None,
is_read: bool = False,
user_id: str = "",
user_nickname: str = "",
user_cardname: Optional[str] = None,
user_cardname: str | None = None,
user_platform: str = "",
chat_info_group_id: Optional[str] = None,
chat_info_group_name: Optional[str] = None,
chat_info_group_platform: Optional[str] = None,
chat_info_group_id: str | None = None,
chat_info_group_name: str | None = None,
chat_info_group_platform: str | None = None,
chat_info_user_id: str = "",
chat_info_user_nickname: str = "",
chat_info_user_cardname: Optional[str] = None,
chat_info_user_cardname: str | None = None,
chat_info_user_platform: str = "",
chat_info_stream_id: str = "",
chat_info_platform: str = "",
chat_info_create_time: float = 0.0,
chat_info_last_active_time: float = 0.0,
# 新增字段
actions: Optional[list] = None,
actions: list | None = None,
should_reply: bool = False,
**kwargs: Any,
):
@@ -132,7 +132,7 @@ class DatabaseMessages(BaseDataModel):
self.selected_expressions = selected_expressions
self.is_read = is_read
self.group_info: Optional[DatabaseGroupInfo] = None
self.group_info: DatabaseGroupInfo | None = None
self.user_info = DatabaseUserInfo(
user_id=user_id,
user_nickname=user_nickname,
@@ -172,7 +172,7 @@ class DatabaseMessages(BaseDataModel):
# assert isinstance(self.interest_value, float) or self.interest_value is None, (
# "interest_value must be a float or None"
# )
def flatten(self) -> Dict[str, Any]:
def flatten(self) -> dict[str, Any]:
"""
将消息数据模型转换为字典格式,便于存储或传输
"""
@@ -255,7 +255,7 @@ class DatabaseMessages(BaseDataModel):
"""
return self.actions or []
def get_message_summary(self) -> Dict[str, Any]:
def get_message_summary(self) -> dict[str, Any]:
"""
获取消息摘要信息

View File

@@ -1,30 +1,32 @@
from dataclasses import dataclass, field
from typing import Optional, Dict, List, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from src.plugin_system.base.component_types import ChatType
from . import BaseDataModel
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
from src.plugin_system.base.component_types import ActionInfo, ChatMode
from .database_data_model import DatabaseMessages
@dataclass
class TargetPersonInfo(BaseDataModel):
platform: str = field(default_factory=str)
user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
person_id: Optional[str] = None
person_name: Optional[str] = None
person_id: str | None = None
person_name: str | None = None
@dataclass
class ActionPlannerInfo(BaseDataModel):
action_type: str = field(default_factory=str)
reasoning: Optional[str] = None
action_data: Optional[Dict] = None
reasoning: str | None = None
action_data: dict | None = None
action_message: Optional["DatabaseMessages"] = None
available_actions: Optional[Dict[str, "ActionInfo"]] = None
available_actions: dict[str, "ActionInfo"] | None = None
@dataclass
@@ -36,7 +38,7 @@ class InterestScore(BaseDataModel):
interest_match_score: float
relationship_score: float
mentioned_score: float
details: Dict[str, str]
details: dict[str, str]
@dataclass
@@ -50,10 +52,10 @@ class Plan(BaseDataModel):
chat_type: "ChatType"
# Generator 填充
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
chat_history: List["DatabaseMessages"] = field(default_factory=list)
target_info: Optional[TargetPersonInfo] = None
available_actions: dict[str, "ActionInfo"] = field(default_factory=dict)
chat_history: list["DatabaseMessages"] = field(default_factory=list)
target_info: TargetPersonInfo | None = None
# Filter 填充
llm_prompt: Optional[str] = None
decided_actions: Optional[List[ActionPlannerInfo]] = None
llm_prompt: str | None = None
decided_actions: list[ActionPlannerInfo] | None = None

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any
from . import BaseDataModel
@@ -9,10 +9,10 @@ if TYPE_CHECKING:
@dataclass
class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None
reasoning: Optional[str] = None
model: Optional[str] = None
tool_calls: Optional[List["ToolCall"]] = None
prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None
content: str | None = None
reasoning: str | None = None
model: str | None = None
tool_calls: list["ToolCall"] | None = None
prompt: str | None = None
selected_expressions: list[int] | None = None
reply_set: list[tuple[str, Any]] | None = None

View File

@@ -7,11 +7,12 @@ import asyncio
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ChatMode, ChatType
from . import BaseDataModel
from src.plugin_system.base.component_types import ChatMode, ChatType
from src.common.logger import get_logger
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
@@ -34,11 +35,11 @@ class StreamContext(BaseDataModel):
stream_id: str
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
history_messages: List["DatabaseMessages"] = field(default_factory=list)
unread_messages: list["DatabaseMessages"] = field(default_factory=list)
history_messages: list["DatabaseMessages"] = field(default_factory=list)
last_check_time: float = field(default_factory=time.time)
is_active: bool = True
processing_task: Optional[asyncio.Task] = None
processing_task: asyncio.Task | None = None
interruption_count: int = 0 # 打断计数器
last_interruption_time: float = 0.0 # 上次打断时间
afc_threshold_adjustment: float = 0.0 # afc阈值调整量
@@ -49,8 +50,8 @@ class StreamContext(BaseDataModel):
# 新增字段以替代ChatMessageContext功能
current_message: Optional["DatabaseMessages"] = None
priority_mode: Optional[str] = None
priority_info: Optional[dict] = None
priority_mode: str | None = None
priority_info: dict | None = None
def add_message(self, message: "DatabaseMessages"):
"""添加消息到上下文"""
@@ -150,11 +151,11 @@ class StreamContext(BaseDataModel):
self.unread_messages.remove(msg)
break
def get_unread_messages(self) -> List["DatabaseMessages"]:
def get_unread_messages(self) -> list["DatabaseMessages"]:
"""获取未读消息"""
return [msg for msg in self.unread_messages if not msg.is_read]
def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]:
def get_history_messages(self, limit: int = 20) -> list["DatabaseMessages"]:
"""获取历史消息"""
# 优先返回最近的历史消息和所有未读消息
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
@@ -230,7 +231,7 @@ class StreamContext(BaseDataModel):
"""设置当前消息"""
self.current_message = message
def get_template_name(self) -> Optional[str]:
def get_template_name(self) -> str | None:
"""获取模板名称"""
if (
self.current_message
@@ -336,11 +337,11 @@ class StreamContext(BaseDataModel):
return False
return True
def get_priority_mode(self) -> Optional[str]:
def get_priority_mode(self) -> str | None:
"""获取优先级模式"""
return self.priority_mode
def get_priority_info(self) -> Optional[dict]:
def get_priority_info(self) -> dict | None:
"""获取优先级信息"""
return self.priority_info