ruff,私聊视为提及了bot

This commit is contained in:
Windpicker-owo
2025-09-20 22:34:22 +08:00
parent 006f9130b9
commit 444f1ca315
76 changed files with 1066 additions and 882 deletions

View File

@@ -6,6 +6,7 @@ class BaseDataModel:
def deepcopy(self):
return copy.deepcopy(self)
def temporarily_transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else
"""

View File

@@ -2,6 +2,7 @@
机器人兴趣标签数据模型
定义机器人的兴趣标签和相关的embedding数据结构
"""
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from datetime import datetime
@@ -12,6 +13,7 @@ from . import BaseDataModel
@dataclass
class BotInterestTag(BaseDataModel):
"""机器人兴趣标签"""
tag_name: str
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
embedding: Optional[List[float]] = None # 标签的embedding向量
@@ -27,7 +29,7 @@ class BotInterestTag(BaseDataModel):
"embedding": self.embedding,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"is_active": self.is_active
"is_active": self.is_active,
}
@classmethod
@@ -39,13 +41,14 @@ class BotInterestTag(BaseDataModel):
embedding=data.get("embedding"),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
is_active=data.get("is_active", True)
is_active=data.get("is_active", True),
)
@dataclass
class BotPersonalityInterests(BaseDataModel):
"""机器人人格化兴趣配置"""
personality_id: str
personality_description: str # 人设描述文本
interest_tags: List[BotInterestTag] = field(default_factory=list)
@@ -57,7 +60,6 @@ class BotPersonalityInterests(BaseDataModel):
"""获取活跃的兴趣标签"""
return [tag for tag in self.interest_tags if tag.is_active]
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
@@ -66,7 +68,7 @@ class BotPersonalityInterests(BaseDataModel):
"interest_tags": [tag.to_dict() for tag in self.interest_tags],
"embedding_model": self.embedding_model,
"last_updated": self.last_updated.isoformat(),
"version": self.version
"version": self.version,
}
@classmethod
@@ -78,13 +80,14 @@ class BotPersonalityInterests(BaseDataModel):
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
version=data.get("version", 1)
version=data.get("version", 1),
)
@dataclass
class InterestMatchResult(BaseDataModel):
"""兴趣匹配结果"""
message_id: str
matched_tags: List[str] = field(default_factory=list)
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
@@ -120,7 +123,9 @@ class InterestMatchResult(BaseDataModel):
# 计算置信度(基于匹配标签数量和分数分布)
if len(self.match_scores) > 0:
avg_score = self.overall_score
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(self.match_scores)
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
self.match_scores
)
# 分数越集中,置信度越高
self.confidence = max(0.0, 1.0 - score_variance)
else:
@@ -129,4 +134,4 @@ class InterestMatchResult(BaseDataModel):
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
"""获取前N个最佳匹配"""
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
return sorted_matches[:top_n]
return sorted_matches[:top_n]

View File

@@ -208,6 +208,7 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
}
@dataclass(init=False)
class DatabaseActionRecords(BaseDataModel):
def __init__(
@@ -235,4 +236,4 @@ class DatabaseActionRecords(BaseDataModel):
self.action_prompt_display = action_prompt_display
self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform
self.chat_info_platform = chat_info_platform

View File

@@ -28,6 +28,7 @@ class ActionPlannerInfo(BaseDataModel):
@dataclass
class InterestScore(BaseDataModel):
"""兴趣度评分结果"""
message_id: str
total_score: float
interest_match_score: float
@@ -41,6 +42,7 @@ class Plan(BaseDataModel):
"""
统一规划数据模型
"""
chat_id: str
mode: "ChatMode"

View File

@@ -2,9 +2,11 @@ from dataclasses import dataclass
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
from . import BaseDataModel
if TYPE_CHECKING:
from src.llm_models.payload_content.tool_option import ToolCall
@dataclass
class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None
@@ -13,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
tool_calls: Optional[List["ToolCall"]] = None
prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None

View File

@@ -2,6 +2,7 @@
消息管理模块数据模型
定义消息管理器使用的数据结构
"""
import asyncio
import time
from dataclasses import dataclass, field
@@ -16,14 +17,16 @@ if TYPE_CHECKING:
class MessageStatus(Enum):
"""消息状态枚举"""
UNREAD = "unread" # 未读消息
READ = "read" # 读消息
UNREAD = "unread" # 读消息
READ = "read" # 已读消息
PROCESSING = "processing" # 处理中
@dataclass
class StreamContext(BaseDataModel):
"""聊天流上下文信息"""
stream_id: str
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
history_messages: List["DatabaseMessages"] = field(default_factory=list)
@@ -59,6 +62,7 @@ class StreamContext(BaseDataModel):
@dataclass
class MessageManagerStats(BaseDataModel):
"""消息管理器统计信息"""
total_streams: int = 0
active_streams: int = 0
total_unread_messages: int = 0
@@ -74,9 +78,10 @@ class MessageManagerStats(BaseDataModel):
@dataclass
class StreamStats(BaseDataModel):
"""聊天流统计信息"""
stream_id: str
is_active: bool
unread_count: int
history_count: int
last_check_time: float
has_active_task: bool
has_active_task: bool

View File

@@ -23,15 +23,15 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
maim_message_config = global_config.maim_message
# 设置基本参数
host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000")
try:
port = int(port_str)
except ValueError:
port = 8000
kwargs = {
"host": host,
"port": port,

View File

@@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask):
self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore
"""客户端UUID"""
self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore
self.private_key_pem: str | None = (
local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None
) # type: ignore
"""客户端私钥"""
self.info_dict = self._get_sys_info()
@@ -61,78 +63,65 @@ class TelemetryHeartBeatTask(AsyncTask):
def _generate_signature(self, request_body: dict) -> tuple[str, str]:
"""
生成RSA签名
Returns:
tuple[str, str]: (timestamp, signature_b64)
"""
if not self.private_key_pem:
raise ValueError("私钥未初始化")
# 生成时间戳
timestamp = datetime.now(timezone.utc).isoformat()
# 创建签名数据字符串
sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}"
# 加载私钥
private_key = serialization.load_pem_private_key(
self.private_key_pem.encode('utf-8'),
password=None
)
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
# 确保是RSA私钥
if not isinstance(private_key, rsa.RSAPrivateKey):
raise ValueError("私钥必须是RSA格式")
# 生成签名
signature = private_key.sign(
sign_data.encode('utf-8'),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
sign_data.encode("utf-8"),
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
hashes.SHA256(),
)
# Base64编码
signature_b64 = base64.b64encode(signature).decode('utf-8')
signature_b64 = base64.b64encode(signature).decode("utf-8")
return timestamp, signature_b64
def _decrypt_challenge(self, challenge_b64: str) -> str:
"""
解密挑战数据
Args:
challenge_b64: Base64编码的挑战数据
Returns:
str: 解密后的UUID字符串
"""
if not self.private_key_pem:
raise ValueError("私钥未初始化")
# 加载私钥
private_key = serialization.load_pem_private_key(
self.private_key_pem.encode('utf-8'),
password=None
)
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
# 确保是RSA私钥
if not isinstance(private_key, rsa.RSAPrivateKey):
raise ValueError("私钥必须是RSA格式")
# 解密挑战数据
decrypted_bytes = private_key.decrypt(
base64.b64decode(challenge_b64),
padding.OAEP(
mgf=padding.MGF1(hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
)
return decrypted_bytes.decode('utf-8')
return decrypted_bytes.decode("utf-8")
async def _req_uuid(self) -> bool:
"""
@@ -155,28 +144,26 @@ class TelemetryHeartBeatTask(AsyncTask):
if response.status != 200:
response_text = await response.text()
logger.error(
f"注册步骤1失败状态码: {response.status}, 响应内容: {response_text}"
)
logger.error(f"注册步骤1失败状态码: {response.status}, 响应内容: {response_text}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Step1 failed: {response_text}"
message=f"Step1 failed: {response_text}",
)
step1_data = await response.json()
temp_uuid = step1_data.get("temp_uuid")
private_key = step1_data.get("private_key")
challenge = step1_data.get("challenge")
if not all([temp_uuid, private_key, challenge]):
logger.error("Step1响应缺少必要字段temp_uuid, private_key 或 challenge")
raise ValueError("Step1响应数据不完整")
# 临时保存私钥用于解密
self.private_key_pem = private_key
# 解密挑战数据
logger.debug("解密挑战数据...")
try:
@@ -184,21 +171,18 @@ class TelemetryHeartBeatTask(AsyncTask):
except Exception as e:
logger.error(f"解密挑战数据失败: {e}")
raise
# 验证解密结果
if decrypted_uuid != temp_uuid:
logger.error(f"解密结果验证失败: 期望 {temp_uuid}, 实际 {decrypted_uuid}")
raise ValueError("解密结果与临时UUID不匹配")
logger.debug("挑战数据解密成功开始注册步骤2")
# Step 2: 发送解密结果完成注册
async with session.post(
f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2",
json={
"temp_uuid": temp_uuid,
"decrypted_uuid": decrypted_uuid
},
json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid},
timeout=aiohttp.ClientTimeout(total=5),
) as response:
logger.debug(f"Step2 Response status: {response.status}")
@@ -206,7 +190,7 @@ class TelemetryHeartBeatTask(AsyncTask):
if response.status == 200:
step2_data = await response.json()
mofox_uuid = step2_data.get("mofox_uuid")
if mofox_uuid:
# 将正式UUID和私钥存储到本地
local_storage["mofox_uuid"] = mofox_uuid
@@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask):
raise ValueError(f"Step2失败: {response_text}")
else:
response_text = await response.text()
logger.error(
f"注册步骤2失败状态码: {response.status}, 响应内容: {response_text}"
)
logger.error(f"注册步骤2失败状态码: {response.status}, 响应内容: {response_text}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Step2 failed: {response_text}"
message=f"Step2 failed: {response_text}",
)
except Exception as e:
import traceback
error_msg = str(e) or "未知错误"
logger.warning(
f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}"
)
logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}")
logger.debug(f"完整错误信息: {traceback.format_exc()}")
# 请求失败,重试次数+1
@@ -264,13 +244,13 @@ class TelemetryHeartBeatTask(AsyncTask):
try:
# 生成签名
timestamp, signature = self._generate_signature(self.info_dict)
headers = {
"X-mofox-UUID": self.client_uuid,
"X-mofox-Signature": signature,
"X-mofox-Timestamp": timestamp,
"User-Agent": f"MofoxClient/{self.client_uuid[:8]}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
@@ -347,4 +327,4 @@ class TelemetryHeartBeatTask(AsyncTask):
logger.warning("客户端注册失败,跳过此次心跳")
return
await self._send_heartbeat()
await self._send_heartbeat()

View File

@@ -99,14 +99,13 @@ def get_global_server() -> Server:
"""获取全局服务器实例"""
global global_server
if global_server is None:
host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000")
try:
port = int(port_str)
except ValueError:
port = 8000
global_server = Server(host=host, port=port)
return global_server