Revert "refactor(core): 提升类型安全性并添加配置空值检查"
This reverts commit abfcf56941.
This commit is contained in:
@@ -43,7 +43,6 @@ from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
from typing import cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -56,25 +55,23 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
"""检查消息是否包含过滤词"""
|
||||
if global_config and global_config.message_receive:
|
||||
for word in global_config.message_receive.ban_words:
|
||||
if word in text:
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return True
|
||||
for word in global_config.message_receive.ban_words:
|
||||
if word in text:
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式"""
|
||||
if global_config and global_config.message_receive:
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return True
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -284,7 +281,7 @@ class MessageHandler:
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=DatabaseUserInfo.from_dict(user_info) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(cast(dict, group_info)) if group_info else None,
|
||||
group_info=DatabaseGroupInfo.from_dict(group_info) if group_info else None,
|
||||
)
|
||||
|
||||
# 将消息信封转换为 DatabaseMessages
|
||||
@@ -434,7 +431,7 @@ class MessageHandler:
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=DatabaseUserInfo.from_dict(user_info) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(cast(dict, group_info)) if group_info else None,
|
||||
group_info=DatabaseGroupInfo.from_dict(group_info) if group_info else None,
|
||||
)
|
||||
|
||||
# 将消息信封转换为 DatabaseMessages
|
||||
@@ -538,9 +535,7 @@ class MessageHandler:
|
||||
text = message.processed_plain_text or ""
|
||||
|
||||
# 获取配置的命令前缀
|
||||
prefixes = []
|
||||
if global_config and global_config.command:
|
||||
prefixes = global_config.command.command_prefixes
|
||||
prefixes = global_config.command.command_prefixes
|
||||
|
||||
# 检查是否以任何前缀开头
|
||||
matched_prefix = None
|
||||
@@ -712,7 +707,7 @@ class MessageHandler:
|
||||
|
||||
# 检查是否需要处理消息
|
||||
should_process_in_manager = True
|
||||
if group_info and global_config and global_config.message_receive and str(group_info.group_id) in global_config.message_receive.mute_group_list:
|
||||
if group_info and str(group_info.group_id) in global_config.message_receive.mute_group_list:
|
||||
is_image_or_emoji = message.is_picid or message.is_emoji
|
||||
if not message.is_mentioned and not is_image_or_emoji:
|
||||
logger.debug(
|
||||
@@ -736,7 +731,7 @@ class MessageHandler:
|
||||
|
||||
# 情绪系统更新
|
||||
try:
|
||||
if global_config and global_config.mood and global_config.mood.enable_mood:
|
||||
if global_config.mood.enable_mood:
|
||||
interest_rate = message.interest_value or 0.0
|
||||
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st
|
||||
}
|
||||
|
||||
# 异步处理消息段,生成纯文本
|
||||
processed_plain_text = await _process_message_segments(message_segment, processing_state)
|
||||
processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info)
|
||||
|
||||
# 解析 notice 信息
|
||||
is_notify = False
|
||||
@@ -155,13 +155,15 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st
|
||||
|
||||
async def _process_message_segments(
|
||||
segment: SegPayload | list[SegPayload],
|
||||
state: dict
|
||||
state: dict,
|
||||
message_info: MessageInfoPayload
|
||||
) -> str:
|
||||
"""递归处理消息段,转换为文字描述
|
||||
|
||||
Args:
|
||||
segment: 要处理的消息段(TypedDict 或列表)
|
||||
state: 处理状态字典(用于记录消息类型标记)
|
||||
message_info: 消息基础信息(TypedDict 格式)
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
@@ -170,7 +172,7 @@ async def _process_message_segments(
|
||||
if isinstance(segment, list):
|
||||
segments_text = []
|
||||
for seg in segment:
|
||||
processed = await _process_message_segments(seg, state)
|
||||
processed = await _process_message_segments(seg, state, message_info)
|
||||
if processed:
|
||||
segments_text.append(processed)
|
||||
return " ".join(segments_text)
|
||||
@@ -184,26 +186,28 @@ async def _process_message_segments(
|
||||
if seg_type == "seglist" and isinstance(seg_data, list):
|
||||
segments_text = []
|
||||
for sub_seg in seg_data:
|
||||
processed = await _process_message_segments(sub_seg, state)
|
||||
processed = await _process_message_segments(sub_seg, state, message_info)
|
||||
if processed:
|
||||
segments_text.append(processed)
|
||||
return " ".join(segments_text)
|
||||
|
||||
# 处理其他类型
|
||||
return await _process_single_segment(segment, state)
|
||||
return await _process_single_segment(segment, state, message_info)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
async def _process_single_segment(
|
||||
segment: SegPayload,
|
||||
state: dict
|
||||
state: dict,
|
||||
message_info: MessageInfoPayload
|
||||
) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段(TypedDict 格式)
|
||||
state: 处理状态字典
|
||||
message_info: 消息基础信息(TypedDict 格式)
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
@@ -230,6 +234,7 @@ async def _process_single_segment(
|
||||
return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户"
|
||||
|
||||
elif seg_type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg_data, str):
|
||||
state["has_picid"] = True
|
||||
state["is_picid"] = True
|
||||
@@ -242,17 +247,27 @@ async def _process_single_segment(
|
||||
state["has_emoji"] = True
|
||||
state["is_emoji"] = True
|
||||
if isinstance(seg_data, str):
|
||||
image_manager = get_image_manager()
|
||||
return await image_manager.get_emoji_description(seg_data)
|
||||
return await get_image_manager().get_emoji_description(seg_data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
|
||||
elif seg_type == "voice":
|
||||
state["is_voice"] = True
|
||||
# 检查是否是自己发送的语音
|
||||
|
||||
# 检查消息是否由机器人自己发送
|
||||
user_info = message_info.get("user_info", {})
|
||||
user_id_str = str(user_info.get("user_id", ""))
|
||||
if user_id_str == str(global_config.bot.qq_account):
|
||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {user_id_str}),尝试从缓存获取文本。")
|
||||
if isinstance(seg_data, str):
|
||||
cached_text = consume_self_voice_text(seg_data)
|
||||
if cached_text:
|
||||
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
||||
return f"[语音:{cached_text}]"
|
||||
else:
|
||||
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||
|
||||
# 标准语音识别流程
|
||||
if isinstance(seg_data, str):
|
||||
cached_text = consume_self_voice_text(seg_data)
|
||||
if cached_text:
|
||||
return f"[语音:{cached_text}]"
|
||||
return await get_voice_text(seg_data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
|
||||
@@ -284,7 +299,7 @@ async def _process_single_segment(
|
||||
logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
|
||||
return "[视频]"
|
||||
|
||||
if global_config and global_config.video_analysis and global_config.video_analysis.enable:
|
||||
if global_config.video_analysis.enable:
|
||||
logger.info("已启用视频识别,开始识别")
|
||||
if isinstance(seg_data, dict):
|
||||
try:
|
||||
@@ -302,9 +317,8 @@ async def _process_single_segment(
|
||||
|
||||
# 使用video analyzer分析视频
|
||||
video_analyzer = get_video_analyzer()
|
||||
prompt = global_config.video_analysis.batch_analysis_prompt if global_config and global_config.video_analysis else ""
|
||||
result = await video_analyzer.analyze_video_from_bytes(
|
||||
video_bytes, filename, prompt=prompt
|
||||
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
|
||||
)
|
||||
|
||||
logger.info(f"视频分析结果: {result}")
|
||||
|
||||
@@ -13,10 +13,9 @@ import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import TaskConfig
|
||||
from src.memory_graph.manager import MemoryManager
|
||||
from src.memory_graph.long_term_manager import LongTermMemoryManager
|
||||
from src.memory_graph.models import JudgeDecision, MemoryBlock, ShortTermMemory
|
||||
@@ -84,7 +83,7 @@ class UnifiedMemoryManager:
|
||||
self.long_term_manager: LongTermMemoryManager
|
||||
|
||||
# 底层 MemoryManager(长期记忆)
|
||||
self.memory_manager: MemoryManager = cast(MemoryManager, memory_manager)
|
||||
self.memory_manager: MemoryManager = memory_manager
|
||||
|
||||
# 配置参数存储(用于初始化)
|
||||
self._config = {
|
||||
@@ -331,11 +330,7 @@ class UnifiedMemoryManager:
|
||||
|
||||
"""
|
||||
|
||||
prompt = f"""你是一个记忆检索评估专家。你的任务是判断当前检索到的“感知记忆”(即时对话)和“短期记忆”(结构化信息)是否足以支撑一次有深度、有上下文的回复。
|
||||
|
||||
**核心原则:**
|
||||
- **不要轻易检索长期记忆!** 只有在当前对话需要深入探讨、回忆过去复杂事件或需要特定背景知识时,才认为记忆不足。
|
||||
- **闲聊、简单问候、表情互动或无特定主题的对话,现有记忆通常是充足的。** 频繁检索长期记忆会拖慢响应速度。
|
||||
prompt = f"""你是一个记忆检索评估专家。请判断检索到的记忆是否足以回答用户的问题。
|
||||
|
||||
**用户查询:**
|
||||
{query}
|
||||
@@ -346,36 +341,27 @@ class UnifiedMemoryManager:
|
||||
**检索到的短期记忆(结构化信息,自然语言描述):**
|
||||
{short_term_desc or '(无)'}
|
||||
|
||||
**评估指南:**
|
||||
1. **分析用户意图**:用户是在闲聊,还是在讨论一个需要深入挖掘的话题?
|
||||
2. **检查现有记忆**:当前的感知和短期记忆是否已经包含了足够的信息来回应用户的查询?
|
||||
- 对于闲聊(如“你好”、“哈哈”、“[表情]”),现有记忆总是充足的 (`"is_sufficient": true`)。
|
||||
- 对于需要回忆具体细节、深入探讨个人经历或专业知识的查询,如果现有记忆中没有相关信息,则可能不充足。
|
||||
3. **决策**:
|
||||
- 如果记忆充足,设置 `"is_sufficient": true`。
|
||||
- 如果确实需要更多信息才能进行有意义的对话,设置 `"is_sufficient": false`,并提供具体的补充查询。
|
||||
**任务要求:**
|
||||
1. 判断这些记忆是否足以回答用户的问题
|
||||
2. 如果不充足,分析缺少哪些方面的信息
|
||||
3. 生成额外需要检索的 query(用于在长期记忆中检索)
|
||||
|
||||
**输出格式(JSON):**
|
||||
```json
|
||||
{{
|
||||
"is_sufficient": true/false,
|
||||
"confidence": 0.85,
|
||||
"reasoning": "在这里解释你的判断理由。例如:‘用户只是在打招呼,现有记忆已足够’或‘用户问到了一个具体的历史事件,需要检索长期记忆’。",
|
||||
"reasoning": "判断理由",
|
||||
"missing_aspects": ["缺失的信息1", "缺失的信息2"],
|
||||
"additional_queries": ["补充query1", "补充query2"]
|
||||
}}
|
||||
```
|
||||
|
||||
请严格按照上述原则进行判断,并输出JSON:"""
|
||||
请输出JSON:"""
|
||||
|
||||
# 调用记忆裁判模型
|
||||
model_set = (
|
||||
model_config.model_task_config.memory_judge
|
||||
if model_config and model_config.model_task_config
|
||||
else TaskConfig(model_name="deepseek/deepseek-v2", provider="deepseek")
|
||||
)
|
||||
llm = LLMRequest(
|
||||
model_set=model_set,
|
||||
model_set=model_config.model_task_config.memory_judge,
|
||||
request_type="unified_memory.judge",
|
||||
)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class CoreActionsPlugin(BasePlugin):
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
components: ClassVar = []
|
||||
|
||||
# 注册 reply 动作
|
||||
if self.get_config("components.enable_reply", True):
|
||||
|
||||
@@ -317,9 +317,6 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
"ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否忽略不是针对自己的戳一戳消息"),
|
||||
"poke_debounce_seconds": ConfigField(type=float, default=2.0, description="戳一戳防抖时间(秒)"),
|
||||
"enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复处理"),
|
||||
"enable_reply_at": ConfigField(type=bool, default=True, description="是否在回复时自动@原消息发送者"),
|
||||
"reply_at_rate": ConfigField(type=float, default=0.5, description="回复时@的概率(0.0-1.0)"),
|
||||
"enable_video_processing": ConfigField(type=bool, default=True, description="是否启用视频消息处理(下载和解析)"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import base64
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
import uuid
|
||||
|
||||
from mofox_wire import MessageBuilder
|
||||
@@ -214,9 +214,6 @@ class MessageHandler:
|
||||
case RealMessageType.record:
|
||||
return await self._handle_record_message(segment)
|
||||
case RealMessageType.video:
|
||||
if not config_api.get_plugin_config(self.plugin_config, "features.enable_video_processing", True):
|
||||
logger.debug("视频消息处理已禁用,跳过")
|
||||
return None
|
||||
return await self._handle_video_message(segment)
|
||||
case RealMessageType.rps:
|
||||
return await self._handle_rps_message(segment)
|
||||
@@ -334,13 +331,10 @@ class MessageHandler:
|
||||
{"type": seg.get("type", "text"), "data": seg.get("data", "")} for seg in reply_segments
|
||||
] or [{"type": "text", "data": "[无法获取被引用的消息]"}]
|
||||
|
||||
return cast(
|
||||
SegPayload,
|
||||
{
|
||||
"type": "seglist",
|
||||
"data": [{"type": "text", "data": prefix_text}, *brief_segments, {"type": "text", "data": suffix_text}],
|
||||
},
|
||||
)
|
||||
return {
|
||||
"type": "seglist",
|
||||
"data": [{"type": "text", "data": prefix_text}, *brief_segments, {"type": "text", "data": suffix_text}],
|
||||
}
|
||||
|
||||
async def _handle_record_message(self, segment: dict) -> SegPayload | None:
|
||||
"""处理语音消息"""
|
||||
@@ -386,17 +380,14 @@ class MessageHandler:
|
||||
video_base64 = base64.b64encode(video_data).decode("utf-8")
|
||||
logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
|
||||
|
||||
return cast(
|
||||
SegPayload,
|
||||
{
|
||||
"type": "video",
|
||||
"data": {
|
||||
"base64": video_base64,
|
||||
"filename": Path(file_path).name,
|
||||
"size_mb": len(video_data) / (1024 * 1024),
|
||||
},
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {
|
||||
"base64": video_base64,
|
||||
"filename": Path(file_path).name,
|
||||
"size_mb": len(video_data) / (1024 * 1024),
|
||||
},
|
||||
)
|
||||
}
|
||||
elif video_url:
|
||||
# URL下载处理
|
||||
from ..video_handler import get_video_downloader
|
||||
@@ -410,18 +401,15 @@ class MessageHandler:
|
||||
video_base64 = base64.b64encode(download_result["data"]).decode("utf-8")
|
||||
logger.debug(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB")
|
||||
|
||||
return cast(
|
||||
SegPayload,
|
||||
{
|
||||
"type": "video",
|
||||
"data": {
|
||||
"base64": video_base64,
|
||||
"filename": download_result.get("filename", "video.mp4"),
|
||||
"size_mb": len(download_result["data"]) / (1024 * 1024),
|
||||
"url": video_url,
|
||||
},
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {
|
||||
"base64": video_base64,
|
||||
"filename": download_result.get("filename", "video.mp4"),
|
||||
"size_mb": len(download_result["data"]) / (1024 * 1024),
|
||||
"url": video_url,
|
||||
},
|
||||
)
|
||||
}
|
||||
else:
|
||||
logger.warning("既没有有效的本地文件路径,也没有有效的视频URL")
|
||||
return None
|
||||
@@ -466,39 +454,34 @@ class MessageHandler:
|
||||
processed_message = handled_message
|
||||
|
||||
forward_hint = {"type": "text", "data": "这是一条转发消息:\n"}
|
||||
return cast(SegPayload, {"type": "seglist", "data": [forward_hint, processed_message]})
|
||||
return {"type": "seglist", "data": [forward_hint, processed_message]}
|
||||
|
||||
async def _recursive_parse_image_seg(self, seg_data: SegPayload, to_image: bool) -> SegPayload:
|
||||
# sourcery skip: merge-else-if-into-elif
|
||||
if seg_data.get("type") == "seglist":
|
||||
new_seg_list = []
|
||||
for i_seg in seg_data.get("data", []):
|
||||
if isinstance(i_seg, dict): # 确保是字典类型
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
return {"type": "seglist", "data": new_seg_list}
|
||||
|
||||
if to_image:
|
||||
if seg_data.get("type") == "image":
|
||||
image_url = seg_data.get("data")
|
||||
if isinstance(image_url, str):
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return {"type": "text", "data": "[图片]"}
|
||||
return {"type": "image", "data": encoded_image}
|
||||
return {"type": "text", "data": "[图片]"}
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return {"type": "text", "data": "[图片]"}
|
||||
return {"type": "image", "data": encoded_image}
|
||||
if seg_data.get("type") == "emoji":
|
||||
image_url = seg_data.get("data")
|
||||
if isinstance(image_url, str):
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return {"type": "text", "data": "[表情包]"}
|
||||
return {"type": "emoji", "data": encoded_image}
|
||||
return {"type": "text", "data": "[表情包]"}
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return {"type": "text", "data": "[表情包]"}
|
||||
return {"type": "emoji", "data": encoded_image}
|
||||
logger.debug(f"不处理类型: {seg_data.get('type')}")
|
||||
return seg_data
|
||||
|
||||
@@ -612,7 +595,7 @@ class MessageHandler:
|
||||
"id": file_id,
|
||||
}
|
||||
|
||||
return cast(SegPayload, {"type": "file", "data": file_data})
|
||||
return {"type": "file", "data": file_data}
|
||||
|
||||
async def _handle_json_message(self, segment: dict) -> SegPayload | None:
|
||||
"""
|
||||
@@ -640,7 +623,7 @@ class MessageHandler:
|
||||
# 从回声消息中提取文件信息
|
||||
file_info = self._extract_file_info_from_echo(nested_data)
|
||||
if file_info:
|
||||
return cast(SegPayload, {"type": "file", "data": file_info})
|
||||
return {"type": "file", "data": file_info}
|
||||
|
||||
# 检查是否是QQ小程序分享消息
|
||||
if "app" in nested_data and "com.tencent.miniapp" in str(nested_data.get("app", "")):
|
||||
|
||||
Reference in New Issue
Block a user