重构:统一平台字段命名,更新相关数据模型和消息处理逻辑
This commit is contained in:
@@ -57,7 +57,7 @@ class ChatStream:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"platform": self.platform,
|
||||
"platform": self.platform or "",
|
||||
"user_info": self.user_info.to_dict() if self.user_info else None,
|
||||
"group_info": self.group_info.to_dict() if self.group_info else None,
|
||||
"create_time": self.create_time,
|
||||
@@ -81,7 +81,7 @@ class ChatStream:
|
||||
|
||||
instance = cls(
|
||||
stream_id=data["stream_id"],
|
||||
platform=data["platform"],
|
||||
platform=data.get("platform", "") or "",
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
data=data,
|
||||
@@ -342,9 +342,9 @@ class ChatManager:
|
||||
def register_message(self, message: DatabaseMessages):
|
||||
"""注册消息到聊天流"""
|
||||
# 从 DatabaseMessages 提取平台和用户/群组信息
|
||||
from mofox_wire import GroupInfo, UserInfo
|
||||
from src.common.data_models.database_data_model import DatabaseGroupInfo, DatabaseUserInfo
|
||||
|
||||
user_info = UserInfo(
|
||||
user_info = DatabaseUserInfo(
|
||||
platform=message.user_info.platform,
|
||||
user_id=message.user_info.user_id,
|
||||
user_nickname=message.user_info.user_nickname,
|
||||
@@ -353,8 +353,8 @@ class ChatManager:
|
||||
|
||||
group_info = None
|
||||
if message.group_info:
|
||||
group_info = GroupInfo(
|
||||
platform=message.group_info.group_platform or "",
|
||||
group_info = DatabaseGroupInfo(
|
||||
platform=message.group_info.platform or "",
|
||||
group_id=message.group_info.group_id,
|
||||
group_name=message.group_info.group_name
|
||||
)
|
||||
@@ -595,14 +595,14 @@ class ChatManager:
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
fields_to_save = {
|
||||
"platform": s_data_dict["platform"],
|
||||
"platform": s_data_dict.get("platform", "") or "",
|
||||
"create_time": s_data_dict["create_time"],
|
||||
"last_active_time": s_data_dict["last_active_time"],
|
||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||
"group_platform": group_info_d.get("platform", "") or "" if group_info_d else "",
|
||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
"energy_value": s_data_dict.get("energy_value", 5.0),
|
||||
@@ -636,7 +636,7 @@ class ChatManager:
|
||||
await _db_save_stream_async(stream_data_dict)
|
||||
stream.saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (SQLAlchemy): {e}", exc_info=True)
|
||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (SQLAlchemy): {e}")
|
||||
|
||||
async def _save_all_streams(self):
|
||||
"""保存所有聊天流"""
|
||||
|
||||
@@ -41,7 +41,7 @@ from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo, DatabaseGroupInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager
|
||||
@@ -226,7 +226,7 @@ class MessageHandler:
|
||||
logger.debug(f"消息处理流程控制: {exc}")
|
||||
else:
|
||||
message_id = envelope.get("message_info", {}).get("message_id", "UNKNOWN")
|
||||
logger.error(f"处理消息 {message_id} 时出错: {exc}", exc_info=True)
|
||||
logger.error(f"处理消息 {message_id} 时出错: {exc}")
|
||||
|
||||
async def _handle_adapter_response_route(self, envelope: MessageEnvelope) -> MessageEnvelope | None:
|
||||
"""
|
||||
@@ -264,12 +264,12 @@ class MessageHandler:
|
||||
|
||||
# 获取或创建聊天流
|
||||
platform = message_info.get("platform", "unknown")
|
||||
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
user_info=DatabaseUserInfo.from_dict(user_info) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(group_info) if group_info else None,
|
||||
)
|
||||
|
||||
# 将消息信封转换为 DatabaseMessages
|
||||
|
||||
@@ -494,7 +494,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> MessageInf
|
||||
group_info: GroupInfoPayload | None = None
|
||||
if db_message.group_info:
|
||||
group_info = {
|
||||
"platform": db_message.group_info.group_platform or "",
|
||||
"platform": db_message.group_info.platform or "",
|
||||
"group_id": db_message.group_info.group_id,
|
||||
"group_name": db_message.group_info.group_name,
|
||||
}
|
||||
|
||||
@@ -161,11 +161,11 @@ class MessageStorageBatcher:
|
||||
if processed_plain_text:
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
filtered_processed_plain_text = re.sub(
|
||||
pattern, processed_plain_text or "", flags=re.DOTALL
|
||||
pattern, "", processed_plain_text or "", flags=re.DOTALL
|
||||
)
|
||||
|
||||
display_message = message.display_message or message.processed_plain_text or ""
|
||||
filtered_display_message = re.sub(pattern, display_message, flags=re.DOTALL)
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
|
||||
msg_id = message.message_id
|
||||
msg_time = message.time
|
||||
@@ -202,7 +202,7 @@ class MessageStorageBatcher:
|
||||
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
|
||||
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
|
||||
chat_info_group_platform = message.group_info.platform if message.group_info else None
|
||||
chat_info_group_id = message.group_info.group_id if message.group_info else None
|
||||
chat_info_group_name = message.group_info.group_name if message.group_info else None
|
||||
|
||||
|
||||
@@ -1809,7 +1809,7 @@ class DefaultReplyer:
|
||||
message_info["group_info"] = {
|
||||
"group_id": self.chat_stream.group_info.group_id,
|
||||
"group_name": self.chat_stream.group_info.group_name,
|
||||
"platform": self.chat_stream.group_info.group_platform,
|
||||
"platform": self.chat_stream.group_info.platform,
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -42,7 +42,7 @@ class DatabaseGroupInfo(BaseDataModel):
|
||||
"""
|
||||
group_id: str = field(default_factory=str) # 群组唯一标识 ID
|
||||
group_name: str = field(default_factory=str) # 群组名称
|
||||
group_platform: str | None = None # 群组所在平台,可为空
|
||||
platform: str | None = None # 群组所在平台,可为空
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DatabaseGroupInfo":
|
||||
@@ -50,7 +50,7 @@ class DatabaseGroupInfo(BaseDataModel):
|
||||
return cls(
|
||||
group_id=data.get("group_id", ""),
|
||||
group_name=data.get("group_name", ""),
|
||||
group_platform=data.get("group_platform"),
|
||||
platform=data.get("platform"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
@@ -58,7 +58,7 @@ class DatabaseGroupInfo(BaseDataModel):
|
||||
return {
|
||||
"group_id": self.group_id,
|
||||
"group_name": self.group_name,
|
||||
"group_platform": self.group_platform,
|
||||
"group_platform": self.platform,
|
||||
}
|
||||
|
||||
@dataclass
|
||||
@@ -168,7 +168,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.group_info = DatabaseGroupInfo(
|
||||
group_id=chat_info_group_id,
|
||||
group_name=chat_info_group_name,
|
||||
group_platform=chat_info_group_platform,
|
||||
platform=chat_info_group_platform,
|
||||
)
|
||||
|
||||
# 构建聊天信息对象
|
||||
@@ -234,7 +234,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
"user_platform": self.user_info.platform,
|
||||
"chat_info_group_id": self.group_info.group_id if self.group_info else None,
|
||||
"chat_info_group_name": self.group_info.group_name if self.group_info else None,
|
||||
"chat_info_group_platform": self.group_info.group_platform if self.group_info else None,
|
||||
"chat_info_group_platform": self.group_info.platform if self.group_info else None,
|
||||
"chat_info_stream_id": self.chat_info.stream_id,
|
||||
"chat_info_platform": self.chat_info.platform,
|
||||
"chat_info_create_time": self.chat_info.create_time,
|
||||
|
||||
@@ -508,13 +508,14 @@ class StreamContext(BaseDataModel):
|
||||
logger.debug(f"历史信息已初始化,stream={self.stream_id}, 当前条数={len(self.history_messages)}")
|
||||
return
|
||||
|
||||
logger.info(f"?? [历史加载] 开始从数据库读取历史消息: {self.stream_id}")
|
||||
logger.info(f"[历史加载] 开始从数据库读取历史消息: {self.stream_id}")
|
||||
self._history_initialized = True
|
||||
|
||||
try:
|
||||
logger.debug(f"开始加载数据库历史消息: {self.stream_id}")
|
||||
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
db_messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
|
||||
@@ -455,7 +455,7 @@ DEFAULT_MODULE_COLORS = {
|
||||
"main": "#FFFFFF", # 亮白色+粗体 (主程序)
|
||||
"api": "#00FF00", # 亮绿色
|
||||
"emoji": "#FFAF00", # 橙黄色,偏向橙色但与replyer和action_manager不同
|
||||
"chat": "#00FF00", # 亮蓝色
|
||||
"message_handler": "#00FF00", # 亮蓝色
|
||||
"config": "#FFFF00", # 亮黄色
|
||||
"common": "#FF00FF", # 亮紫色
|
||||
"tools": "#00FFFF", # 亮青色
|
||||
@@ -665,7 +665,7 @@ DEFAULT_MODULE_ALIASES = {
|
||||
"memory": "记忆",
|
||||
"tool_executor": "工具",
|
||||
"hfc": "聊天节奏",
|
||||
"chat": "所见",
|
||||
"message_handler": "所见",
|
||||
"anti_injector": "反注入",
|
||||
"anti_injector.detector": "反注入检测",
|
||||
"anti_injector.shield": "反注入加盾",
|
||||
|
||||
@@ -208,7 +208,7 @@ def _build_message_envelope(
|
||||
message_info["group_info"] = {
|
||||
"group_id": target_stream.group_info.group_id,
|
||||
"group_name": target_stream.group_info.group_name,
|
||||
"platform": target_stream.group_info.group_platform,
|
||||
"platform": target_stream.group_info.platform,
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -348,7 +348,7 @@ class ChatterPlanFilter:
|
||||
|
||||
# 获取真正的已读和未读消息
|
||||
read_messages = (
|
||||
stream_context.context.history_messages
|
||||
stream_context.history_messages
|
||||
) # 已读消息存储在history_messages中
|
||||
if not read_messages:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
Reference in New Issue
Block a user