Merge afc branch into dev, prioritizing afc changes and migrating database async modifications from dev

This commit is contained in:
Windpicker-owo
2025-09-27 23:37:40 +08:00
138 changed files with 12183 additions and 5968 deletions

View File

@@ -6,6 +6,7 @@ class BaseDataModel:
def deepcopy(self):
return copy.deepcopy(self)
def temporarily_transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else
"""

View File

@@ -0,0 +1,137 @@
"""
机器人兴趣标签数据模型
定义机器人的兴趣标签和相关的embedding数据结构
"""
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from datetime import datetime
from . import BaseDataModel
@dataclass
class BotInterestTag(BaseDataModel):
"""机器人兴趣标签"""
tag_name: str
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
embedding: Optional[List[float]] = None # 标签的embedding向量
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
is_active: bool = True
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"tag_name": self.tag_name,
"weight": self.weight,
"embedding": self.embedding,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"is_active": self.is_active,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag":
"""从字典创建对象"""
return cls(
tag_name=data["tag_name"],
weight=data.get("weight", 1.0),
embedding=data.get("embedding"),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
is_active=data.get("is_active", True),
)
@dataclass
class BotPersonalityInterests(BaseDataModel):
"""机器人人格化兴趣配置"""
personality_id: str
personality_description: str # 人设描述文本
interest_tags: List[BotInterestTag] = field(default_factory=list)
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
last_updated: datetime = field(default_factory=datetime.now)
version: int = 1 # 版本号,用于追踪更新
def get_active_tags(self) -> List[BotInterestTag]:
"""获取活跃的兴趣标签"""
return [tag for tag in self.interest_tags if tag.is_active]
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"personality_id": self.personality_id,
"personality_description": self.personality_description,
"interest_tags": [tag.to_dict() for tag in self.interest_tags],
"embedding_model": self.embedding_model,
"last_updated": self.last_updated.isoformat(),
"version": self.version,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests":
"""从字典创建对象"""
return cls(
personality_id=data["personality_id"],
personality_description=data["personality_description"],
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
version=data.get("version", 1),
)
@dataclass
class InterestMatchResult(BaseDataModel):
"""兴趣匹配结果"""
message_id: str
matched_tags: List[str] = field(default_factory=list)
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
overall_score: float = 0.0
top_tag: Optional[str] = None
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
matched_keywords: List[str] = field(default_factory=list)
def add_match(self, tag_name: str, score: float, keywords: List[str] = None):
"""添加匹配结果"""
self.matched_tags.append(tag_name)
self.match_scores[tag_name] = score
if keywords:
self.matched_keywords.extend(keywords)
def calculate_overall_score(self):
"""计算总体匹配分数"""
if not self.match_scores:
self.overall_score = 0.0
self.top_tag = None
return
# 使用加权平均计算总体分数
total_weight = len(self.match_scores)
if total_weight > 0:
self.overall_score = sum(self.match_scores.values()) / total_weight
# 设置最佳匹配标签
self.top_tag = max(self.match_scores.items(), key=lambda x: x[1])[0]
else:
self.overall_score = 0.0
self.top_tag = None
# 计算置信度(基于匹配标签数量和分数分布)
if len(self.match_scores) > 0:
avg_score = self.overall_score
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
self.match_scores
)
# 分数越集中,置信度越高
self.confidence = max(0.0, 1.0 - score_variance)
else:
self.confidence = 0.0
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
"""获取前N个最佳匹配"""
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
return sorted_matches[:top_n]

View File

@@ -79,6 +79,7 @@ class DatabaseMessages(BaseDataModel):
is_command: bool = False,
is_notify: bool = False,
selected_expressions: Optional[str] = None,
is_read: bool = False,
user_id: str = "",
user_nickname: str = "",
user_cardname: Optional[str] = None,
@@ -94,6 +95,9 @@ class DatabaseMessages(BaseDataModel):
chat_info_platform: str = "",
chat_info_create_time: float = 0.0,
chat_info_last_active_time: float = 0.0,
# 新增字段
actions: Optional[list] = None,
should_reply: bool = False,
**kwargs: Any,
):
self.message_id = message_id
@@ -102,6 +106,10 @@ class DatabaseMessages(BaseDataModel):
self.reply_to = reply_to
self.interest_value = interest_value
# 新增字段
self.actions = actions
self.should_reply = should_reply
self.key_words = key_words
self.key_words_lite = key_words_lite
self.is_mentioned = is_mentioned
@@ -122,6 +130,7 @@ class DatabaseMessages(BaseDataModel):
self.is_notify = is_notify
self.selected_expressions = selected_expressions
self.is_read = is_read
self.group_info: Optional[DatabaseGroupInfo] = None
self.user_info = DatabaseUserInfo(
@@ -188,6 +197,10 @@ class DatabaseMessages(BaseDataModel):
"is_command": self.is_command,
"is_notify": self.is_notify,
"selected_expressions": self.selected_expressions,
"is_read": self.is_read,
# 新增字段
"actions": self.actions,
"should_reply": self.should_reply,
"user_id": self.user_info.user_id,
"user_nickname": self.user_info.user_nickname,
"user_cardname": self.user_info.user_cardname,
@@ -205,6 +218,61 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
}
def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None):
"""
更新消息信息
Args:
interest_value: 兴趣度值
actions: 执行的动作列表
should_reply: 是否应该回复
"""
if interest_value is not None:
self.interest_value = interest_value
if actions is not None:
self.actions = actions
if should_reply is not None:
self.should_reply = should_reply
def add_action(self, action: str):
"""
添加执行的动作到消息中
Args:
action: 要添加的动作名称
"""
if self.actions is None:
self.actions = []
if action not in self.actions: # 避免重复添加
self.actions.append(action)
def get_actions(self) -> list:
"""
获取执行的动作列表
Returns:
动作列表,如果没有动作则返回空列表
"""
return self.actions or []
def get_message_summary(self) -> Dict[str, Any]:
"""
获取消息摘要信息
Returns:
包含关键字段的消息摘要
"""
return {
"message_id": self.message_id,
"time": self.time,
"interest_value": self.interest_value,
"actions": self.actions,
"should_reply": self.should_reply,
"user_nickname": self.user_info.user_nickname,
"display_message": self.display_message,
}
@dataclass(init=False)
class DatabaseActionRecords(BaseDataModel):
def __init__(
@@ -232,4 +300,4 @@ class DatabaseActionRecords(BaseDataModel):
self.action_prompt_display = action_prompt_display
self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform
self.chat_info_platform = chat_info_platform

View File

@@ -1,10 +1,12 @@
from dataclasses import dataclass, field
from typing import Optional, Dict, List, TYPE_CHECKING
from src.plugin_system.base.component_types import ChatType
from . import BaseDataModel
if TYPE_CHECKING:
pass
from .database_data_model import DatabaseMessages
from src.plugin_system.base.component_types import ActionInfo, ChatMode
@dataclass
@@ -21,23 +23,37 @@ class ActionPlannerInfo(BaseDataModel):
action_type: str = field(default_factory=str)
reasoning: Optional[str] = None
action_data: Optional[Dict] = None
action_message: Optional[Dict] = None
action_message: Optional["DatabaseMessages"] = None
available_actions: Optional[Dict[str, "ActionInfo"]] = None
@dataclass
class InterestScore(BaseDataModel):
"""兴趣度评分结果"""
message_id: str
total_score: float
interest_match_score: float
relationship_score: float
mentioned_score: float
details: Dict[str, str]
@dataclass
class Plan(BaseDataModel):
"""
统一规划数据模型
"""
chat_id: str
mode: "ChatMode"
chat_type: "ChatType"
# Generator 填充
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
chat_history: List["DatabaseMessages"] = field(default_factory=list)
target_info: Optional[TargetPersonInfo] = None
# Filter 填充
llm_prompt: Optional[str] = None
decided_actions: Optional[List[ActionPlannerInfo]] = None

View File

@@ -6,6 +6,7 @@ from . import BaseDataModel
if TYPE_CHECKING:
pass
@dataclass
class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None
@@ -14,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
tool_calls: Optional[List["ToolCall"]] = None
prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None

View File

@@ -1,36 +0,0 @@
from dataclasses import dataclass, field
from typing import Optional, TYPE_CHECKING
from . import BaseDataModel
if TYPE_CHECKING:
pass
@dataclass
class MessageAndActionModel(BaseDataModel):
chat_id: str = field(default_factory=str)
time: float = field(default_factory=float)
user_id: str = field(default_factory=str)
user_platform: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
user_cardname: Optional[str] = None
processed_plain_text: Optional[str] = None
display_message: Optional[str] = None
chat_info_platform: str = field(default_factory=str)
is_action_record: bool = field(default=False)
action_name: Optional[str] = None
@classmethod
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
return cls(
chat_id=message.chat_id,
time=message.time,
user_id=message.user_info.user_id,
user_platform=message.user_info.platform,
user_nickname=message.user_info.user_nickname,
user_cardname=message.user_info.user_cardname,
processed_plain_text=message.processed_plain_text,
display_message=message.display_message,
chat_info_platform=message.chat_info.platform,
)

View File

@@ -0,0 +1,373 @@
"""
消息管理模块数据模型
定义消息管理器使用的数据结构
"""
import asyncio
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, TYPE_CHECKING
from . import BaseDataModel
from src.plugin_system.base.component_types import ChatMode, ChatType
from src.common.logger import get_logger
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
logger = get_logger("stream_context")
class MessageStatus(Enum):
"""消息状态枚举"""
UNREAD = "unread" # 未读消息
READ = "read" # 已读消息
PROCESSING = "processing" # 处理中
@dataclass
class StreamContext(BaseDataModel):
"""聊天流上下文信息"""
stream_id: str
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
history_messages: List["DatabaseMessages"] = field(default_factory=list)
last_check_time: float = field(default_factory=time.time)
is_active: bool = True
processing_task: Optional[asyncio.Task] = None
interruption_count: int = 0 # 打断计数器
last_interruption_time: float = 0.0 # 上次打断时间
afc_threshold_adjustment: float = 0.0 # afc阈值调整量
# 独立分发周期字段
next_check_time: float = field(default_factory=time.time) # 下次检查时间
distribution_interval: float = 5.0 # 当前分发周期(秒)
# 新增字段以替代ChatMessageContext功能
current_message: Optional["DatabaseMessages"] = None
priority_mode: Optional[str] = None
priority_info: Optional[dict] = None
def add_message(self, message: "DatabaseMessages"):
"""添加消息到上下文"""
message.is_read = False
self.unread_messages.append(message)
# 自动检测和更新chat type
self._detect_chat_type(message)
def update_message_info(
self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None
):
"""
更新消息信息
Args:
message_id: 消息ID
interest_value: 兴趣度值
actions: 执行的动作列表
should_reply: 是否应该回复
"""
# 在未读消息中查找并更新
for message in self.unread_messages:
if message.message_id == message_id:
message.update_message_info(interest_value, actions, should_reply)
break
# 在历史消息中查找并更新
for message in self.history_messages:
if message.message_id == message_id:
message.update_message_info(interest_value, actions, should_reply)
break
def add_action_to_message(self, message_id: str, action: str):
"""
向指定消息添加执行的动作
Args:
message_id: 消息ID
action: 要添加的动作名称
"""
# 在未读消息中查找并更新
for message in self.unread_messages:
if message.message_id == message_id:
message.add_action(action)
break
# 在历史消息中查找并更新
for message in self.history_messages:
if message.message_id == message_id:
message.add_action(action)
break
def _detect_chat_type(self, message: "DatabaseMessages"):
"""根据消息内容自动检测聊天类型"""
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
if len(self.unread_messages) == 1: # 只有这条消息
# 如果消息包含群组信息,则为群聊
if hasattr(message, "chat_info_group_id") and message.chat_info_group_id:
self.chat_type = ChatType.GROUP
elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name:
self.chat_type = ChatType.GROUP
else:
self.chat_type = ChatType.PRIVATE
def update_chat_type(self, chat_type: ChatType):
"""手动更新聊天类型"""
self.chat_type = chat_type
def set_chat_mode(self, chat_mode: ChatMode):
"""设置聊天模式"""
self.chat_mode = chat_mode
def is_group_chat(self) -> bool:
"""检查是否为群聊"""
return self.chat_type == ChatType.GROUP
def is_private_chat(self) -> bool:
"""检查是否为私聊"""
return self.chat_type == ChatType.PRIVATE
def get_chat_type_display(self) -> str:
"""获取聊天类型的显示名称"""
if self.chat_type == ChatType.GROUP:
return "群聊"
elif self.chat_type == ChatType.PRIVATE:
return "私聊"
else:
return "未知类型"
def mark_message_as_read(self, message_id: str):
"""标记消息为已读"""
for msg in self.unread_messages:
if msg.message_id == message_id:
msg.is_read = True
self.history_messages.append(msg)
self.unread_messages.remove(msg)
break
def get_unread_messages(self) -> List["DatabaseMessages"]:
"""获取未读消息"""
return [msg for msg in self.unread_messages if not msg.is_read]
def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]:
"""获取历史消息"""
# 优先返回最近的历史消息和所有未读消息
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
return recent_history
def calculate_interruption_probability(self, max_limit: int, probability_factor: float) -> float:
"""计算打断概率"""
if max_limit <= 0:
return 0.0
# 计算打断比例
interruption_ratio = self.interruption_count / max_limit
# 如果已达到或超过最大次数,完全禁止打断
if self.interruption_count >= max_limit:
return 0.0
# 如果超过概率因子,概率下降
if interruption_ratio > probability_factor:
# 使用指数衰减,超过限制越多,概率越低
excess_ratio = interruption_ratio - probability_factor
probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减
else:
# 在限制内,保持较高概率
probability = 0.8
return max(0.0, min(1.0, probability))
def increment_interruption_count(self):
"""增加打断计数"""
self.interruption_count += 1
self.last_interruption_time = time.time()
# 同步打断计数到ChatStream
self._sync_interruption_count_to_stream()
def reset_interruption_count(self):
"""重置打断计数和afc阈值调整"""
self.interruption_count = 0
self.last_interruption_time = 0.0
self.afc_threshold_adjustment = 0.0
# 同步打断计数到ChatStream
self._sync_interruption_count_to_stream()
def apply_interruption_afc_reduction(self, reduction_value: float):
"""应用打断导致的afc阈值降低"""
self.afc_threshold_adjustment += reduction_value
logger.debug(f"应用afc阈值降低: {reduction_value}, 总调整量: {self.afc_threshold_adjustment}")
def get_afc_threshold_adjustment(self) -> float:
"""获取当前的afc阈值调整量"""
return self.afc_threshold_adjustment
def _sync_interruption_count_to_stream(self):
"""同步打断计数到ChatStream"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if chat_manager:
chat_stream = chat_manager.get_stream(self.stream_id)
if chat_stream and hasattr(chat_stream, "interruption_count"):
# 在这里我们只是标记需要保存实际的保存会在下次save时进行
chat_stream.saved = False
logger.debug(
f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream"
)
except Exception as e:
logger.warning(f"同步打断计数到ChatStream失败: {e}")
def set_current_message(self, message: "DatabaseMessages"):
"""设置当前消息"""
self.current_message = message
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if (
self.current_message
and hasattr(self.current_message, "additional_config")
and self.current_message.additional_config
):
try:
import json
config = json.loads(self.current_message.additional_config)
if config.get("template_info") and not config.get("template_default", True):
return config.get("template_name")
except (json.JSONDecodeError, AttributeError):
pass
return None
def get_last_message(self) -> Optional["DatabaseMessages"]:
"""获取最后一条消息"""
if self.current_message:
return self.current_message
if self.unread_messages:
return self.unread_messages[-1]
if self.history_messages:
return self.history_messages[-1]
return None
def check_types(self, types: list) -> bool:
"""
检查当前消息是否支持指定的类型
Args:
types: 需要检查的消息类型列表,如 ["text", "image", "emoji"]
Returns:
bool: 如果消息支持所有指定的类型则返回True否则返回False
"""
if not self.current_message:
return False
if not types:
# 如果没有指定类型要求,默认为支持
return True
# 优先从additional_config中获取format_info
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
try:
import orjson
config = orjson.loads(self.current_message.additional_config)
# 检查format_info结构
if "format_info" in config:
format_info = config["format_info"]
# 方法1: 直接检查accept_format字段
if "accept_format" in format_info:
accept_format = format_info["accept_format"]
# 确保accept_format是列表类型
if isinstance(accept_format, str):
accept_format = [accept_format]
elif isinstance(accept_format, list):
pass
else:
# 如果accept_format不是字符串或列表尝试转换为列表
accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else []
# 检查所有请求的类型是否都被支持
for requested_type in types:
if requested_type not in accept_format:
logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
return False
return True
# 方法2: 检查content_format字段向后兼容
elif "content_format" in format_info:
content_format = format_info["content_format"]
# 确保content_format是列表类型
if isinstance(content_format, str):
content_format = [content_format]
elif isinstance(content_format, list):
pass
else:
content_format = list(content_format) if hasattr(content_format, "__iter__") else []
# 检查所有请求的类型是否都被支持
for requested_type in types:
if requested_type not in content_format:
logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
return False
return True
except (orjson.JSONDecodeError, AttributeError, TypeError) as e:
logger.debug(f"解析消息格式信息失败: {e}")
# 备用方案如果无法从additional_config获取格式信息使用默认支持的类型
# 大多数消息至少支持text类型
default_supported_types = ["text", "emoji"]
for requested_type in types:
if requested_type not in default_supported_types:
logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'")
# 对于非基础类型返回False以避免错误
if requested_type not in ["text", "emoji", "reply"]:
return False
return True
def get_priority_mode(self) -> Optional[str]:
"""获取优先级模式"""
return self.priority_mode
def get_priority_info(self) -> Optional[dict]:
"""获取优先级信息"""
return self.priority_info
@dataclass
class MessageManagerStats(BaseDataModel):
"""消息管理器统计信息"""
total_streams: int = 0
active_streams: int = 0
total_unread_messages: int = 0
total_processed_messages: int = 0
start_time: float = field(default_factory=time.time)
@property
def uptime(self) -> float:
"""运行时间"""
return time.time() - self.start_time
@dataclass
class StreamStats(BaseDataModel):
"""聊天流统计信息"""
stream_id: str
is_active: bool
unread_count: int
history_count: int
last_check_time: float
has_active_task: bool

View File

@@ -30,6 +30,7 @@ from src.common.database.sqlalchemy_models import (
Schedule,
MaiZoneScheduleStatus,
CacheEntries,
UserRelationships,
)
from src.common.logger import get_logger
@@ -54,6 +55,7 @@ MODEL_MAPPING = {
"Schedule": Schedule,
"MaiZoneScheduleStatus": MaiZoneScheduleStatus,
"CacheEntries": CacheEntries,
"UserRelationships": UserRelationships,
}

View File

@@ -55,7 +55,17 @@ class ChatStreams(Base):
user_cardname = Column(Text, nullable=True)
energy_value = Column(Float, nullable=True, default=5.0)
sleep_pressure = Column(Float, nullable=True, default=0.0)
focus_energy = Column(Float, nullable=True, default=1.0)
focus_energy = Column(Float, nullable=True, default=0.5)
# 动态兴趣度系统字段
base_interest_energy = Column(Float, nullable=True, default=0.5)
message_interest_total = Column(Float, nullable=True, default=0.0)
message_count = Column(Integer, nullable=True, default=0)
action_count = Column(Integer, nullable=True, default=0)
reply_count = Column(Integer, nullable=True, default=0)
last_interaction_time = Column(Float, nullable=True, default=None)
consecutive_no_reply = Column(Integer, nullable=True, default=0)
# 消息打断系统字段
interruption_count = Column(Integer, nullable=True, default=0)
__table_args__ = (
Index("idx_chatstreams_stream_id", "stream_id"),
@@ -165,11 +175,16 @@ class Messages(Base):
is_command = Column(Boolean, nullable=False, default=False)
is_notify = Column(Boolean, nullable=False, default=False)
# 兴趣度系统字段
actions = Column(Text, nullable=True) # JSON格式存储动作列表
should_reply = Column(Boolean, nullable=True, default=False)
__table_args__ = (
Index("idx_messages_message_id", "message_id"),
Index("idx_messages_chat_id", "chat_id"),
Index("idx_messages_time", "time"),
Index("idx_messages_user_id", "user_id"),
Index("idx_messages_should_reply", "should_reply"),
)
@@ -300,6 +315,26 @@ class PersonInfo(Base):
)
class BotPersonalityInterests(Base):
"""机器人人格兴趣标签模型"""
__tablename__ = "bot_personality_interests"
id = Column(Integer, primary_key=True, autoincrement=True)
personality_id = Column(get_string_field(100), nullable=False, index=True)
personality_description = Column(Text, nullable=False)
interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表
embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
version = Column(Integer, nullable=False, default=1)
last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
__table_args__ = (
Index("idx_botpersonality_personality_id", "personality_id"),
Index("idx_botpersonality_version", "version"),
Index("idx_botpersonality_last_updated", "last_updated"),
)
class Memory(Base):
"""记忆模型"""
@@ -722,3 +757,23 @@ class UserPermissions(Base):
Index("idx_user_permission", "platform", "user_id", "permission_node"),
Index("idx_permission_granted", "permission_node", "granted"),
)
class UserRelationships(Base):
"""用户关系模型 - 存储用户与bot的关系数据"""
__tablename__ = "user_relationships"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
user_name = Column(get_string_field(100), nullable=True) # 用户名
relationship_text = Column(Text, nullable=True) # 关系印象描述
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
__table_args__ = (
Index("idx_user_relationship_id", "user_id"),
Index("idx_relationship_score", "relationship_score"),
Index("idx_relationship_updated", "last_updated"),
)

View File

@@ -350,6 +350,10 @@ MODULE_COLORS = {
"memory": "\033[38;5;117m", # 天蓝色
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
"action_manager": "\033[38;5;208m", # 橙色不与replyer重复
"message_manager": "\033[38;5;27m", # 深蓝色,消息管理器
"chatter_manager": "\033[38;5;129m", # 紫色,聊天管理器
"chatter_interest_scoring": "\033[38;5;214m", # 橙黄色,兴趣评分
"plan_executor": "\033[38;5;172m", # 橙褐色,计划执行器
# 关系系统
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
# 聊天相关模块
@@ -551,6 +555,10 @@ MODULE_ALIASES = {
"llm_models": "模型",
"person_info": "人物",
"chat_stream": "聊天流",
"message_manager": "消息管理",
"chatter_manager": "聊天管理",
"chatter_interest_scoring": "兴趣评分",
"plan_executor": "计划执行",
"planner": "规划器",
"replyer": "言语",
"config": "配置",

View File

@@ -23,15 +23,15 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
maim_message_config = global_config.maim_message
# 设置基本参数
host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000")
try:
port = int(port_str)
except ValueError:
port = 8000
kwargs = {
"host": host,
"port": port,

View File

@@ -22,10 +22,15 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
"""
将 SQLAlchemy 模型实例转换为字典。
"""
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
try:
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
except Exception as e:
# 如果对象已经脱离会话尝试从instance.__dict__中获取数据
logger.warning(f"从数据库对象获取属性失败尝试使用__dict__: {e}")
return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns}
async def find_messages(
def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
limit: int = 0,
@@ -46,7 +51,7 @@ async def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
async with get_db_session() as session:
with get_db_session() as session:
query = select(Messages)
# 应用过滤器
@@ -96,7 +101,7 @@ async def find_messages(
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
try:
results = (await session.execute(query)).scalars().all()
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
@@ -104,7 +109,7 @@ async def find_messages(
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
try:
latest_results = (await session.execute(query)).scalars().all()
latest_results = session.execute(query).scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
@@ -128,11 +133,12 @@ async def find_messages(
if sort_terms:
query = query.order_by(*sort_terms)
try:
results = (await session.execute(query)).scalars().all()
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
results = []
# 在会话内将结果转换为字典,避免会话分离错误
return [_model_to_dict(msg) for msg in results]
except Exception as e:
log_message = (
@@ -143,7 +149,7 @@ async def find_messages(
return []
async def count_messages(message_filter: dict[str, Any]) -> int:
def count_messages(message_filter: dict[str, Any]) -> int:
"""
根据提供的过滤器计算消息数量。
@@ -154,7 +160,7 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量,如果出错则返回 0。
"""
try:
async with get_db_session() as session:
with get_db_session() as session:
query = select(func.count(Messages.id))
# 应用过滤器
@@ -192,7 +198,7 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
if conditions:
query = query.where(*conditions)
count = (await session.execute(query)).scalar()
count = session.execute(query).scalar()
return count or 0
except Exception as e:
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
@@ -201,5 +207,5 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 SQLAlchemy插入操作通常是使用 session.add() 和 await session.commit()。
# 注意:对于 SQLAlchemy插入操作通常是使用 session.add() 和 session.commit()。
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。

View File

@@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask):
self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore
"""客户端UUID"""
self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore
self.private_key_pem: str | None = (
local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None
) # type: ignore
"""客户端私钥"""
self.info_dict = self._get_sys_info()
@@ -61,78 +63,65 @@ class TelemetryHeartBeatTask(AsyncTask):
def _generate_signature(self, request_body: dict) -> tuple[str, str]:
"""
生成RSA签名
Returns:
tuple[str, str]: (timestamp, signature_b64)
"""
if not self.private_key_pem:
raise ValueError("私钥未初始化")
# 生成时间戳
timestamp = datetime.now(timezone.utc).isoformat()
# 创建签名数据字符串
sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}"
# 加载私钥
private_key = serialization.load_pem_private_key(
self.private_key_pem.encode('utf-8'),
password=None
)
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
# 确保是RSA私钥
if not isinstance(private_key, rsa.RSAPrivateKey):
raise ValueError("私钥必须是RSA格式")
# 生成签名
signature = private_key.sign(
sign_data.encode('utf-8'),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
sign_data.encode("utf-8"),
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
hashes.SHA256(),
)
# Base64编码
signature_b64 = base64.b64encode(signature).decode('utf-8')
signature_b64 = base64.b64encode(signature).decode("utf-8")
return timestamp, signature_b64
def _decrypt_challenge(self, challenge_b64: str) -> str:
"""
解密挑战数据
Args:
challenge_b64: Base64编码的挑战数据
Returns:
str: 解密后的UUID字符串
"""
if not self.private_key_pem:
raise ValueError("私钥未初始化")
# 加载私钥
private_key = serialization.load_pem_private_key(
self.private_key_pem.encode('utf-8'),
password=None
)
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
# 确保是RSA私钥
if not isinstance(private_key, rsa.RSAPrivateKey):
raise ValueError("私钥必须是RSA格式")
# 解密挑战数据
decrypted_bytes = private_key.decrypt(
base64.b64decode(challenge_b64),
padding.OAEP(
mgf=padding.MGF1(hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
)
return decrypted_bytes.decode('utf-8')
return decrypted_bytes.decode("utf-8")
async def _req_uuid(self) -> bool:
"""
@@ -155,28 +144,26 @@ class TelemetryHeartBeatTask(AsyncTask):
if response.status != 200:
response_text = await response.text()
logger.error(
f"注册步骤1失败状态码: {response.status}, 响应内容: {response_text}"
)
logger.error(f"注册步骤1失败状态码: {response.status}, 响应内容: {response_text}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Step1 failed: {response_text}"
message=f"Step1 failed: {response_text}",
)
step1_data = await response.json()
temp_uuid = step1_data.get("temp_uuid")
private_key = step1_data.get("private_key")
challenge = step1_data.get("challenge")
if not all([temp_uuid, private_key, challenge]):
logger.error("Step1响应缺少必要字段temp_uuid, private_key 或 challenge")
raise ValueError("Step1响应数据不完整")
# 临时保存私钥用于解密
self.private_key_pem = private_key
# 解密挑战数据
logger.debug("解密挑战数据...")
try:
@@ -184,21 +171,18 @@ class TelemetryHeartBeatTask(AsyncTask):
except Exception as e:
logger.error(f"解密挑战数据失败: {e}")
raise
# 验证解密结果
if decrypted_uuid != temp_uuid:
logger.error(f"解密结果验证失败: 期望 {temp_uuid}, 实际 {decrypted_uuid}")
raise ValueError("解密结果与临时UUID不匹配")
logger.debug("挑战数据解密成功开始注册步骤2")
# Step 2: 发送解密结果完成注册
async with session.post(
f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2",
json={
"temp_uuid": temp_uuid,
"decrypted_uuid": decrypted_uuid
},
json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid},
timeout=aiohttp.ClientTimeout(total=5),
) as response:
logger.debug(f"Step2 Response status: {response.status}")
@@ -206,7 +190,7 @@ class TelemetryHeartBeatTask(AsyncTask):
if response.status == 200:
step2_data = await response.json()
mofox_uuid = step2_data.get("mofox_uuid")
if mofox_uuid:
# 将正式UUID和私钥存储到本地
local_storage["mofox_uuid"] = mofox_uuid
@@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask):
raise ValueError(f"Step2失败: {response_text}")
else:
response_text = await response.text()
logger.error(
f"注册步骤2失败状态码: {response.status}, 响应内容: {response_text}"
)
logger.error(f"注册步骤2失败状态码: {response.status}, 响应内容: {response_text}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Step2 failed: {response_text}"
message=f"Step2 failed: {response_text}",
)
except Exception as e:
import traceback
error_msg = str(e) or "未知错误"
logger.warning(
f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}"
)
logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}")
logger.debug(f"完整错误信息: {traceback.format_exc()}")
# 请求失败,重试次数+1
@@ -264,13 +244,13 @@ class TelemetryHeartBeatTask(AsyncTask):
try:
# 生成签名
timestamp, signature = self._generate_signature(self.info_dict)
headers = {
"X-mofox-UUID": self.client_uuid,
"X-mofox-Signature": signature,
"X-mofox-Timestamp": timestamp,
"User-Agent": f"MofoxClient/{self.client_uuid[:8]}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
@@ -347,4 +327,4 @@ class TelemetryHeartBeatTask(AsyncTask):
logger.warning("客户端注册失败,跳过此次心跳")
return
await self._send_heartbeat()
await self._send_heartbeat()

View File

@@ -99,14 +99,13 @@ def get_global_server() -> Server:
"""获取全局服务器实例"""
global global_server
if global_server is None:
host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000")
try:
port = int(port_str)
except ValueError:
port = 8000
global_server = Server(host=host, port=port)
return global_server