Merge afc branch into dev, prioritizing afc changes and migrating database async modifications from dev

This commit is contained in:
Windpicker-owo
2025-09-27 23:37:40 +08:00
138 changed files with 12183 additions and 5968 deletions

View File

@@ -6,6 +6,7 @@ class BaseDataModel:
def deepcopy(self):
return copy.deepcopy(self)
def temporarily_transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else
"""

View File

@@ -0,0 +1,137 @@
"""
机器人兴趣标签数据模型
定义机器人的兴趣标签和相关的embedding数据结构
"""
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from datetime import datetime
from . import BaseDataModel
@dataclass
class BotInterestTag(BaseDataModel):
"""机器人兴趣标签"""
tag_name: str
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
embedding: Optional[List[float]] = None # 标签的embedding向量
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
is_active: bool = True
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"tag_name": self.tag_name,
"weight": self.weight,
"embedding": self.embedding,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"is_active": self.is_active,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag":
"""从字典创建对象"""
return cls(
tag_name=data["tag_name"],
weight=data.get("weight", 1.0),
embedding=data.get("embedding"),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
is_active=data.get("is_active", True),
)
@dataclass
class BotPersonalityInterests(BaseDataModel):
"""机器人人格化兴趣配置"""
personality_id: str
personality_description: str # 人设描述文本
interest_tags: List[BotInterestTag] = field(default_factory=list)
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
last_updated: datetime = field(default_factory=datetime.now)
version: int = 1 # 版本号,用于追踪更新
def get_active_tags(self) -> List[BotInterestTag]:
"""获取活跃的兴趣标签"""
return [tag for tag in self.interest_tags if tag.is_active]
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"personality_id": self.personality_id,
"personality_description": self.personality_description,
"interest_tags": [tag.to_dict() for tag in self.interest_tags],
"embedding_model": self.embedding_model,
"last_updated": self.last_updated.isoformat(),
"version": self.version,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests":
"""从字典创建对象"""
return cls(
personality_id=data["personality_id"],
personality_description=data["personality_description"],
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
version=data.get("version", 1),
)
@dataclass
class InterestMatchResult(BaseDataModel):
"""兴趣匹配结果"""
message_id: str
matched_tags: List[str] = field(default_factory=list)
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
overall_score: float = 0.0
top_tag: Optional[str] = None
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
matched_keywords: List[str] = field(default_factory=list)
def add_match(self, tag_name: str, score: float, keywords: List[str] = None):
"""添加匹配结果"""
self.matched_tags.append(tag_name)
self.match_scores[tag_name] = score
if keywords:
self.matched_keywords.extend(keywords)
def calculate_overall_score(self):
"""计算总体匹配分数"""
if not self.match_scores:
self.overall_score = 0.0
self.top_tag = None
return
# 使用加权平均计算总体分数
total_weight = len(self.match_scores)
if total_weight > 0:
self.overall_score = sum(self.match_scores.values()) / total_weight
# 设置最佳匹配标签
self.top_tag = max(self.match_scores.items(), key=lambda x: x[1])[0]
else:
self.overall_score = 0.0
self.top_tag = None
# 计算置信度(基于匹配标签数量和分数分布)
if len(self.match_scores) > 0:
avg_score = self.overall_score
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
self.match_scores
)
# 分数越集中,置信度越高
self.confidence = max(0.0, 1.0 - score_variance)
else:
self.confidence = 0.0
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
"""获取前N个最佳匹配"""
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
return sorted_matches[:top_n]

View File

@@ -79,6 +79,7 @@ class DatabaseMessages(BaseDataModel):
is_command: bool = False,
is_notify: bool = False,
selected_expressions: Optional[str] = None,
is_read: bool = False,
user_id: str = "",
user_nickname: str = "",
user_cardname: Optional[str] = None,
@@ -94,6 +95,9 @@ class DatabaseMessages(BaseDataModel):
chat_info_platform: str = "",
chat_info_create_time: float = 0.0,
chat_info_last_active_time: float = 0.0,
# 新增字段
actions: Optional[list] = None,
should_reply: bool = False,
**kwargs: Any,
):
self.message_id = message_id
@@ -102,6 +106,10 @@ class DatabaseMessages(BaseDataModel):
self.reply_to = reply_to
self.interest_value = interest_value
# 新增字段
self.actions = actions
self.should_reply = should_reply
self.key_words = key_words
self.key_words_lite = key_words_lite
self.is_mentioned = is_mentioned
@@ -122,6 +130,7 @@ class DatabaseMessages(BaseDataModel):
self.is_notify = is_notify
self.selected_expressions = selected_expressions
self.is_read = is_read
self.group_info: Optional[DatabaseGroupInfo] = None
self.user_info = DatabaseUserInfo(
@@ -188,6 +197,10 @@ class DatabaseMessages(BaseDataModel):
"is_command": self.is_command,
"is_notify": self.is_notify,
"selected_expressions": self.selected_expressions,
"is_read": self.is_read,
# 新增字段
"actions": self.actions,
"should_reply": self.should_reply,
"user_id": self.user_info.user_id,
"user_nickname": self.user_info.user_nickname,
"user_cardname": self.user_info.user_cardname,
@@ -205,6 +218,61 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
}
def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None):
"""
更新消息信息
Args:
interest_value: 兴趣度值
actions: 执行的动作列表
should_reply: 是否应该回复
"""
if interest_value is not None:
self.interest_value = interest_value
if actions is not None:
self.actions = actions
if should_reply is not None:
self.should_reply = should_reply
def add_action(self, action: str):
"""
添加执行的动作到消息中
Args:
action: 要添加的动作名称
"""
if self.actions is None:
self.actions = []
if action not in self.actions: # 避免重复添加
self.actions.append(action)
def get_actions(self) -> list:
"""
获取执行的动作列表
Returns:
动作列表,如果没有动作则返回空列表
"""
return self.actions or []
def get_message_summary(self) -> Dict[str, Any]:
"""
获取消息摘要信息
Returns:
包含关键字段的消息摘要
"""
return {
"message_id": self.message_id,
"time": self.time,
"interest_value": self.interest_value,
"actions": self.actions,
"should_reply": self.should_reply,
"user_nickname": self.user_info.user_nickname,
"display_message": self.display_message,
}
@dataclass(init=False)
class DatabaseActionRecords(BaseDataModel):
def __init__(
@@ -232,4 +300,4 @@ class DatabaseActionRecords(BaseDataModel):
self.action_prompt_display = action_prompt_display
self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform
self.chat_info_platform = chat_info_platform

View File

@@ -1,10 +1,12 @@
from dataclasses import dataclass, field
from typing import Optional, Dict, List, TYPE_CHECKING
from src.plugin_system.base.component_types import ChatType
from . import BaseDataModel
if TYPE_CHECKING:
pass
from .database_data_model import DatabaseMessages
from src.plugin_system.base.component_types import ActionInfo, ChatMode
@dataclass
@@ -21,23 +23,37 @@ class ActionPlannerInfo(BaseDataModel):
action_type: str = field(default_factory=str)
reasoning: Optional[str] = None
action_data: Optional[Dict] = None
action_message: Optional[Dict] = None
action_message: Optional["DatabaseMessages"] = None
available_actions: Optional[Dict[str, "ActionInfo"]] = None
@dataclass
class InterestScore(BaseDataModel):
"""兴趣度评分结果"""
message_id: str
total_score: float
interest_match_score: float
relationship_score: float
mentioned_score: float
details: Dict[str, str]
@dataclass
class Plan(BaseDataModel):
"""
统一规划数据模型
"""
chat_id: str
mode: "ChatMode"
chat_type: "ChatType"
# Generator 填充
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
chat_history: List["DatabaseMessages"] = field(default_factory=list)
target_info: Optional[TargetPersonInfo] = None
# Filter 填充
llm_prompt: Optional[str] = None
decided_actions: Optional[List[ActionPlannerInfo]] = None

View File

@@ -6,6 +6,7 @@ from . import BaseDataModel
if TYPE_CHECKING:
pass
@dataclass
class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None
@@ -14,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
tool_calls: Optional[List["ToolCall"]] = None
prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None

View File

@@ -1,36 +0,0 @@
from dataclasses import dataclass, field
from typing import Optional, TYPE_CHECKING
from . import BaseDataModel
if TYPE_CHECKING:
pass
@dataclass
class MessageAndActionModel(BaseDataModel):
chat_id: str = field(default_factory=str)
time: float = field(default_factory=float)
user_id: str = field(default_factory=str)
user_platform: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
user_cardname: Optional[str] = None
processed_plain_text: Optional[str] = None
display_message: Optional[str] = None
chat_info_platform: str = field(default_factory=str)
is_action_record: bool = field(default=False)
action_name: Optional[str] = None
@classmethod
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
return cls(
chat_id=message.chat_id,
time=message.time,
user_id=message.user_info.user_id,
user_platform=message.user_info.platform,
user_nickname=message.user_info.user_nickname,
user_cardname=message.user_info.user_cardname,
processed_plain_text=message.processed_plain_text,
display_message=message.display_message,
chat_info_platform=message.chat_info.platform,
)

View File

@@ -0,0 +1,373 @@
"""
消息管理模块数据模型
定义消息管理器使用的数据结构
"""
import asyncio
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, TYPE_CHECKING
from . import BaseDataModel
from src.plugin_system.base.component_types import ChatMode, ChatType
from src.common.logger import get_logger
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
logger = get_logger("stream_context")
class MessageStatus(Enum):
"""消息状态枚举"""
UNREAD = "unread" # 未读消息
READ = "read" # 已读消息
PROCESSING = "processing" # 处理中
@dataclass
class StreamContext(BaseDataModel):
"""聊天流上下文信息"""
stream_id: str
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
history_messages: List["DatabaseMessages"] = field(default_factory=list)
last_check_time: float = field(default_factory=time.time)
is_active: bool = True
processing_task: Optional[asyncio.Task] = None
interruption_count: int = 0 # 打断计数器
last_interruption_time: float = 0.0 # 上次打断时间
afc_threshold_adjustment: float = 0.0 # afc阈值调整量
# 独立分发周期字段
next_check_time: float = field(default_factory=time.time) # 下次检查时间
distribution_interval: float = 5.0 # 当前分发周期(秒)
# 新增字段以替代ChatMessageContext功能
current_message: Optional["DatabaseMessages"] = None
priority_mode: Optional[str] = None
priority_info: Optional[dict] = None
def add_message(self, message: "DatabaseMessages"):
"""添加消息到上下文"""
message.is_read = False
self.unread_messages.append(message)
# 自动检测和更新chat type
self._detect_chat_type(message)
def update_message_info(
self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None
):
"""
更新消息信息
Args:
message_id: 消息ID
interest_value: 兴趣度值
actions: 执行的动作列表
should_reply: 是否应该回复
"""
# 在未读消息中查找并更新
for message in self.unread_messages:
if message.message_id == message_id:
message.update_message_info(interest_value, actions, should_reply)
break
# 在历史消息中查找并更新
for message in self.history_messages:
if message.message_id == message_id:
message.update_message_info(interest_value, actions, should_reply)
break
def add_action_to_message(self, message_id: str, action: str):
"""
向指定消息添加执行的动作
Args:
message_id: 消息ID
action: 要添加的动作名称
"""
# 在未读消息中查找并更新
for message in self.unread_messages:
if message.message_id == message_id:
message.add_action(action)
break
# 在历史消息中查找并更新
for message in self.history_messages:
if message.message_id == message_id:
message.add_action(action)
break
def _detect_chat_type(self, message: "DatabaseMessages"):
"""根据消息内容自动检测聊天类型"""
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
if len(self.unread_messages) == 1: # 只有这条消息
# 如果消息包含群组信息,则为群聊
if hasattr(message, "chat_info_group_id") and message.chat_info_group_id:
self.chat_type = ChatType.GROUP
elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name:
self.chat_type = ChatType.GROUP
else:
self.chat_type = ChatType.PRIVATE
def update_chat_type(self, chat_type: ChatType):
"""手动更新聊天类型"""
self.chat_type = chat_type
def set_chat_mode(self, chat_mode: ChatMode):
"""设置聊天模式"""
self.chat_mode = chat_mode
def is_group_chat(self) -> bool:
"""检查是否为群聊"""
return self.chat_type == ChatType.GROUP
def is_private_chat(self) -> bool:
"""检查是否为私聊"""
return self.chat_type == ChatType.PRIVATE
def get_chat_type_display(self) -> str:
"""获取聊天类型的显示名称"""
if self.chat_type == ChatType.GROUP:
return "群聊"
elif self.chat_type == ChatType.PRIVATE:
return "私聊"
else:
return "未知类型"
def mark_message_as_read(self, message_id: str):
"""标记消息为已读"""
for msg in self.unread_messages:
if msg.message_id == message_id:
msg.is_read = True
self.history_messages.append(msg)
self.unread_messages.remove(msg)
break
def get_unread_messages(self) -> List["DatabaseMessages"]:
"""获取未读消息"""
return [msg for msg in self.unread_messages if not msg.is_read]
def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]:
"""获取历史消息"""
# 优先返回最近的历史消息和所有未读消息
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
return recent_history
def calculate_interruption_probability(self, max_limit: int, probability_factor: float) -> float:
"""计算打断概率"""
if max_limit <= 0:
return 0.0
# 计算打断比例
interruption_ratio = self.interruption_count / max_limit
# 如果已达到或超过最大次数,完全禁止打断
if self.interruption_count >= max_limit:
return 0.0
# 如果超过概率因子,概率下降
if interruption_ratio > probability_factor:
# 使用指数衰减,超过限制越多,概率越低
excess_ratio = interruption_ratio - probability_factor
probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减
else:
# 在限制内,保持较高概率
probability = 0.8
return max(0.0, min(1.0, probability))
def increment_interruption_count(self):
"""增加打断计数"""
self.interruption_count += 1
self.last_interruption_time = time.time()
# 同步打断计数到ChatStream
self._sync_interruption_count_to_stream()
def reset_interruption_count(self):
"""重置打断计数和afc阈值调整"""
self.interruption_count = 0
self.last_interruption_time = 0.0
self.afc_threshold_adjustment = 0.0
# 同步打断计数到ChatStream
self._sync_interruption_count_to_stream()
def apply_interruption_afc_reduction(self, reduction_value: float):
"""应用打断导致的afc阈值降低"""
self.afc_threshold_adjustment += reduction_value
logger.debug(f"应用afc阈值降低: {reduction_value}, 总调整量: {self.afc_threshold_adjustment}")
def get_afc_threshold_adjustment(self) -> float:
"""获取当前的afc阈值调整量"""
return self.afc_threshold_adjustment
def _sync_interruption_count_to_stream(self):
"""同步打断计数到ChatStream"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if chat_manager:
chat_stream = chat_manager.get_stream(self.stream_id)
if chat_stream and hasattr(chat_stream, "interruption_count"):
# 在这里我们只是标记需要保存实际的保存会在下次save时进行
chat_stream.saved = False
logger.debug(
f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream"
)
except Exception as e:
logger.warning(f"同步打断计数到ChatStream失败: {e}")
def set_current_message(self, message: "DatabaseMessages"):
"""设置当前消息"""
self.current_message = message
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if (
self.current_message
and hasattr(self.current_message, "additional_config")
and self.current_message.additional_config
):
try:
import json
config = json.loads(self.current_message.additional_config)
if config.get("template_info") and not config.get("template_default", True):
return config.get("template_name")
except (json.JSONDecodeError, AttributeError):
pass
return None
def get_last_message(self) -> Optional["DatabaseMessages"]:
"""获取最后一条消息"""
if self.current_message:
return self.current_message
if self.unread_messages:
return self.unread_messages[-1]
if self.history_messages:
return self.history_messages[-1]
return None
def check_types(self, types: list) -> bool:
"""
检查当前消息是否支持指定的类型
Args:
types: 需要检查的消息类型列表,如 ["text", "image", "emoji"]
Returns:
bool: 如果消息支持所有指定的类型则返回True否则返回False
"""
if not self.current_message:
return False
if not types:
# 如果没有指定类型要求,默认为支持
return True
# 优先从additional_config中获取format_info
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
try:
import orjson
config = orjson.loads(self.current_message.additional_config)
# 检查format_info结构
if "format_info" in config:
format_info = config["format_info"]
# 方法1: 直接检查accept_format字段
if "accept_format" in format_info:
accept_format = format_info["accept_format"]
# 确保accept_format是列表类型
if isinstance(accept_format, str):
accept_format = [accept_format]
elif isinstance(accept_format, list):
pass
else:
# 如果accept_format不是字符串或列表尝试转换为列表
accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else []
# 检查所有请求的类型是否都被支持
for requested_type in types:
if requested_type not in accept_format:
logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
return False
return True
# 方法2: 检查content_format字段向后兼容
elif "content_format" in format_info:
content_format = format_info["content_format"]
# 确保content_format是列表类型
if isinstance(content_format, str):
content_format = [content_format]
elif isinstance(content_format, list):
pass
else:
content_format = list(content_format) if hasattr(content_format, "__iter__") else []
# 检查所有请求的类型是否都被支持
for requested_type in types:
if requested_type not in content_format:
logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
return False
return True
except (orjson.JSONDecodeError, AttributeError, TypeError) as e:
logger.debug(f"解析消息格式信息失败: {e}")
# 备用方案如果无法从additional_config获取格式信息使用默认支持的类型
# 大多数消息至少支持text类型
default_supported_types = ["text", "emoji"]
for requested_type in types:
if requested_type not in default_supported_types:
logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'")
# 对于非基础类型返回False以避免错误
if requested_type not in ["text", "emoji", "reply"]:
return False
return True
def get_priority_mode(self) -> Optional[str]:
"""获取优先级模式"""
return self.priority_mode
def get_priority_info(self) -> Optional[dict]:
"""获取优先级信息"""
return self.priority_info
@dataclass
class MessageManagerStats(BaseDataModel):
"""消息管理器统计信息"""
total_streams: int = 0
active_streams: int = 0
total_unread_messages: int = 0
total_processed_messages: int = 0
start_time: float = field(default_factory=time.time)
@property
def uptime(self) -> float:
"""运行时间"""
return time.time() - self.start_time
@dataclass
class StreamStats(BaseDataModel):
"""聊天流统计信息"""
stream_id: str
is_active: bool
unread_count: int
history_count: int
last_check_time: float
has_active_task: bool