typing fix

This commit is contained in:
UnCLAS-Prommer
2025-07-17 00:10:41 +08:00
parent 6e838ccc74
commit 1aa2734d62
26 changed files with 329 additions and 293 deletions

View File

@@ -61,7 +61,7 @@ __all__ = [
"ConfigField",
# 工具函数
"ManifestValidator",
"ManifestGenerator",
"validate_plugin_manifest",
"generate_plugin_manifest",
# "ManifestGenerator",
# "validate_plugin_manifest",
# "generate_plugin_manifest",
]

View File

@@ -111,7 +111,7 @@ async def _send_to_target(
is_head=True,
is_emoji=(message_type == "emoji"),
thinking_start_time=current_time,
reply_to = reply_to_platform_id
reply_to=reply_to_platform_id,
)
# 发送消息
@@ -137,6 +137,7 @@ async def _send_to_target(
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
# sourcery skip: inline-variable, use-named-expression
"""查找要回复的消息
Args:
@@ -184,14 +185,11 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
# 检查是否有 回复<aaa:bbb> 字段
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, translate_text)
if match:
if match := re.search(reply_pattern, translate_text):
aaa = match.group(1)
bbb = match.group(2)
reply_person_id = get_person_info_manager().get_person_id(platform, bbb)
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name")
if not reply_person_name:
reply_person_name = aaa
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") or aaa
# 在内容前加上回复信息
translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1)
@@ -206,9 +204,7 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
aaa = m.group(1)
bbb = m.group(2)
at_person_id = get_person_info_manager().get_person_id(platform, bbb)
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name")
if not at_person_name:
at_person_name = aaa
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") or aaa
new_content += f"@{at_person_name}"
last_end = m.end()
new_content += translate_text[last_end:]
@@ -370,7 +366,14 @@ async def custom_to_stream(
bool: 是否发送成功
"""
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log
message_type,
content,
stream_id,
display_message,
typing,
reply_to,
storage_message=storage_message,
show_log=show_log,
)
@@ -396,7 +399,7 @@ async def text_to_group(
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def text_to_user(
@@ -420,7 +423,7 @@ async def text_to_user(
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
@@ -543,7 +546,9 @@ async def custom_to_group(
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_to_user(
@@ -571,7 +576,9 @@ async def custom_to_user(
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_message(
@@ -611,4 +618,6 @@ async def custom_message(
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
"""
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)

View File

@@ -38,7 +38,7 @@ class BaseAction(ABC):
chat_stream: ChatStream,
log_prefix: str = "",
plugin_config: Optional[dict] = None,
action_message: dict = None,
action_message: Optional[dict] = None,
**kwargs,
):
"""初始化Action组件
@@ -63,7 +63,7 @@ class BaseAction(ABC):
self.cycle_timers = cycle_timers
self.thinking_id = thinking_id
self.log_prefix = log_prefix
# 保存插件配置
self.plugin_config = plugin_config or {}
@@ -92,10 +92,10 @@ class BaseAction(ABC):
self.chat_stream = chat_stream or kwargs.get("chat_stream")
self.chat_id = self.chat_stream.stream_id
self.platform = getattr(self.chat_stream, "platform", None)
# 初始化基础信息(带类型注解)
self.action_message = action_message
self.group_id = None
self.group_name = None
self.user_id = None
@@ -103,15 +103,17 @@ class BaseAction(ABC):
self.is_group = False
self.target_id = None
self.has_action_message = False
if self.action_message:
self.has_action_message = True
else:
self.action_message = {}
if self.has_action_message:
if self.action_name != "no_reply":
self.group_id = str(self.action_message.get("chat_info_group_id", None))
self.group_name = self.action_message.get("chat_info_group_name", None)
self.user_id = str(self.action_message.get("user_id", None))
self.user_nickname = self.action_message.get("user_nickname", None)
if self.group_id:
@@ -132,8 +134,6 @@ class BaseAction(ABC):
self.is_group = False
self.target_id = self.user_id
logger.debug(f"{self.log_prefix} Action组件初始化完成")
logger.info(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
@@ -199,7 +199,9 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"
async def send_text(self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False) -> bool:
async def send_text(
self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False
) -> bool:
"""发送文本消息
Args:
@@ -299,7 +301,7 @@ class BaseAction(ABC):
)
async def send_command(
self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
) -> bool:
"""发送命令消息

View File

@@ -135,7 +135,7 @@ class BaseCommand(ABC):
)
async def send_command(
self, command_name: str, args: dict = None, display_message: str = "", storage_message: bool = True
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
) -> bool:
"""发送命令消息

View File

@@ -346,67 +346,67 @@ class ComponentRegistry:
# === 状态管理方法 ===
def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR -------------------------------------
"""启用组件,支持命名空间解析"""
# 首先尝试找到正确的命名空间化名称
component_info = self.get_component_info(component_name, component_type)
if not component_info:
return False
# def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# # -------------------------------- LOGIC ERROR -------------------------------------
# """启用组件,支持命名空间解析"""
# # 首先尝试找到正确的命名空间化名称
# component_info = self.get_component_info(component_name, component_type)
# if not component_info:
# return False
# 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
else:
namespaced_name = (
f"{component_info.component_type.value}.{component_name}"
if "." not in component_name
else component_name
)
# # 根据组件类型构造正确的命名空间化名称
# if component_info.component_type == ComponentType.ACTION:
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
# elif component_info.component_type == ComponentType.COMMAND:
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
# else:
# namespaced_name = (
# f"{component_info.component_type.value}.{component_name}"
# if "." not in component_name
# else component_name
# )
if namespaced_name in self._components:
self._components[namespaced_name].enabled = True
# 如果是Action更新默认动作集
# ---- HERE ----
# if isinstance(component_info, ActionInfo):
# self._action_descriptions[component_name] = component_info.description
logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
return True
return False
# if namespaced_name in self._components:
# self._components[namespaced_name].enabled = True
# # 如果是Action更新默认动作集
# # ---- HERE ----
# # if isinstance(component_info, ActionInfo):
# # self._action_descriptions[component_name] = component_info.description
# logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
# return True
# return False
def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR -------------------------------------
"""禁用组件,支持命名空间解析"""
# 首先尝试找到正确的命名空间化名称
component_info = self.get_component_info(component_name, component_type)
if not component_info:
return False
# def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# # -------------------------------- LOGIC ERROR -------------------------------------
# """禁用组件,支持命名空间解析"""
# # 首先尝试找到正确的命名空间化名称
# component_info = self.get_component_info(component_name, component_type)
# if not component_info:
# return False
# 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
else:
namespaced_name = (
f"{component_info.component_type.value}.{component_name}"
if "." not in component_name
else component_name
)
# # 根据组件类型构造正确的命名空间化名称
# if component_info.component_type == ComponentType.ACTION:
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
# elif component_info.component_type == ComponentType.COMMAND:
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
# else:
# namespaced_name = (
# f"{component_info.component_type.value}.{component_name}"
# if "." not in component_name
# else component_name
# )
if namespaced_name in self._components:
self._components[namespaced_name].enabled = False
# 如果是Action从默认动作集中移除
# ---- HERE ----
# if component_name in self._action_descriptions:
# del self._action_descriptions[component_name]
logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
return True
return False
# if namespaced_name in self._components:
# self._components[namespaced_name].enabled = False
# # 如果是Action从默认动作集中移除
# # ---- HERE ----
# # if component_name in self._action_descriptions:
# # del self._action_descriptions[component_name]
# logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
# return True
# return False
def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息"""

View File

@@ -7,7 +7,7 @@
import subprocess
import sys
import importlib
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Any
from src.common.logger import get_logger
from src.plugin_system.base.component_types import PythonDependency
@@ -176,7 +176,7 @@ class DependencyManager:
logger.error(f"生成requirements文件失败: {str(e)}")
return False
def get_install_summary(self) -> Dict[str, any]:
def get_install_summary(self) -> Dict[str, Any]:
"""获取安装摘要"""
return {
"install_log": self.install_log.copy(),

View File

@@ -197,29 +197,29 @@ class PluginManager:
"""获取所有启用的插件信息"""
return list(component_registry.get_enabled_plugins().values())
def enable_plugin(self, plugin_name: str) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
"""启用插件"""
if plugin_info := component_registry.get_plugin_info(plugin_name):
plugin_info.enabled = True
# 启用插件的所有组件
for component in plugin_info.components:
component_registry.enable_component(component.name)
logger.debug(f"已启用插件: {plugin_name}")
return True
return False
# def enable_plugin(self, plugin_name: str) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# """启用插件"""
# if plugin_info := component_registry.get_plugin_info(plugin_name):
# plugin_info.enabled = True
# # 启用插件的所有组件
# for component in plugin_info.components:
# component_registry.enable_component(component.name)
# logger.debug(f"已启用插件: {plugin_name}")
# return True
# return False
def disable_plugin(self, plugin_name: str) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
"""禁用插件"""
if plugin_info := component_registry.get_plugin_info(plugin_name):
plugin_info.enabled = False
# 禁用插件的所有组件
for component in plugin_info.components:
component_registry.disable_component(component.name)
logger.debug(f"已禁用插件: {plugin_name}")
return True
return False
# def disable_plugin(self, plugin_name: str) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# """禁用插件"""
# if plugin_info := component_registry.get_plugin_info(plugin_name):
# plugin_info.enabled = False
# # 禁用插件的所有组件
# for component in plugin_info.components:
# component_registry.disable_component(component.name)
# logger.debug(f"已禁用插件: {plugin_name}")
# return True
# return False
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
"""获取插件实例