Merge afc branch into dev, prioritizing afc changes and migrating database async modifications from dev
This commit is contained in:
@@ -6,6 +6,7 @@ class BaseDataModel:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else
|
||||
"""
|
||||
|
||||
137
src/common/data_models/bot_interest_data_model.py
Normal file
137
src/common/data_models/bot_interest_data_model.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
机器人兴趣标签数据模型
|
||||
定义机器人的兴趣标签和相关的embedding数据结构
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterestTag(BaseDataModel):
|
||||
"""机器人兴趣标签"""
|
||||
|
||||
tag_name: str
|
||||
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
||||
embedding: Optional[List[float]] = None # 标签的embedding向量
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"tag_name": self.tag_name,
|
||||
"weight": self.weight,
|
||||
"embedding": self.embedding,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"is_active": self.is_active,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag":
|
||||
"""从字典创建对象"""
|
||||
return cls(
|
||||
tag_name=data["tag_name"],
|
||||
weight=data.get("weight", 1.0),
|
||||
embedding=data.get("embedding"),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
|
||||
is_active=data.get("is_active", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotPersonalityInterests(BaseDataModel):
|
||||
"""机器人人格化兴趣配置"""
|
||||
|
||||
personality_id: str
|
||||
personality_description: str # 人设描述文本
|
||||
interest_tags: List[BotInterestTag] = field(default_factory=list)
|
||||
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
|
||||
last_updated: datetime = field(default_factory=datetime.now)
|
||||
version: int = 1 # 版本号,用于追踪更新
|
||||
|
||||
def get_active_tags(self) -> List[BotInterestTag]:
|
||||
"""获取活跃的兴趣标签"""
|
||||
return [tag for tag in self.interest_tags if tag.is_active]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"personality_id": self.personality_id,
|
||||
"personality_description": self.personality_description,
|
||||
"interest_tags": [tag.to_dict() for tag in self.interest_tags],
|
||||
"embedding_model": self.embedding_model,
|
||||
"last_updated": self.last_updated.isoformat(),
|
||||
"version": self.version,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests":
|
||||
"""从字典创建对象"""
|
||||
return cls(
|
||||
personality_id=data["personality_id"],
|
||||
personality_description=data["personality_description"],
|
||||
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
|
||||
embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
|
||||
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
|
||||
version=data.get("version", 1),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestMatchResult(BaseDataModel):
|
||||
"""兴趣匹配结果"""
|
||||
|
||||
message_id: str
|
||||
matched_tags: List[str] = field(default_factory=list)
|
||||
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
|
||||
overall_score: float = 0.0
|
||||
top_tag: Optional[str] = None
|
||||
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
|
||||
matched_keywords: List[str] = field(default_factory=list)
|
||||
|
||||
def add_match(self, tag_name: str, score: float, keywords: List[str] = None):
|
||||
"""添加匹配结果"""
|
||||
self.matched_tags.append(tag_name)
|
||||
self.match_scores[tag_name] = score
|
||||
if keywords:
|
||||
self.matched_keywords.extend(keywords)
|
||||
|
||||
def calculate_overall_score(self):
|
||||
"""计算总体匹配分数"""
|
||||
if not self.match_scores:
|
||||
self.overall_score = 0.0
|
||||
self.top_tag = None
|
||||
return
|
||||
|
||||
# 使用加权平均计算总体分数
|
||||
total_weight = len(self.match_scores)
|
||||
if total_weight > 0:
|
||||
self.overall_score = sum(self.match_scores.values()) / total_weight
|
||||
# 设置最佳匹配标签
|
||||
self.top_tag = max(self.match_scores.items(), key=lambda x: x[1])[0]
|
||||
else:
|
||||
self.overall_score = 0.0
|
||||
self.top_tag = None
|
||||
|
||||
# 计算置信度(基于匹配标签数量和分数分布)
|
||||
if len(self.match_scores) > 0:
|
||||
avg_score = self.overall_score
|
||||
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
|
||||
self.match_scores
|
||||
)
|
||||
# 分数越集中,置信度越高
|
||||
self.confidence = max(0.0, 1.0 - score_variance)
|
||||
else:
|
||||
self.confidence = 0.0
|
||||
|
||||
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
|
||||
"""获取前N个最佳匹配"""
|
||||
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_matches[:top_n]
|
||||
@@ -79,6 +79,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
is_command: bool = False,
|
||||
is_notify: bool = False,
|
||||
selected_expressions: Optional[str] = None,
|
||||
is_read: bool = False,
|
||||
user_id: str = "",
|
||||
user_nickname: str = "",
|
||||
user_cardname: Optional[str] = None,
|
||||
@@ -94,6 +95,9 @@ class DatabaseMessages(BaseDataModel):
|
||||
chat_info_platform: str = "",
|
||||
chat_info_create_time: float = 0.0,
|
||||
chat_info_last_active_time: float = 0.0,
|
||||
# 新增字段
|
||||
actions: Optional[list] = None,
|
||||
should_reply: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.message_id = message_id
|
||||
@@ -102,6 +106,10 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.reply_to = reply_to
|
||||
self.interest_value = interest_value
|
||||
|
||||
# 新增字段
|
||||
self.actions = actions
|
||||
self.should_reply = should_reply
|
||||
|
||||
self.key_words = key_words
|
||||
self.key_words_lite = key_words_lite
|
||||
self.is_mentioned = is_mentioned
|
||||
@@ -122,6 +130,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.is_notify = is_notify
|
||||
|
||||
self.selected_expressions = selected_expressions
|
||||
self.is_read = is_read
|
||||
|
||||
self.group_info: Optional[DatabaseGroupInfo] = None
|
||||
self.user_info = DatabaseUserInfo(
|
||||
@@ -188,6 +197,10 @@ class DatabaseMessages(BaseDataModel):
|
||||
"is_command": self.is_command,
|
||||
"is_notify": self.is_notify,
|
||||
"selected_expressions": self.selected_expressions,
|
||||
"is_read": self.is_read,
|
||||
# 新增字段
|
||||
"actions": self.actions,
|
||||
"should_reply": self.should_reply,
|
||||
"user_id": self.user_info.user_id,
|
||||
"user_nickname": self.user_info.user_nickname,
|
||||
"user_cardname": self.user_info.user_cardname,
|
||||
@@ -205,6 +218,61 @@ class DatabaseMessages(BaseDataModel):
|
||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||
}
|
||||
|
||||
def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
Args:
|
||||
interest_value: 兴趣度值
|
||||
actions: 执行的动作列表
|
||||
should_reply: 是否应该回复
|
||||
"""
|
||||
if interest_value is not None:
|
||||
self.interest_value = interest_value
|
||||
if actions is not None:
|
||||
self.actions = actions
|
||||
if should_reply is not None:
|
||||
self.should_reply = should_reply
|
||||
|
||||
def add_action(self, action: str):
|
||||
"""
|
||||
添加执行的动作到消息中
|
||||
|
||||
Args:
|
||||
action: 要添加的动作名称
|
||||
"""
|
||||
if self.actions is None:
|
||||
self.actions = []
|
||||
if action not in self.actions: # 避免重复添加
|
||||
self.actions.append(action)
|
||||
|
||||
def get_actions(self) -> list:
|
||||
"""
|
||||
获取执行的动作列表
|
||||
|
||||
Returns:
|
||||
动作列表,如果没有动作则返回空列表
|
||||
"""
|
||||
return self.actions or []
|
||||
|
||||
def get_message_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取消息摘要信息
|
||||
|
||||
Returns:
|
||||
包含关键字段的消息摘要
|
||||
"""
|
||||
return {
|
||||
"message_id": self.message_id,
|
||||
"time": self.time,
|
||||
"interest_value": self.interest_value,
|
||||
"actions": self.actions,
|
||||
"should_reply": self.should_reply,
|
||||
"user_nickname": self.user_info.user_nickname,
|
||||
"display_message": self.display_message,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class DatabaseActionRecords(BaseDataModel):
|
||||
def __init__(
|
||||
@@ -232,4 +300,4 @@ class DatabaseActionRecords(BaseDataModel):
|
||||
self.action_prompt_display = action_prompt_display
|
||||
self.chat_id = chat_id
|
||||
self.chat_info_stream_id = chat_info_stream_id
|
||||
self.chat_info_platform = chat_info_platform
|
||||
self.chat_info_platform = chat_info_platform
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, List, TYPE_CHECKING
|
||||
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from .database_data_model import DatabaseMessages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -21,23 +23,37 @@ class ActionPlannerInfo(BaseDataModel):
|
||||
action_type: str = field(default_factory=str)
|
||||
reasoning: Optional[str] = None
|
||||
action_data: Optional[Dict] = None
|
||||
action_message: Optional[Dict] = None
|
||||
action_message: Optional["DatabaseMessages"] = None
|
||||
available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestScore(BaseDataModel):
|
||||
"""兴趣度评分结果"""
|
||||
|
||||
message_id: str
|
||||
total_score: float
|
||||
interest_match_score: float
|
||||
relationship_score: float
|
||||
mentioned_score: float
|
||||
details: Dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Plan(BaseDataModel):
|
||||
"""
|
||||
统一规划数据模型
|
||||
"""
|
||||
|
||||
chat_id: str
|
||||
mode: "ChatMode"
|
||||
|
||||
|
||||
chat_type: "ChatType"
|
||||
# Generator 填充
|
||||
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
|
||||
chat_history: List["DatabaseMessages"] = field(default_factory=list)
|
||||
target_info: Optional[TargetPersonInfo] = None
|
||||
|
||||
|
||||
# Filter 填充
|
||||
llm_prompt: Optional[str] = None
|
||||
decided_actions: Optional[List[ActionPlannerInfo]] = None
|
||||
|
||||
@@ -6,6 +6,7 @@ from . import BaseDataModel
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMGenerationDataModel(BaseDataModel):
|
||||
content: Optional[str] = None
|
||||
@@ -14,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
|
||||
tool_calls: Optional[List["ToolCall"]] = None
|
||||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageAndActionModel(BaseDataModel):
|
||||
chat_id: str = field(default_factory=str)
|
||||
time: float = field(default_factory=float)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_platform: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
processed_plain_text: Optional[str] = None
|
||||
display_message: Optional[str] = None
|
||||
chat_info_platform: str = field(default_factory=str)
|
||||
is_action_record: bool = field(default=False)
|
||||
action_name: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
|
||||
return cls(
|
||||
chat_id=message.chat_id,
|
||||
time=message.time,
|
||||
user_id=message.user_info.user_id,
|
||||
user_platform=message.user_info.platform,
|
||||
user_nickname=message.user_info.user_nickname,
|
||||
user_cardname=message.user_info.user_cardname,
|
||||
processed_plain_text=message.processed_plain_text,
|
||||
display_message=message.display_message,
|
||||
chat_info_platform=message.chat_info.platform,
|
||||
)
|
||||
373
src/common/data_models/message_manager_data_model.py
Normal file
373
src/common/data_models/message_manager_data_model.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
消息管理模块数据模型
|
||||
定义消息管理器使用的数据结构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from . import BaseDataModel
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger("stream_context")
|
||||
|
||||
|
||||
class MessageStatus(Enum):
|
||||
"""消息状态枚举"""
|
||||
|
||||
UNREAD = "unread" # 未读消息
|
||||
READ = "read" # 已读消息
|
||||
PROCESSING = "processing" # 处理中
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext(BaseDataModel):
|
||||
"""聊天流上下文信息"""
|
||||
|
||||
stream_id: str
|
||||
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
|
||||
chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式
|
||||
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
last_check_time: float = field(default_factory=time.time)
|
||||
is_active: bool = True
|
||||
processing_task: Optional[asyncio.Task] = None
|
||||
interruption_count: int = 0 # 打断计数器
|
||||
last_interruption_time: float = 0.0 # 上次打断时间
|
||||
afc_threshold_adjustment: float = 0.0 # afc阈值调整量
|
||||
|
||||
# 独立分发周期字段
|
||||
next_check_time: float = field(default_factory=time.time) # 下次检查时间
|
||||
distribution_interval: float = 5.0 # 当前分发周期(秒)
|
||||
|
||||
# 新增字段以替代ChatMessageContext功能
|
||||
current_message: Optional["DatabaseMessages"] = None
|
||||
priority_mode: Optional[str] = None
|
||||
priority_info: Optional[dict] = None
|
||||
|
||||
def add_message(self, message: "DatabaseMessages"):
|
||||
"""添加消息到上下文"""
|
||||
message.is_read = False
|
||||
self.unread_messages.append(message)
|
||||
|
||||
# 自动检测和更新chat type
|
||||
self._detect_chat_type(message)
|
||||
|
||||
def update_message_info(
|
||||
self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None
|
||||
):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
interest_value: 兴趣度值
|
||||
actions: 执行的动作列表
|
||||
should_reply: 是否应该回复
|
||||
"""
|
||||
# 在未读消息中查找并更新
|
||||
for message in self.unread_messages:
|
||||
if message.message_id == message_id:
|
||||
message.update_message_info(interest_value, actions, should_reply)
|
||||
break
|
||||
|
||||
# 在历史消息中查找并更新
|
||||
for message in self.history_messages:
|
||||
if message.message_id == message_id:
|
||||
message.update_message_info(interest_value, actions, should_reply)
|
||||
break
|
||||
|
||||
def add_action_to_message(self, message_id: str, action: str):
|
||||
"""
|
||||
向指定消息添加执行的动作
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
action: 要添加的动作名称
|
||||
"""
|
||||
# 在未读消息中查找并更新
|
||||
for message in self.unread_messages:
|
||||
if message.message_id == message_id:
|
||||
message.add_action(action)
|
||||
break
|
||||
|
||||
# 在历史消息中查找并更新
|
||||
for message in self.history_messages:
|
||||
if message.message_id == message_id:
|
||||
message.add_action(action)
|
||||
break
|
||||
|
||||
def _detect_chat_type(self, message: "DatabaseMessages"):
|
||||
"""根据消息内容自动检测聊天类型"""
|
||||
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
|
||||
if len(self.unread_messages) == 1: # 只有这条消息
|
||||
# 如果消息包含群组信息,则为群聊
|
||||
if hasattr(message, "chat_info_group_id") and message.chat_info_group_id:
|
||||
self.chat_type = ChatType.GROUP
|
||||
elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name:
|
||||
self.chat_type = ChatType.GROUP
|
||||
else:
|
||||
self.chat_type = ChatType.PRIVATE
|
||||
|
||||
def update_chat_type(self, chat_type: ChatType):
|
||||
"""手动更新聊天类型"""
|
||||
self.chat_type = chat_type
|
||||
|
||||
def set_chat_mode(self, chat_mode: ChatMode):
|
||||
"""设置聊天模式"""
|
||||
self.chat_mode = chat_mode
|
||||
|
||||
def is_group_chat(self) -> bool:
|
||||
"""检查是否为群聊"""
|
||||
return self.chat_type == ChatType.GROUP
|
||||
|
||||
def is_private_chat(self) -> bool:
|
||||
"""检查是否为私聊"""
|
||||
return self.chat_type == ChatType.PRIVATE
|
||||
|
||||
def get_chat_type_display(self) -> str:
|
||||
"""获取聊天类型的显示名称"""
|
||||
if self.chat_type == ChatType.GROUP:
|
||||
return "群聊"
|
||||
elif self.chat_type == ChatType.PRIVATE:
|
||||
return "私聊"
|
||||
else:
|
||||
return "未知类型"
|
||||
|
||||
def mark_message_as_read(self, message_id: str):
|
||||
"""标记消息为已读"""
|
||||
for msg in self.unread_messages:
|
||||
if msg.message_id == message_id:
|
||||
msg.is_read = True
|
||||
self.history_messages.append(msg)
|
||||
self.unread_messages.remove(msg)
|
||||
break
|
||||
|
||||
def get_unread_messages(self) -> List["DatabaseMessages"]:
|
||||
"""获取未读消息"""
|
||||
return [msg for msg in self.unread_messages if not msg.is_read]
|
||||
|
||||
def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]:
|
||||
"""获取历史消息"""
|
||||
# 优先返回最近的历史消息和所有未读消息
|
||||
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
|
||||
return recent_history
|
||||
|
||||
def calculate_interruption_probability(self, max_limit: int, probability_factor: float) -> float:
|
||||
"""计算打断概率"""
|
||||
if max_limit <= 0:
|
||||
return 0.0
|
||||
|
||||
# 计算打断比例
|
||||
interruption_ratio = self.interruption_count / max_limit
|
||||
|
||||
# 如果已达到或超过最大次数,完全禁止打断
|
||||
if self.interruption_count >= max_limit:
|
||||
return 0.0
|
||||
|
||||
# 如果超过概率因子,概率下降
|
||||
if interruption_ratio > probability_factor:
|
||||
# 使用指数衰减,超过限制越多,概率越低
|
||||
excess_ratio = interruption_ratio - probability_factor
|
||||
probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减
|
||||
else:
|
||||
# 在限制内,保持较高概率
|
||||
probability = 0.8
|
||||
|
||||
return max(0.0, min(1.0, probability))
|
||||
|
||||
def increment_interruption_count(self):
|
||||
"""增加打断计数"""
|
||||
self.interruption_count += 1
|
||||
self.last_interruption_time = time.time()
|
||||
|
||||
# 同步打断计数到ChatStream
|
||||
self._sync_interruption_count_to_stream()
|
||||
|
||||
def reset_interruption_count(self):
|
||||
"""重置打断计数和afc阈值调整"""
|
||||
self.interruption_count = 0
|
||||
self.last_interruption_time = 0.0
|
||||
self.afc_threshold_adjustment = 0.0
|
||||
|
||||
# 同步打断计数到ChatStream
|
||||
self._sync_interruption_count_to_stream()
|
||||
|
||||
def apply_interruption_afc_reduction(self, reduction_value: float):
|
||||
"""应用打断导致的afc阈值降低"""
|
||||
self.afc_threshold_adjustment += reduction_value
|
||||
logger.debug(f"应用afc阈值降低: {reduction_value}, 总调整量: {self.afc_threshold_adjustment}")
|
||||
|
||||
def get_afc_threshold_adjustment(self) -> float:
|
||||
"""获取当前的afc阈值调整量"""
|
||||
return self.afc_threshold_adjustment
|
||||
|
||||
def _sync_interruption_count_to_stream(self):
|
||||
"""同步打断计数到ChatStream"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
if chat_manager:
|
||||
chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
if chat_stream and hasattr(chat_stream, "interruption_count"):
|
||||
# 在这里我们只是标记需要保存,实际的保存会在下次save时进行
|
||||
chat_stream.saved = False
|
||||
logger.debug(
|
||||
f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"同步打断计数到ChatStream失败: {e}")
|
||||
|
||||
def set_current_message(self, message: "DatabaseMessages"):
|
||||
"""设置当前消息"""
|
||||
self.current_message = message
|
||||
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
"""获取模板名称"""
|
||||
if (
|
||||
self.current_message
|
||||
and hasattr(self.current_message, "additional_config")
|
||||
and self.current_message.additional_config
|
||||
):
|
||||
try:
|
||||
import json
|
||||
|
||||
config = json.loads(self.current_message.additional_config)
|
||||
if config.get("template_info") and not config.get("template_default", True):
|
||||
return config.get("template_name")
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_last_message(self) -> Optional["DatabaseMessages"]:
|
||||
"""获取最后一条消息"""
|
||||
if self.current_message:
|
||||
return self.current_message
|
||||
if self.unread_messages:
|
||||
return self.unread_messages[-1]
|
||||
if self.history_messages:
|
||||
return self.history_messages[-1]
|
||||
return None
|
||||
|
||||
def check_types(self, types: list) -> bool:
|
||||
"""
|
||||
检查当前消息是否支持指定的类型
|
||||
|
||||
Args:
|
||||
types: 需要检查的消息类型列表,如 ["text", "image", "emoji"]
|
||||
|
||||
Returns:
|
||||
bool: 如果消息支持所有指定的类型则返回True,否则返回False
|
||||
"""
|
||||
if not self.current_message:
|
||||
return False
|
||||
|
||||
if not types:
|
||||
# 如果没有指定类型要求,默认为支持
|
||||
return True
|
||||
|
||||
# 优先从additional_config中获取format_info
|
||||
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
|
||||
try:
|
||||
import orjson
|
||||
|
||||
config = orjson.loads(self.current_message.additional_config)
|
||||
|
||||
# 检查format_info结构
|
||||
if "format_info" in config:
|
||||
format_info = config["format_info"]
|
||||
|
||||
# 方法1: 直接检查accept_format字段
|
||||
if "accept_format" in format_info:
|
||||
accept_format = format_info["accept_format"]
|
||||
# 确保accept_format是列表类型
|
||||
if isinstance(accept_format, str):
|
||||
accept_format = [accept_format]
|
||||
elif isinstance(accept_format, list):
|
||||
pass
|
||||
else:
|
||||
# 如果accept_format不是字符串或列表,尝试转换为列表
|
||||
accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else []
|
||||
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in accept_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||
return False
|
||||
return True
|
||||
|
||||
# 方法2: 检查content_format字段(向后兼容)
|
||||
elif "content_format" in format_info:
|
||||
content_format = format_info["content_format"]
|
||||
# 确保content_format是列表类型
|
||||
if isinstance(content_format, str):
|
||||
content_format = [content_format]
|
||||
elif isinstance(content_format, list):
|
||||
pass
|
||||
else:
|
||||
content_format = list(content_format) if hasattr(content_format, "__iter__") else []
|
||||
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in content_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||
return False
|
||||
return True
|
||||
|
||||
except (orjson.JSONDecodeError, AttributeError, TypeError) as e:
|
||||
logger.debug(f"解析消息格式信息失败: {e}")
|
||||
|
||||
# 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型
|
||||
# 大多数消息至少支持text类型
|
||||
default_supported_types = ["text", "emoji"]
|
||||
for requested_type in types:
|
||||
if requested_type not in default_supported_types:
|
||||
logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'")
|
||||
# 对于非基础类型,返回False以避免错误
|
||||
if requested_type not in ["text", "emoji", "reply"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> Optional[str]:
|
||||
"""获取优先级模式"""
|
||||
return self.priority_mode
|
||||
|
||||
def get_priority_info(self) -> Optional[dict]:
|
||||
"""获取优先级信息"""
|
||||
return self.priority_info
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageManagerStats(BaseDataModel):
|
||||
"""消息管理器统计信息"""
|
||||
|
||||
total_streams: int = 0
|
||||
active_streams: int = 0
|
||||
total_unread_messages: int = 0
|
||||
total_processed_messages: int = 0
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
@property
|
||||
def uptime(self) -> float:
|
||||
"""运行时间"""
|
||||
return time.time() - self.start_time
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamStats(BaseDataModel):
|
||||
"""聊天流统计信息"""
|
||||
|
||||
stream_id: str
|
||||
is_active: bool
|
||||
unread_count: int
|
||||
history_count: int
|
||||
last_check_time: float
|
||||
has_active_task: bool
|
||||
Reference in New Issue
Block a user