refactor(core): 统一消息对象类型并增强代码健壮性
本次提交对多个核心模块进行了重构和修复,主要目标是统一内部消息对象的类型为 `DatabaseMessages`,并增加多处空值检查和类型注解,以提升代码的健壮性和可维护性。
- **统一消息类型**: 在 `action_manager` 中,将 `action_message` 和 `target_message` 的类型注解和处理逻辑统一为 `DatabaseMessages`,消除了对 `dict` 类型的兼容代码,使逻辑更清晰。
- **增强健壮性**:
- 在 `permission_api` 中,为所有对外方法增加了对 `_permission_manager` 未初始化时的空值检查,防止在管理器未就绪时调用引发异常。
- 在 `chat_api` 和 `cross_context_api` 中,增加了对 `stream.user_info` 的存在性检查,避免在私聊场景下 `user_info` 为空时导致 `AttributeError`。
- **类型修复**: 修正了 `action_modifier` 和 `plugin_base` 中的类型注解错误,并解决了 `action_modifier` 中因 `chat_stream` 未初始化可能导致的潜在问题。
- **代码简化**: 移除了 `action_manager` 中因兼容 `dict` 类型而产生的冗余代码分支,使逻辑更直接。
This commit is contained in:
committed by
Windpicker-owo
parent
3e8373e0ec
commit
8a0075ee92
@@ -51,7 +51,7 @@ class ChatterActionManager:
|
|||||||
chat_stream: ChatStream,
|
chat_stream: ChatStream,
|
||||||
log_prefix: str,
|
log_prefix: str,
|
||||||
shutting_down: bool = False,
|
shutting_down: bool = False,
|
||||||
action_message: dict | None = None,
|
action_message: DatabaseMessages | None = None,
|
||||||
) -> BaseAction | None:
|
) -> BaseAction | None:
|
||||||
"""
|
"""
|
||||||
创建动作处理器实例
|
创建动作处理器实例
|
||||||
@@ -143,7 +143,7 @@ class ChatterActionManager:
|
|||||||
self,
|
self,
|
||||||
action_name: str,
|
action_name: str,
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
target_message: dict | DatabaseMessages | None = None,
|
target_message: DatabaseMessages | None = None,
|
||||||
reasoning: str = "",
|
reasoning: str = "",
|
||||||
action_data: dict | None = None,
|
action_data: dict | None = None,
|
||||||
thinking_id: str | None = None,
|
thinking_id: str | None = None,
|
||||||
@@ -264,10 +264,8 @@ class ChatterActionManager:
|
|||||||
)
|
)
|
||||||
if not success or not response_set:
|
if not success or not response_set:
|
||||||
# 安全地获取 processed_plain_text
|
# 安全地获取 processed_plain_text
|
||||||
if isinstance(target_message, DatabaseMessages):
|
if target_message:
|
||||||
msg_text = target_message.processed_plain_text or "未知消息"
|
msg_text = target_message.processed_plain_text or "未知消息"
|
||||||
elif target_message:
|
|
||||||
msg_text = target_message.get("processed_plain_text", "未知消息")
|
|
||||||
else:
|
else:
|
||||||
msg_text = "未知消息"
|
msg_text = "未知消息"
|
||||||
|
|
||||||
@@ -336,10 +334,7 @@ class ChatterActionManager:
|
|||||||
# 获取目标消息ID
|
# 获取目标消息ID
|
||||||
target_message_id = None
|
target_message_id = None
|
||||||
if target_message:
|
if target_message:
|
||||||
if isinstance(target_message, DatabaseMessages):
|
target_message_id = target_message.message_id
|
||||||
target_message_id = target_message.message_id
|
|
||||||
elif isinstance(target_message, dict):
|
|
||||||
target_message_id = target_message.get("message_id")
|
|
||||||
elif action_data and isinstance(action_data, dict):
|
elif action_data and isinstance(action_data, dict):
|
||||||
target_message_id = action_data.get("target_message_id")
|
target_message_id = action_data.get("target_message_id")
|
||||||
|
|
||||||
@@ -508,14 +503,12 @@ class ChatterActionManager:
|
|||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
|
|
||||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||||
if isinstance(action_message, DatabaseMessages):
|
if action_message:
|
||||||
platform = action_message.chat_info.platform
|
platform = action_message.chat_info.platform
|
||||||
user_id = action_message.user_info.user_id
|
user_id = action_message.user_info.user_id
|
||||||
else:
|
else:
|
||||||
platform = action_message.get("chat_info_platform")
|
platform = getattr(chat_stream, "platform", "unknown")
|
||||||
if platform is None:
|
user_id = ""
|
||||||
platform = getattr(chat_stream, "platform", "unknown")
|
|
||||||
user_id = action_message.get("user_id", "")
|
|
||||||
|
|
||||||
# 获取用户信息并生成回复提示
|
# 获取用户信息并生成回复提示
|
||||||
person_id = person_info_manager.get_person_id(
|
person_id = person_info_manager.get_person_id(
|
||||||
@@ -593,11 +586,8 @@ class ChatterActionManager:
|
|||||||
# 根据新消息数量决定是否需要引用回复
|
# 根据新消息数量决定是否需要引用回复
|
||||||
reply_text = ""
|
reply_text = ""
|
||||||
# 检查是否为主动思考消息
|
# 检查是否为主动思考消息
|
||||||
if isinstance(message_data, DatabaseMessages):
|
if message_data:
|
||||||
# DatabaseMessages 对象没有 message_type 字段,默认为 False
|
is_proactive_thinking = getattr(message_data, "message_type", None) == "proactive_thinking"
|
||||||
is_proactive_thinking = False
|
|
||||||
elif message_data:
|
|
||||||
is_proactive_thinking = message_data.get("message_type") == "proactive_thinking"
|
|
||||||
else:
|
else:
|
||||||
is_proactive_thinking = True
|
is_proactive_thinking = True
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,9 @@ import asyncio
|
|||||||
import hashlib
|
import hashlib
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
import orjson
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||||
from src.common.data_models.message_manager_data_model import StreamContext
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
@@ -34,7 +32,7 @@ class ActionModifier:
|
|||||||
"""初始化动作处理器"""
|
"""初始化动作处理器"""
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
# chat_stream 和 log_prefix 将在异步方法中初始化
|
# chat_stream 和 log_prefix 将在异步方法中初始化
|
||||||
self.chat_stream = None # type: ignore
|
self.chat_stream: ChatStream | None = None
|
||||||
self.log_prefix = f"[{chat_id}]"
|
self.log_prefix = f"[{chat_id}]"
|
||||||
|
|
||||||
self.action_manager = action_manager
|
self.action_manager = action_manager
|
||||||
@@ -113,7 +111,7 @@ class ActionModifier:
|
|||||||
logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}")
|
logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}")
|
||||||
|
|
||||||
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=self.chat_stream.stream_id,
|
chat_id=self.chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||||
)
|
)
|
||||||
@@ -139,6 +137,9 @@ class ActionModifier:
|
|||||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||||
|
|
||||||
# === 第二阶段:检查动作的关联类型 ===
|
# === 第二阶段:检查动作的关联类型 ===
|
||||||
|
if not self.chat_stream:
|
||||||
|
logger.error(f"{self.log_prefix} chat_stream 未初始化,无法执行第二阶段")
|
||||||
|
return
|
||||||
chat_context = self.chat_stream.context_manager.context
|
chat_context = self.chat_stream.context_manager.context
|
||||||
current_actions_s2 = self.action_manager.get_using_actions()
|
current_actions_s2 = self.action_manager.get_using_actions()
|
||||||
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)
|
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)
|
||||||
@@ -331,6 +332,7 @@ class ActionModifier:
|
|||||||
deactivated_actions = []
|
deactivated_actions = []
|
||||||
|
|
||||||
# 获取 Action 类注册表
|
# 获取 Action 类注册表
|
||||||
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
from src.plugin_system.base.component_types import ComponentType
|
from src.plugin_system.base.component_types import ComponentType
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
@@ -354,15 +356,13 @@ class ActionModifier:
|
|||||||
try:
|
try:
|
||||||
# 创建一个最小化的实例
|
# 创建一个最小化的实例
|
||||||
action_instance = object.__new__(action_class)
|
action_instance = object.__new__(action_class)
|
||||||
|
# 使用 cast 来“欺骗”类型检查器
|
||||||
|
action_instance = cast(BaseAction, action_instance)
|
||||||
# 设置必要的属性
|
# 设置必要的属性
|
||||||
action_instance.action_name = action_name
|
|
||||||
action_instance.log_prefix = self.log_prefix
|
action_instance.log_prefix = self.log_prefix
|
||||||
# 设置聊天内容,用于激活判断
|
# 调用 go_activate 方法
|
||||||
action_instance._activation_chat_content = chat_content
|
|
||||||
|
|
||||||
# 调用 go_activate 方法(不再需要传入 chat_content)
|
|
||||||
task = action_instance.go_activate(
|
task = action_instance.go_activate(
|
||||||
llm_judge_model=self.llm_judge,
|
llm_judge_model=self.llm_judge
|
||||||
)
|
)
|
||||||
activation_tasks.append(task)
|
activation_tasks.append(task)
|
||||||
task_action_names.append(action_name)
|
task_action_names.append(action_name)
|
||||||
|
|||||||
@@ -172,6 +172,7 @@ class ChatManager:
|
|||||||
for stream in get_chat_manager().streams.values():
|
for stream in get_chat_manager().streams.values():
|
||||||
if (
|
if (
|
||||||
not stream.group_info
|
not stream.group_info
|
||||||
|
and stream.user_info
|
||||||
and str(stream.user_info.user_id) == str(user_id)
|
and str(stream.user_info.user_id) == str(user_id)
|
||||||
and stream.platform == platform
|
and stream.platform == platform
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -32,8 +32,10 @@ async def get_context_group(chat_id: str) -> ContextGroup | None:
|
|||||||
if is_group:
|
if is_group:
|
||||||
assert current_stream.group_info is not None
|
assert current_stream.group_info is not None
|
||||||
current_chat_raw_id = current_stream.group_info.group_id
|
current_chat_raw_id = current_stream.group_info.group_id
|
||||||
else:
|
elif current_stream.user_info:
|
||||||
current_chat_raw_id = current_stream.user_info.user_id
|
current_chat_raw_id = current_stream.user_info.user_id
|
||||||
|
else:
|
||||||
|
return None
|
||||||
current_type = "group" if is_group else "private"
|
current_type = "group" if is_group else "private"
|
||||||
|
|
||||||
for group in global_config.cross_context.groups:
|
for group in global_config.cross_context.groups:
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class PermissionAPI:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._permission_manager: IPermissionManager | None = None
|
self._permission_manager: IPermissionManager | None = None
|
||||||
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
|
# 需要保留的前缀(视为绝对节点名,不再自动加 plugins.<plugin>. 前缀)
|
||||||
self.RESERVED_PREFIXES: tuple[str, ...] = "system."
|
self.RESERVED_PREFIXES: tuple[str, ...] = ("system.",)
|
||||||
# 系统节点列表 (name, description, default_granted)
|
# 系统节点列表 (name, description, default_granted)
|
||||||
self._SYSTEM_NODES: list[tuple[str, str, bool]] = [
|
self._SYSTEM_NODES: list[tuple[str, str, bool]] = [
|
||||||
("system.superuser", "系统超级管理员:拥有所有权限", False),
|
("system.superuser", "系统超级管理员:拥有所有权限", False),
|
||||||
@@ -80,10 +80,14 @@ class PermissionAPI:
|
|||||||
|
|
||||||
async def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
async def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
|
if not self._permission_manager:
|
||||||
|
return False
|
||||||
return await self._permission_manager.check_permission(UserInfo(platform, user_id), permission_node)
|
return await self._permission_manager.check_permission(UserInfo(platform, user_id), permission_node)
|
||||||
|
|
||||||
async def is_master(self, platform: str, user_id: str) -> bool:
|
async def is_master(self, platform: str, user_id: str) -> bool:
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
|
if not self._permission_manager:
|
||||||
|
return False
|
||||||
return await self._permission_manager.is_master(UserInfo(platform, user_id))
|
return await self._permission_manager.is_master(UserInfo(platform, user_id))
|
||||||
|
|
||||||
async def register_permission_node(
|
async def register_permission_node(
|
||||||
@@ -109,6 +113,8 @@ class PermissionAPI:
|
|||||||
if original_name != node_name:
|
if original_name != node_name:
|
||||||
logger.debug(f"规范化权限节点 '{original_name}' -> '{node_name}'")
|
logger.debug(f"规范化权限节点 '{original_name}' -> '{node_name}'")
|
||||||
node = PermissionNode(node_name, description, plugin_name, default_granted)
|
node = PermissionNode(node_name, description, plugin_name, default_granted)
|
||||||
|
if not self._permission_manager:
|
||||||
|
return False
|
||||||
return await self._permission_manager.register_permission_node(node)
|
return await self._permission_manager.register_permission_node(node)
|
||||||
|
|
||||||
async def register_system_permission_node(
|
async def register_system_permission_node(
|
||||||
@@ -141,18 +147,26 @@ class PermissionAPI:
|
|||||||
|
|
||||||
async def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
async def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
|
if not self._permission_manager:
|
||||||
|
return False
|
||||||
return await self._permission_manager.grant_permission(UserInfo(platform, user_id), permission_node)
|
return await self._permission_manager.grant_permission(UserInfo(platform, user_id), permission_node)
|
||||||
|
|
||||||
async def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
async def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool:
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
|
if not self._permission_manager:
|
||||||
|
return False
|
||||||
return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node)
|
return await self._permission_manager.revoke_permission(UserInfo(platform, user_id), permission_node)
|
||||||
|
|
||||||
async def get_user_permissions(self, platform: str, user_id: str) -> list[str]:
|
async def get_user_permissions(self, platform: str, user_id: str) -> list[str]:
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
|
if not self._permission_manager:
|
||||||
|
return []
|
||||||
return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id))
|
return await self._permission_manager.get_user_permissions(UserInfo(platform, user_id))
|
||||||
|
|
||||||
async def get_all_permission_nodes(self) -> list[dict[str, Any]]:
|
async def get_all_permission_nodes(self) -> list[dict[str, Any]]:
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
|
if not self._permission_manager:
|
||||||
|
return []
|
||||||
nodes = await self._permission_manager.get_all_permission_nodes()
|
nodes = await self._permission_manager.get_all_permission_nodes()
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
@@ -166,6 +180,8 @@ class PermissionAPI:
|
|||||||
|
|
||||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[dict[str, Any]]:
|
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[dict[str, Any]]:
|
||||||
self._ensure_manager()
|
self._ensure_manager()
|
||||||
|
if not self._permission_manager:
|
||||||
|
return []
|
||||||
nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name)
|
nodes = await self._permission_manager.get_plugin_permission_nodes(plugin_name)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -331,7 +331,7 @@ class PluginBase(ABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with open(user_config_path, encoding="utf-8") as f:
|
with open(user_config_path, encoding="utf-8") as f:
|
||||||
user_config: ClassVar = toml.load(f) or {}
|
user_config: dict[str, Any] = toml.load(f) or {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True)
|
logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True)
|
||||||
self.config = self._generate_config_from_schema() # 加载失败时使用默认 schema
|
self.config = self._generate_config_from_schema() # 加载失败时使用默认 schema
|
||||||
|
|||||||
Reference in New Issue
Block a user