ruff,私聊视为提及了bot
This commit is contained in:
@@ -6,6 +6,7 @@ class BaseDataModel:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
机器人兴趣标签数据模型
|
||||
定义机器人的兴趣标签和相关的embedding数据结构
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
@@ -12,6 +13,7 @@ from . import BaseDataModel
|
||||
@dataclass
|
||||
class BotInterestTag(BaseDataModel):
|
||||
"""机器人兴趣标签"""
|
||||
|
||||
tag_name: str
|
||||
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
||||
embedding: Optional[List[float]] = None # 标签的embedding向量
|
||||
@@ -27,7 +29,7 @@ class BotInterestTag(BaseDataModel):
|
||||
"embedding": self.embedding,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"is_active": self.is_active
|
||||
"is_active": self.is_active,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -39,13 +41,14 @@ class BotInterestTag(BaseDataModel):
|
||||
embedding=data.get("embedding"),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
|
||||
is_active=data.get("is_active", True)
|
||||
is_active=data.get("is_active", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotPersonalityInterests(BaseDataModel):
|
||||
"""机器人人格化兴趣配置"""
|
||||
|
||||
personality_id: str
|
||||
personality_description: str # 人设描述文本
|
||||
interest_tags: List[BotInterestTag] = field(default_factory=list)
|
||||
@@ -57,7 +60,6 @@ class BotPersonalityInterests(BaseDataModel):
|
||||
"""获取活跃的兴趣标签"""
|
||||
return [tag for tag in self.interest_tags if tag.is_active]
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
@@ -66,7 +68,7 @@ class BotPersonalityInterests(BaseDataModel):
|
||||
"interest_tags": [tag.to_dict() for tag in self.interest_tags],
|
||||
"embedding_model": self.embedding_model,
|
||||
"last_updated": self.last_updated.isoformat(),
|
||||
"version": self.version
|
||||
"version": self.version,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -78,13 +80,14 @@ class BotPersonalityInterests(BaseDataModel):
|
||||
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
|
||||
embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
|
||||
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
|
||||
version=data.get("version", 1)
|
||||
version=data.get("version", 1),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestMatchResult(BaseDataModel):
|
||||
"""兴趣匹配结果"""
|
||||
|
||||
message_id: str
|
||||
matched_tags: List[str] = field(default_factory=list)
|
||||
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
|
||||
@@ -120,7 +123,9 @@ class InterestMatchResult(BaseDataModel):
|
||||
# 计算置信度(基于匹配标签数量和分数分布)
|
||||
if len(self.match_scores) > 0:
|
||||
avg_score = self.overall_score
|
||||
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(self.match_scores)
|
||||
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
|
||||
self.match_scores
|
||||
)
|
||||
# 分数越集中,置信度越高
|
||||
self.confidence = max(0.0, 1.0 - score_variance)
|
||||
else:
|
||||
@@ -129,4 +134,4 @@ class InterestMatchResult(BaseDataModel):
|
||||
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
|
||||
"""获取前N个最佳匹配"""
|
||||
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_matches[:top_n]
|
||||
return sorted_matches[:top_n]
|
||||
|
||||
@@ -208,6 +208,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class DatabaseActionRecords(BaseDataModel):
|
||||
def __init__(
|
||||
@@ -235,4 +236,4 @@ class DatabaseActionRecords(BaseDataModel):
|
||||
self.action_prompt_display = action_prompt_display
|
||||
self.chat_id = chat_id
|
||||
self.chat_info_stream_id = chat_info_stream_id
|
||||
self.chat_info_platform = chat_info_platform
|
||||
self.chat_info_platform = chat_info_platform
|
||||
|
||||
@@ -28,6 +28,7 @@ class ActionPlannerInfo(BaseDataModel):
|
||||
@dataclass
|
||||
class InterestScore(BaseDataModel):
|
||||
"""兴趣度评分结果"""
|
||||
|
||||
message_id: str
|
||||
total_score: float
|
||||
interest_match_score: float
|
||||
@@ -41,6 +42,7 @@ class Plan(BaseDataModel):
|
||||
"""
|
||||
统一规划数据模型
|
||||
"""
|
||||
|
||||
chat_id: str
|
||||
mode: "ChatMode"
|
||||
|
||||
|
||||
@@ -2,9 +2,11 @@ from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMGenerationDataModel(BaseDataModel):
|
||||
content: Optional[str] = None
|
||||
@@ -13,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
|
||||
tool_calls: Optional[List["ToolCall"]] = None
|
||||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
消息管理模块数据模型
|
||||
定义消息管理器使用的数据结构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
@@ -16,14 +17,16 @@ if TYPE_CHECKING:
|
||||
|
||||
class MessageStatus(Enum):
|
||||
"""消息状态枚举"""
|
||||
UNREAD = "unread" # 未读消息
|
||||
READ = "read" # 已读消息
|
||||
|
||||
UNREAD = "unread" # 未读消息
|
||||
READ = "read" # 已读消息
|
||||
PROCESSING = "processing" # 处理中
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext(BaseDataModel):
|
||||
"""聊天流上下文信息"""
|
||||
|
||||
stream_id: str
|
||||
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
@@ -59,6 +62,7 @@ class StreamContext(BaseDataModel):
|
||||
@dataclass
|
||||
class MessageManagerStats(BaseDataModel):
|
||||
"""消息管理器统计信息"""
|
||||
|
||||
total_streams: int = 0
|
||||
active_streams: int = 0
|
||||
total_unread_messages: int = 0
|
||||
@@ -74,9 +78,10 @@ class MessageManagerStats(BaseDataModel):
|
||||
@dataclass
|
||||
class StreamStats(BaseDataModel):
|
||||
"""聊天流统计信息"""
|
||||
|
||||
stream_id: str
|
||||
is_active: bool
|
||||
unread_count: int
|
||||
history_count: int
|
||||
last_check_time: float
|
||||
has_active_task: bool
|
||||
has_active_task: bool
|
||||
|
||||
Reference in New Issue
Block a user