Merge afc branch into dev, prioritizing afc changes and migrating database async modifications from dev
This commit is contained in:
@@ -6,6 +6,7 @@ class BaseDataModel:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else
|
||||
"""
|
||||
|
||||
137
src/common/data_models/bot_interest_data_model.py
Normal file
137
src/common/data_models/bot_interest_data_model.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
机器人兴趣标签数据模型
|
||||
定义机器人的兴趣标签和相关的embedding数据结构
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterestTag(BaseDataModel):
|
||||
"""机器人兴趣标签"""
|
||||
|
||||
tag_name: str
|
||||
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
||||
embedding: Optional[List[float]] = None # 标签的embedding向量
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"tag_name": self.tag_name,
|
||||
"weight": self.weight,
|
||||
"embedding": self.embedding,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"is_active": self.is_active,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag":
|
||||
"""从字典创建对象"""
|
||||
return cls(
|
||||
tag_name=data["tag_name"],
|
||||
weight=data.get("weight", 1.0),
|
||||
embedding=data.get("embedding"),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
|
||||
is_active=data.get("is_active", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotPersonalityInterests(BaseDataModel):
|
||||
"""机器人人格化兴趣配置"""
|
||||
|
||||
personality_id: str
|
||||
personality_description: str # 人设描述文本
|
||||
interest_tags: List[BotInterestTag] = field(default_factory=list)
|
||||
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
|
||||
last_updated: datetime = field(default_factory=datetime.now)
|
||||
version: int = 1 # 版本号,用于追踪更新
|
||||
|
||||
def get_active_tags(self) -> List[BotInterestTag]:
|
||||
"""获取活跃的兴趣标签"""
|
||||
return [tag for tag in self.interest_tags if tag.is_active]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"personality_id": self.personality_id,
|
||||
"personality_description": self.personality_description,
|
||||
"interest_tags": [tag.to_dict() for tag in self.interest_tags],
|
||||
"embedding_model": self.embedding_model,
|
||||
"last_updated": self.last_updated.isoformat(),
|
||||
"version": self.version,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests":
|
||||
"""从字典创建对象"""
|
||||
return cls(
|
||||
personality_id=data["personality_id"],
|
||||
personality_description=data["personality_description"],
|
||||
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
|
||||
embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
|
||||
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
|
||||
version=data.get("version", 1),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestMatchResult(BaseDataModel):
|
||||
"""兴趣匹配结果"""
|
||||
|
||||
message_id: str
|
||||
matched_tags: List[str] = field(default_factory=list)
|
||||
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
|
||||
overall_score: float = 0.0
|
||||
top_tag: Optional[str] = None
|
||||
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
|
||||
matched_keywords: List[str] = field(default_factory=list)
|
||||
|
||||
def add_match(self, tag_name: str, score: float, keywords: List[str] = None):
|
||||
"""添加匹配结果"""
|
||||
self.matched_tags.append(tag_name)
|
||||
self.match_scores[tag_name] = score
|
||||
if keywords:
|
||||
self.matched_keywords.extend(keywords)
|
||||
|
||||
def calculate_overall_score(self):
|
||||
"""计算总体匹配分数"""
|
||||
if not self.match_scores:
|
||||
self.overall_score = 0.0
|
||||
self.top_tag = None
|
||||
return
|
||||
|
||||
# 使用加权平均计算总体分数
|
||||
total_weight = len(self.match_scores)
|
||||
if total_weight > 0:
|
||||
self.overall_score = sum(self.match_scores.values()) / total_weight
|
||||
# 设置最佳匹配标签
|
||||
self.top_tag = max(self.match_scores.items(), key=lambda x: x[1])[0]
|
||||
else:
|
||||
self.overall_score = 0.0
|
||||
self.top_tag = None
|
||||
|
||||
# 计算置信度(基于匹配标签数量和分数分布)
|
||||
if len(self.match_scores) > 0:
|
||||
avg_score = self.overall_score
|
||||
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
|
||||
self.match_scores
|
||||
)
|
||||
# 分数越集中,置信度越高
|
||||
self.confidence = max(0.0, 1.0 - score_variance)
|
||||
else:
|
||||
self.confidence = 0.0
|
||||
|
||||
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
|
||||
"""获取前N个最佳匹配"""
|
||||
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_matches[:top_n]
|
||||
@@ -79,6 +79,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
is_command: bool = False,
|
||||
is_notify: bool = False,
|
||||
selected_expressions: Optional[str] = None,
|
||||
is_read: bool = False,
|
||||
user_id: str = "",
|
||||
user_nickname: str = "",
|
||||
user_cardname: Optional[str] = None,
|
||||
@@ -94,6 +95,9 @@ class DatabaseMessages(BaseDataModel):
|
||||
chat_info_platform: str = "",
|
||||
chat_info_create_time: float = 0.0,
|
||||
chat_info_last_active_time: float = 0.0,
|
||||
# 新增字段
|
||||
actions: Optional[list] = None,
|
||||
should_reply: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.message_id = message_id
|
||||
@@ -102,6 +106,10 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.reply_to = reply_to
|
||||
self.interest_value = interest_value
|
||||
|
||||
# 新增字段
|
||||
self.actions = actions
|
||||
self.should_reply = should_reply
|
||||
|
||||
self.key_words = key_words
|
||||
self.key_words_lite = key_words_lite
|
||||
self.is_mentioned = is_mentioned
|
||||
@@ -122,6 +130,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.is_notify = is_notify
|
||||
|
||||
self.selected_expressions = selected_expressions
|
||||
self.is_read = is_read
|
||||
|
||||
self.group_info: Optional[DatabaseGroupInfo] = None
|
||||
self.user_info = DatabaseUserInfo(
|
||||
@@ -188,6 +197,10 @@ class DatabaseMessages(BaseDataModel):
|
||||
"is_command": self.is_command,
|
||||
"is_notify": self.is_notify,
|
||||
"selected_expressions": self.selected_expressions,
|
||||
"is_read": self.is_read,
|
||||
# 新增字段
|
||||
"actions": self.actions,
|
||||
"should_reply": self.should_reply,
|
||||
"user_id": self.user_info.user_id,
|
||||
"user_nickname": self.user_info.user_nickname,
|
||||
"user_cardname": self.user_info.user_cardname,
|
||||
@@ -205,6 +218,61 @@ class DatabaseMessages(BaseDataModel):
|
||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||
}
|
||||
|
||||
def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
Args:
|
||||
interest_value: 兴趣度值
|
||||
actions: 执行的动作列表
|
||||
should_reply: 是否应该回复
|
||||
"""
|
||||
if interest_value is not None:
|
||||
self.interest_value = interest_value
|
||||
if actions is not None:
|
||||
self.actions = actions
|
||||
if should_reply is not None:
|
||||
self.should_reply = should_reply
|
||||
|
||||
def add_action(self, action: str):
|
||||
"""
|
||||
添加执行的动作到消息中
|
||||
|
||||
Args:
|
||||
action: 要添加的动作名称
|
||||
"""
|
||||
if self.actions is None:
|
||||
self.actions = []
|
||||
if action not in self.actions: # 避免重复添加
|
||||
self.actions.append(action)
|
||||
|
||||
def get_actions(self) -> list:
|
||||
"""
|
||||
获取执行的动作列表
|
||||
|
||||
Returns:
|
||||
动作列表,如果没有动作则返回空列表
|
||||
"""
|
||||
return self.actions or []
|
||||
|
||||
def get_message_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取消息摘要信息
|
||||
|
||||
Returns:
|
||||
包含关键字段的消息摘要
|
||||
"""
|
||||
return {
|
||||
"message_id": self.message_id,
|
||||
"time": self.time,
|
||||
"interest_value": self.interest_value,
|
||||
"actions": self.actions,
|
||||
"should_reply": self.should_reply,
|
||||
"user_nickname": self.user_info.user_nickname,
|
||||
"display_message": self.display_message,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class DatabaseActionRecords(BaseDataModel):
|
||||
def __init__(
|
||||
@@ -232,4 +300,4 @@ class DatabaseActionRecords(BaseDataModel):
|
||||
self.action_prompt_display = action_prompt_display
|
||||
self.chat_id = chat_id
|
||||
self.chat_info_stream_id = chat_info_stream_id
|
||||
self.chat_info_platform = chat_info_platform
|
||||
self.chat_info_platform = chat_info_platform
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, List, TYPE_CHECKING
|
||||
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from .database_data_model import DatabaseMessages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -21,23 +23,37 @@ class ActionPlannerInfo(BaseDataModel):
|
||||
action_type: str = field(default_factory=str)
|
||||
reasoning: Optional[str] = None
|
||||
action_data: Optional[Dict] = None
|
||||
action_message: Optional[Dict] = None
|
||||
action_message: Optional["DatabaseMessages"] = None
|
||||
available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestScore(BaseDataModel):
|
||||
"""兴趣度评分结果"""
|
||||
|
||||
message_id: str
|
||||
total_score: float
|
||||
interest_match_score: float
|
||||
relationship_score: float
|
||||
mentioned_score: float
|
||||
details: Dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Plan(BaseDataModel):
|
||||
"""
|
||||
统一规划数据模型
|
||||
"""
|
||||
|
||||
chat_id: str
|
||||
mode: "ChatMode"
|
||||
|
||||
|
||||
chat_type: "ChatType"
|
||||
# Generator 填充
|
||||
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
|
||||
chat_history: List["DatabaseMessages"] = field(default_factory=list)
|
||||
target_info: Optional[TargetPersonInfo] = None
|
||||
|
||||
|
||||
# Filter 填充
|
||||
llm_prompt: Optional[str] = None
|
||||
decided_actions: Optional[List[ActionPlannerInfo]] = None
|
||||
|
||||
@@ -6,6 +6,7 @@ from . import BaseDataModel
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMGenerationDataModel(BaseDataModel):
|
||||
content: Optional[str] = None
|
||||
@@ -14,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
|
||||
tool_calls: Optional[List["ToolCall"]] = None
|
||||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageAndActionModel(BaseDataModel):
|
||||
chat_id: str = field(default_factory=str)
|
||||
time: float = field(default_factory=float)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_platform: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
processed_plain_text: Optional[str] = None
|
||||
display_message: Optional[str] = None
|
||||
chat_info_platform: str = field(default_factory=str)
|
||||
is_action_record: bool = field(default=False)
|
||||
action_name: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
|
||||
return cls(
|
||||
chat_id=message.chat_id,
|
||||
time=message.time,
|
||||
user_id=message.user_info.user_id,
|
||||
user_platform=message.user_info.platform,
|
||||
user_nickname=message.user_info.user_nickname,
|
||||
user_cardname=message.user_info.user_cardname,
|
||||
processed_plain_text=message.processed_plain_text,
|
||||
display_message=message.display_message,
|
||||
chat_info_platform=message.chat_info.platform,
|
||||
)
|
||||
373
src/common/data_models/message_manager_data_model.py
Normal file
373
src/common/data_models/message_manager_data_model.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
消息管理模块数据模型
|
||||
定义消息管理器使用的数据结构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from . import BaseDataModel
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger("stream_context")
|
||||
|
||||
|
||||
class MessageStatus(Enum):
|
||||
"""消息状态枚举"""
|
||||
|
||||
UNREAD = "unread" # 未读消息
|
||||
READ = "read" # 已读消息
|
||||
PROCESSING = "processing" # 处理中
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext(BaseDataModel):
|
||||
"""聊天流上下文信息"""
|
||||
|
||||
stream_id: str
|
||||
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
|
||||
chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式
|
||||
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
last_check_time: float = field(default_factory=time.time)
|
||||
is_active: bool = True
|
||||
processing_task: Optional[asyncio.Task] = None
|
||||
interruption_count: int = 0 # 打断计数器
|
||||
last_interruption_time: float = 0.0 # 上次打断时间
|
||||
afc_threshold_adjustment: float = 0.0 # afc阈值调整量
|
||||
|
||||
# 独立分发周期字段
|
||||
next_check_time: float = field(default_factory=time.time) # 下次检查时间
|
||||
distribution_interval: float = 5.0 # 当前分发周期(秒)
|
||||
|
||||
# 新增字段以替代ChatMessageContext功能
|
||||
current_message: Optional["DatabaseMessages"] = None
|
||||
priority_mode: Optional[str] = None
|
||||
priority_info: Optional[dict] = None
|
||||
|
||||
def add_message(self, message: "DatabaseMessages"):
|
||||
"""添加消息到上下文"""
|
||||
message.is_read = False
|
||||
self.unread_messages.append(message)
|
||||
|
||||
# 自动检测和更新chat type
|
||||
self._detect_chat_type(message)
|
||||
|
||||
def update_message_info(
|
||||
self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None
|
||||
):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
interest_value: 兴趣度值
|
||||
actions: 执行的动作列表
|
||||
should_reply: 是否应该回复
|
||||
"""
|
||||
# 在未读消息中查找并更新
|
||||
for message in self.unread_messages:
|
||||
if message.message_id == message_id:
|
||||
message.update_message_info(interest_value, actions, should_reply)
|
||||
break
|
||||
|
||||
# 在历史消息中查找并更新
|
||||
for message in self.history_messages:
|
||||
if message.message_id == message_id:
|
||||
message.update_message_info(interest_value, actions, should_reply)
|
||||
break
|
||||
|
||||
def add_action_to_message(self, message_id: str, action: str):
|
||||
"""
|
||||
向指定消息添加执行的动作
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
action: 要添加的动作名称
|
||||
"""
|
||||
# 在未读消息中查找并更新
|
||||
for message in self.unread_messages:
|
||||
if message.message_id == message_id:
|
||||
message.add_action(action)
|
||||
break
|
||||
|
||||
# 在历史消息中查找并更新
|
||||
for message in self.history_messages:
|
||||
if message.message_id == message_id:
|
||||
message.add_action(action)
|
||||
break
|
||||
|
||||
def _detect_chat_type(self, message: "DatabaseMessages"):
|
||||
"""根据消息内容自动检测聊天类型"""
|
||||
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
|
||||
if len(self.unread_messages) == 1: # 只有这条消息
|
||||
# 如果消息包含群组信息,则为群聊
|
||||
if hasattr(message, "chat_info_group_id") and message.chat_info_group_id:
|
||||
self.chat_type = ChatType.GROUP
|
||||
elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name:
|
||||
self.chat_type = ChatType.GROUP
|
||||
else:
|
||||
self.chat_type = ChatType.PRIVATE
|
||||
|
||||
def update_chat_type(self, chat_type: ChatType):
|
||||
"""手动更新聊天类型"""
|
||||
self.chat_type = chat_type
|
||||
|
||||
def set_chat_mode(self, chat_mode: ChatMode):
|
||||
"""设置聊天模式"""
|
||||
self.chat_mode = chat_mode
|
||||
|
||||
def is_group_chat(self) -> bool:
|
||||
"""检查是否为群聊"""
|
||||
return self.chat_type == ChatType.GROUP
|
||||
|
||||
def is_private_chat(self) -> bool:
|
||||
"""检查是否为私聊"""
|
||||
return self.chat_type == ChatType.PRIVATE
|
||||
|
||||
def get_chat_type_display(self) -> str:
|
||||
"""获取聊天类型的显示名称"""
|
||||
if self.chat_type == ChatType.GROUP:
|
||||
return "群聊"
|
||||
elif self.chat_type == ChatType.PRIVATE:
|
||||
return "私聊"
|
||||
else:
|
||||
return "未知类型"
|
||||
|
||||
def mark_message_as_read(self, message_id: str):
|
||||
"""标记消息为已读"""
|
||||
for msg in self.unread_messages:
|
||||
if msg.message_id == message_id:
|
||||
msg.is_read = True
|
||||
self.history_messages.append(msg)
|
||||
self.unread_messages.remove(msg)
|
||||
break
|
||||
|
||||
def get_unread_messages(self) -> List["DatabaseMessages"]:
|
||||
"""获取未读消息"""
|
||||
return [msg for msg in self.unread_messages if not msg.is_read]
|
||||
|
||||
def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]:
|
||||
"""获取历史消息"""
|
||||
# 优先返回最近的历史消息和所有未读消息
|
||||
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
|
||||
return recent_history
|
||||
|
||||
def calculate_interruption_probability(self, max_limit: int, probability_factor: float) -> float:
|
||||
"""计算打断概率"""
|
||||
if max_limit <= 0:
|
||||
return 0.0
|
||||
|
||||
# 计算打断比例
|
||||
interruption_ratio = self.interruption_count / max_limit
|
||||
|
||||
# 如果已达到或超过最大次数,完全禁止打断
|
||||
if self.interruption_count >= max_limit:
|
||||
return 0.0
|
||||
|
||||
# 如果超过概率因子,概率下降
|
||||
if interruption_ratio > probability_factor:
|
||||
# 使用指数衰减,超过限制越多,概率越低
|
||||
excess_ratio = interruption_ratio - probability_factor
|
||||
probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减
|
||||
else:
|
||||
# 在限制内,保持较高概率
|
||||
probability = 0.8
|
||||
|
||||
return max(0.0, min(1.0, probability))
|
||||
|
||||
def increment_interruption_count(self):
|
||||
"""增加打断计数"""
|
||||
self.interruption_count += 1
|
||||
self.last_interruption_time = time.time()
|
||||
|
||||
# 同步打断计数到ChatStream
|
||||
self._sync_interruption_count_to_stream()
|
||||
|
||||
def reset_interruption_count(self):
|
||||
"""重置打断计数和afc阈值调整"""
|
||||
self.interruption_count = 0
|
||||
self.last_interruption_time = 0.0
|
||||
self.afc_threshold_adjustment = 0.0
|
||||
|
||||
# 同步打断计数到ChatStream
|
||||
self._sync_interruption_count_to_stream()
|
||||
|
||||
def apply_interruption_afc_reduction(self, reduction_value: float):
|
||||
"""应用打断导致的afc阈值降低"""
|
||||
self.afc_threshold_adjustment += reduction_value
|
||||
logger.debug(f"应用afc阈值降低: {reduction_value}, 总调整量: {self.afc_threshold_adjustment}")
|
||||
|
||||
def get_afc_threshold_adjustment(self) -> float:
|
||||
"""获取当前的afc阈值调整量"""
|
||||
return self.afc_threshold_adjustment
|
||||
|
||||
def _sync_interruption_count_to_stream(self):
|
||||
"""同步打断计数到ChatStream"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
if chat_manager:
|
||||
chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
if chat_stream and hasattr(chat_stream, "interruption_count"):
|
||||
# 在这里我们只是标记需要保存,实际的保存会在下次save时进行
|
||||
chat_stream.saved = False
|
||||
logger.debug(
|
||||
f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"同步打断计数到ChatStream失败: {e}")
|
||||
|
||||
def set_current_message(self, message: "DatabaseMessages"):
|
||||
"""设置当前消息"""
|
||||
self.current_message = message
|
||||
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
"""获取模板名称"""
|
||||
if (
|
||||
self.current_message
|
||||
and hasattr(self.current_message, "additional_config")
|
||||
and self.current_message.additional_config
|
||||
):
|
||||
try:
|
||||
import json
|
||||
|
||||
config = json.loads(self.current_message.additional_config)
|
||||
if config.get("template_info") and not config.get("template_default", True):
|
||||
return config.get("template_name")
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_last_message(self) -> Optional["DatabaseMessages"]:
|
||||
"""获取最后一条消息"""
|
||||
if self.current_message:
|
||||
return self.current_message
|
||||
if self.unread_messages:
|
||||
return self.unread_messages[-1]
|
||||
if self.history_messages:
|
||||
return self.history_messages[-1]
|
||||
return None
|
||||
|
||||
def check_types(self, types: list) -> bool:
|
||||
"""
|
||||
检查当前消息是否支持指定的类型
|
||||
|
||||
Args:
|
||||
types: 需要检查的消息类型列表,如 ["text", "image", "emoji"]
|
||||
|
||||
Returns:
|
||||
bool: 如果消息支持所有指定的类型则返回True,否则返回False
|
||||
"""
|
||||
if not self.current_message:
|
||||
return False
|
||||
|
||||
if not types:
|
||||
# 如果没有指定类型要求,默认为支持
|
||||
return True
|
||||
|
||||
# 优先从additional_config中获取format_info
|
||||
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
|
||||
try:
|
||||
import orjson
|
||||
|
||||
config = orjson.loads(self.current_message.additional_config)
|
||||
|
||||
# 检查format_info结构
|
||||
if "format_info" in config:
|
||||
format_info = config["format_info"]
|
||||
|
||||
# 方法1: 直接检查accept_format字段
|
||||
if "accept_format" in format_info:
|
||||
accept_format = format_info["accept_format"]
|
||||
# 确保accept_format是列表类型
|
||||
if isinstance(accept_format, str):
|
||||
accept_format = [accept_format]
|
||||
elif isinstance(accept_format, list):
|
||||
pass
|
||||
else:
|
||||
# 如果accept_format不是字符串或列表,尝试转换为列表
|
||||
accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else []
|
||||
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in accept_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||
return False
|
||||
return True
|
||||
|
||||
# 方法2: 检查content_format字段(向后兼容)
|
||||
elif "content_format" in format_info:
|
||||
content_format = format_info["content_format"]
|
||||
# 确保content_format是列表类型
|
||||
if isinstance(content_format, str):
|
||||
content_format = [content_format]
|
||||
elif isinstance(content_format, list):
|
||||
pass
|
||||
else:
|
||||
content_format = list(content_format) if hasattr(content_format, "__iter__") else []
|
||||
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in content_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||
return False
|
||||
return True
|
||||
|
||||
except (orjson.JSONDecodeError, AttributeError, TypeError) as e:
|
||||
logger.debug(f"解析消息格式信息失败: {e}")
|
||||
|
||||
# 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型
|
||||
# 大多数消息至少支持text类型
|
||||
default_supported_types = ["text", "emoji"]
|
||||
for requested_type in types:
|
||||
if requested_type not in default_supported_types:
|
||||
logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'")
|
||||
# 对于非基础类型,返回False以避免错误
|
||||
if requested_type not in ["text", "emoji", "reply"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> Optional[str]:
|
||||
"""获取优先级模式"""
|
||||
return self.priority_mode
|
||||
|
||||
def get_priority_info(self) -> Optional[dict]:
|
||||
"""获取优先级信息"""
|
||||
return self.priority_info
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageManagerStats(BaseDataModel):
|
||||
"""消息管理器统计信息"""
|
||||
|
||||
total_streams: int = 0
|
||||
active_streams: int = 0
|
||||
total_unread_messages: int = 0
|
||||
total_processed_messages: int = 0
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
@property
|
||||
def uptime(self) -> float:
|
||||
"""运行时间"""
|
||||
return time.time() - self.start_time
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamStats(BaseDataModel):
|
||||
"""聊天流统计信息"""
|
||||
|
||||
stream_id: str
|
||||
is_active: bool
|
||||
unread_count: int
|
||||
history_count: int
|
||||
last_check_time: float
|
||||
has_active_task: bool
|
||||
@@ -30,6 +30,7 @@ from src.common.database.sqlalchemy_models import (
|
||||
Schedule,
|
||||
MaiZoneScheduleStatus,
|
||||
CacheEntries,
|
||||
UserRelationships,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -54,6 +55,7 @@ MODEL_MAPPING = {
|
||||
"Schedule": Schedule,
|
||||
"MaiZoneScheduleStatus": MaiZoneScheduleStatus,
|
||||
"CacheEntries": CacheEntries,
|
||||
"UserRelationships": UserRelationships,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -55,7 +55,17 @@ class ChatStreams(Base):
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
energy_value = Column(Float, nullable=True, default=5.0)
|
||||
sleep_pressure = Column(Float, nullable=True, default=0.0)
|
||||
focus_energy = Column(Float, nullable=True, default=1.0)
|
||||
focus_energy = Column(Float, nullable=True, default=0.5)
|
||||
# 动态兴趣度系统字段
|
||||
base_interest_energy = Column(Float, nullable=True, default=0.5)
|
||||
message_interest_total = Column(Float, nullable=True, default=0.0)
|
||||
message_count = Column(Integer, nullable=True, default=0)
|
||||
action_count = Column(Integer, nullable=True, default=0)
|
||||
reply_count = Column(Integer, nullable=True, default=0)
|
||||
last_interaction_time = Column(Float, nullable=True, default=None)
|
||||
consecutive_no_reply = Column(Integer, nullable=True, default=0)
|
||||
# 消息打断系统字段
|
||||
interruption_count = Column(Integer, nullable=True, default=0)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_chatstreams_stream_id", "stream_id"),
|
||||
@@ -165,11 +175,16 @@ class Messages(Base):
|
||||
is_command = Column(Boolean, nullable=False, default=False)
|
||||
is_notify = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# 兴趣度系统字段
|
||||
actions = Column(Text, nullable=True) # JSON格式存储动作列表
|
||||
should_reply = Column(Boolean, nullable=True, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_messages_message_id", "message_id"),
|
||||
Index("idx_messages_chat_id", "chat_id"),
|
||||
Index("idx_messages_time", "time"),
|
||||
Index("idx_messages_user_id", "user_id"),
|
||||
Index("idx_messages_should_reply", "should_reply"),
|
||||
)
|
||||
|
||||
|
||||
@@ -300,6 +315,26 @@ class PersonInfo(Base):
|
||||
)
|
||||
|
||||
|
||||
class BotPersonalityInterests(Base):
|
||||
"""机器人人格兴趣标签模型"""
|
||||
|
||||
__tablename__ = "bot_personality_interests"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
personality_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
personality_description = Column(Text, nullable=False)
|
||||
interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表
|
||||
embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
|
||||
version = Column(Integer, nullable=False, default=1)
|
||||
last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_botpersonality_personality_id", "personality_id"),
|
||||
Index("idx_botpersonality_version", "version"),
|
||||
Index("idx_botpersonality_last_updated", "last_updated"),
|
||||
)
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
"""记忆模型"""
|
||||
|
||||
@@ -722,3 +757,23 @@ class UserPermissions(Base):
|
||||
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
||||
Index("idx_permission_granted", "permission_node", "granted"),
|
||||
)
|
||||
|
||||
|
||||
class UserRelationships(Base):
|
||||
"""用户关系模型 - 存储用户与bot的关系数据"""
|
||||
|
||||
__tablename__ = "user_relationships"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
|
||||
user_name = Column(get_string_field(100), nullable=True) # 用户名
|
||||
relationship_text = Column(Text, nullable=True) # 关系印象描述
|
||||
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
||||
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_user_relationship_id", "user_id"),
|
||||
Index("idx_relationship_score", "relationship_score"),
|
||||
Index("idx_relationship_updated", "last_updated"),
|
||||
)
|
||||
|
||||
@@ -350,6 +350,10 @@ MODULE_COLORS = {
|
||||
"memory": "\033[38;5;117m", # 天蓝色
|
||||
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
|
||||
"action_manager": "\033[38;5;208m", # 橙色,不与replyer重复
|
||||
"message_manager": "\033[38;5;27m", # 深蓝色,消息管理器
|
||||
"chatter_manager": "\033[38;5;129m", # 紫色,聊天管理器
|
||||
"chatter_interest_scoring": "\033[38;5;214m", # 橙黄色,兴趣评分
|
||||
"plan_executor": "\033[38;5;172m", # 橙褐色,计划执行器
|
||||
# 关系系统
|
||||
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
|
||||
# 聊天相关模块
|
||||
@@ -551,6 +555,10 @@ MODULE_ALIASES = {
|
||||
"llm_models": "模型",
|
||||
"person_info": "人物",
|
||||
"chat_stream": "聊天流",
|
||||
"message_manager": "消息管理",
|
||||
"chatter_manager": "聊天管理",
|
||||
"chatter_interest_scoring": "兴趣评分",
|
||||
"plan_executor": "计划执行",
|
||||
"planner": "规划器",
|
||||
"replyer": "言语",
|
||||
"config": "配置",
|
||||
|
||||
@@ -23,15 +23,15 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
maim_message_config = global_config.maim_message
|
||||
|
||||
# 设置基本参数
|
||||
|
||||
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
|
||||
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
port = 8000
|
||||
|
||||
|
||||
kwargs = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
|
||||
@@ -22,10 +22,15 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
"""
|
||||
将 SQLAlchemy 模型实例转换为字典。
|
||||
"""
|
||||
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
||||
try:
|
||||
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
||||
except Exception as e:
|
||||
# 如果对象已经脱离会话,尝试从instance.__dict__中获取数据
|
||||
logger.warning(f"从数据库对象获取属性失败,尝试使用__dict__: {e}")
|
||||
return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns}
|
||||
|
||||
|
||||
async def find_messages(
|
||||
def find_messages(
|
||||
message_filter: dict[str, Any],
|
||||
sort: Optional[List[tuple[str, int]]] = None,
|
||||
limit: int = 0,
|
||||
@@ -46,7 +51,7 @@ async def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
query = select(Messages)
|
||||
|
||||
# 应用过滤器
|
||||
@@ -96,7 +101,7 @@ async def find_messages(
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
try:
|
||||
results = (await session.execute(query)).scalars().all()
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
@@ -104,7 +109,7 @@ async def find_messages(
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
try:
|
||||
latest_results = (await session.execute(query)).scalars().all()
|
||||
latest_results = session.execute(query).scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
@@ -128,11 +133,12 @@ async def find_messages(
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = (await session.execute(query)).scalars().all()
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
|
||||
# 在会话内将结果转换为字典,避免会话分离错误
|
||||
return [_model_to_dict(msg) for msg in results]
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
@@ -143,7 +149,7 @@ async def find_messages(
|
||||
return []
|
||||
|
||||
|
||||
async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
"""
|
||||
根据提供的过滤器计算消息数量。
|
||||
|
||||
@@ -154,7 +160,7 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
with get_db_session() as session:
|
||||
query = select(func.count(Messages.id))
|
||||
|
||||
# 应用过滤器
|
||||
@@ -192,7 +198,7 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = (await session.execute(query)).scalar()
|
||||
count = session.execute(query).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
@@ -201,5 +207,5 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
|
||||
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 await session.commit()。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。
|
||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||
|
||||
@@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore
|
||||
"""客户端UUID"""
|
||||
|
||||
self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore
|
||||
self.private_key_pem: str | None = (
|
||||
local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None
|
||||
) # type: ignore
|
||||
"""客户端私钥"""
|
||||
|
||||
self.info_dict = self._get_sys_info()
|
||||
@@ -61,78 +63,65 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
def _generate_signature(self, request_body: dict) -> tuple[str, str]:
|
||||
"""
|
||||
生成RSA签名
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (timestamp, signature_b64)
|
||||
"""
|
||||
if not self.private_key_pem:
|
||||
raise ValueError("私钥未初始化")
|
||||
|
||||
|
||||
# 生成时间戳
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
# 创建签名数据字符串
|
||||
sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}"
|
||||
|
||||
|
||||
# 加载私钥
|
||||
private_key = serialization.load_pem_private_key(
|
||||
self.private_key_pem.encode('utf-8'),
|
||||
password=None
|
||||
)
|
||||
|
||||
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
|
||||
|
||||
# 确保是RSA私钥
|
||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||
raise ValueError("私钥必须是RSA格式")
|
||||
|
||||
|
||||
# 生成签名
|
||||
signature = private_key.sign(
|
||||
sign_data.encode('utf-8'),
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH
|
||||
),
|
||||
hashes.SHA256()
|
||||
sign_data.encode("utf-8"),
|
||||
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
|
||||
|
||||
# Base64编码
|
||||
signature_b64 = base64.b64encode(signature).decode('utf-8')
|
||||
|
||||
signature_b64 = base64.b64encode(signature).decode("utf-8")
|
||||
|
||||
return timestamp, signature_b64
|
||||
|
||||
def _decrypt_challenge(self, challenge_b64: str) -> str:
|
||||
"""
|
||||
解密挑战数据
|
||||
|
||||
|
||||
Args:
|
||||
challenge_b64: Base64编码的挑战数据
|
||||
|
||||
|
||||
Returns:
|
||||
str: 解密后的UUID字符串
|
||||
"""
|
||||
if not self.private_key_pem:
|
||||
raise ValueError("私钥未初始化")
|
||||
|
||||
|
||||
# 加载私钥
|
||||
private_key = serialization.load_pem_private_key(
|
||||
self.private_key_pem.encode('utf-8'),
|
||||
password=None
|
||||
)
|
||||
|
||||
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
|
||||
|
||||
# 确保是RSA私钥
|
||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||
raise ValueError("私钥必须是RSA格式")
|
||||
|
||||
|
||||
# 解密挑战数据
|
||||
decrypted_bytes = private_key.decrypt(
|
||||
base64.b64decode(challenge_b64),
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
|
||||
)
|
||||
|
||||
return decrypted_bytes.decode('utf-8')
|
||||
|
||||
return decrypted_bytes.decode("utf-8")
|
||||
|
||||
async def _req_uuid(self) -> bool:
|
||||
"""
|
||||
@@ -155,28 +144,26 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
|
||||
if response.status != 200:
|
||||
response_text = await response.text()
|
||||
logger.error(
|
||||
f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}"
|
||||
)
|
||||
logger.error(f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}")
|
||||
raise aiohttp.ClientResponseError(
|
||||
request_info=response.request_info,
|
||||
history=response.history,
|
||||
status=response.status,
|
||||
message=f"Step1 failed: {response_text}"
|
||||
message=f"Step1 failed: {response_text}",
|
||||
)
|
||||
|
||||
step1_data = await response.json()
|
||||
temp_uuid = step1_data.get("temp_uuid")
|
||||
private_key = step1_data.get("private_key")
|
||||
challenge = step1_data.get("challenge")
|
||||
|
||||
|
||||
if not all([temp_uuid, private_key, challenge]):
|
||||
logger.error("Step1响应缺少必要字段:temp_uuid, private_key 或 challenge")
|
||||
raise ValueError("Step1响应数据不完整")
|
||||
|
||||
# 临时保存私钥用于解密
|
||||
self.private_key_pem = private_key
|
||||
|
||||
|
||||
# 解密挑战数据
|
||||
logger.debug("解密挑战数据...")
|
||||
try:
|
||||
@@ -184,21 +171,18 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
except Exception as e:
|
||||
logger.error(f"解密挑战数据失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# 验证解密结果
|
||||
if decrypted_uuid != temp_uuid:
|
||||
logger.error(f"解密结果验证失败: 期望 {temp_uuid}, 实际 {decrypted_uuid}")
|
||||
raise ValueError("解密结果与临时UUID不匹配")
|
||||
|
||||
|
||||
logger.debug("挑战数据解密成功,开始注册步骤2")
|
||||
|
||||
# Step 2: 发送解密结果完成注册
|
||||
async with session.post(
|
||||
f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2",
|
||||
json={
|
||||
"temp_uuid": temp_uuid,
|
||||
"decrypted_uuid": decrypted_uuid
|
||||
},
|
||||
json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid},
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as response:
|
||||
logger.debug(f"Step2 Response status: {response.status}")
|
||||
@@ -206,7 +190,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
if response.status == 200:
|
||||
step2_data = await response.json()
|
||||
mofox_uuid = step2_data.get("mofox_uuid")
|
||||
|
||||
|
||||
if mofox_uuid:
|
||||
# 将正式UUID和私钥存储到本地
|
||||
local_storage["mofox_uuid"] = mofox_uuid
|
||||
@@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
raise ValueError(f"Step2失败: {response_text}")
|
||||
else:
|
||||
response_text = await response.text()
|
||||
logger.error(
|
||||
f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}"
|
||||
)
|
||||
logger.error(f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}")
|
||||
raise aiohttp.ClientResponseError(
|
||||
request_info=response.request_info,
|
||||
history=response.history,
|
||||
status=response.status,
|
||||
message=f"Step2 failed: {response_text}"
|
||||
message=f"Step2 failed: {response_text}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = str(e) or "未知错误"
|
||||
logger.warning(
|
||||
f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}"
|
||||
)
|
||||
logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}")
|
||||
logger.debug(f"完整错误信息: {traceback.format_exc()}")
|
||||
|
||||
# 请求失败,重试次数+1
|
||||
@@ -264,13 +244,13 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
try:
|
||||
# 生成签名
|
||||
timestamp, signature = self._generate_signature(self.info_dict)
|
||||
|
||||
|
||||
headers = {
|
||||
"X-mofox-UUID": self.client_uuid,
|
||||
"X-mofox-Signature": signature,
|
||||
"X-mofox-Timestamp": timestamp,
|
||||
"User-Agent": f"MofoxClient/{self.client_uuid[:8]}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
||||
@@ -347,4 +327,4 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
logger.warning("客户端注册失败,跳过此次心跳")
|
||||
return
|
||||
|
||||
await self._send_heartbeat()
|
||||
await self._send_heartbeat()
|
||||
|
||||
@@ -99,14 +99,13 @@ def get_global_server() -> Server:
|
||||
"""获取全局服务器实例"""
|
||||
global global_server
|
||||
if global_server is None:
|
||||
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
|
||||
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
port = 8000
|
||||
|
||||
|
||||
global_server = Server(host=host, port=port)
|
||||
return global_server
|
||||
|
||||
Reference in New Issue
Block a user