fix typing, api change
This commit is contained in:
@@ -16,3 +16,7 @@
|
||||
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
|
||||
2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容
|
||||
3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。**(可能有遗漏)**
|
||||
3. 部分API的参数类型和返回值进行了调整
|
||||
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
|
||||
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
||||
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
||||
|
||||
@@ -13,23 +13,29 @@
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
from enum import Enum
|
||||
|
||||
# 导入依赖
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
|
||||
logger = get_logger("chat_api")
|
||||
|
||||
|
||||
class SpecialTypes(Enum):
|
||||
"""特殊枚举类型"""
|
||||
|
||||
ALL_PLATFORMS = "all_platforms"
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取所有聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq"
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 聊天流列表
|
||||
@@ -37,7 +43,7 @@ class ChatManager:
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if stream.platform == platform:
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||
except Exception as e:
|
||||
@@ -45,11 +51,11 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取所有群聊聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq"
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 群聊聊天流列表
|
||||
@@ -57,7 +63,7 @@ class ChatManager:
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if stream.platform == platform and stream.group_info:
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||
except Exception as e:
|
||||
@@ -65,11 +71,11 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取所有私聊聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq"
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 私聊聊天流列表
|
||||
@@ -77,7 +83,7 @@ class ChatManager:
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if stream.platform == platform and not stream.group_info:
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||
except Exception as e:
|
||||
@@ -85,12 +91,14 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]:
|
||||
"""根据群ID获取聊天流
|
||||
|
||||
Args:
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
@@ -110,12 +118,14 @@ class ChatManager:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]:
|
||||
"""根据用户ID获取私聊流
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
@@ -145,7 +155,7 @@ class ChatManager:
|
||||
str: 聊天类型 ("group", "private", "unknown")
|
||||
"""
|
||||
if not chat_stream:
|
||||
return "unknown"
|
||||
raise ValueError("chat_stream cannot be None")
|
||||
|
||||
if hasattr(chat_stream, "group_info"):
|
||||
return "group" if chat_stream.group_info else "private"
|
||||
@@ -165,7 +175,7 @@ class ChatManager:
|
||||
return {}
|
||||
|
||||
try:
|
||||
info = {
|
||||
info: Dict[str, Any] = {
|
||||
"stream_id": chat_stream.stream_id,
|
||||
"platform": chat_stream.platform,
|
||||
"type": ChatManager.get_stream_type(chat_stream),
|
||||
@@ -200,9 +210,9 @@ class ChatManager:
|
||||
Dict[str, int]: 包含各种统计信息的字典
|
||||
"""
|
||||
try:
|
||||
all_streams = ChatManager.get_all_streams()
|
||||
group_streams = ChatManager.get_group_streams()
|
||||
private_streams = ChatManager.get_private_streams()
|
||||
all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
|
||||
summary = {
|
||||
"total_streams": len(all_streams),
|
||||
@@ -215,7 +225,12 @@ class ChatManager:
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
|
||||
return {"total_streams": 0, "group_streams": 0, "private_streams": 0, "qq_streams": 0}
|
||||
return {
|
||||
"total_streams": 0,
|
||||
"group_streams": 0,
|
||||
"private_streams": 0,
|
||||
"qq_streams": 0,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -223,41 +238,41 @@ class ChatManager:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq"):
|
||||
"""获取所有聊天流的便捷函数"""
|
||||
return ChatManager.get_all_streams(platform)
|
||||
|
||||
|
||||
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq"):
|
||||
"""获取群聊聊天流的便捷函数"""
|
||||
return ChatManager.get_group_streams(platform)
|
||||
|
||||
|
||||
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq"):
|
||||
"""获取私聊聊天流的便捷函数"""
|
||||
return ChatManager.get_private_streams(platform)
|
||||
|
||||
|
||||
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq"):
|
||||
"""根据群ID获取聊天流的便捷函数"""
|
||||
return ChatManager.get_stream_by_group_id(group_id, platform)
|
||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
||||
|
||||
|
||||
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq"):
|
||||
"""根据用户ID获取私聊流的便捷函数"""
|
||||
return ChatManager.get_stream_by_user_id(user_id, platform)
|
||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
||||
|
||||
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
def get_stream_type(chat_stream: ChatStream):
|
||||
"""获取聊天流类型的便捷函数"""
|
||||
return ChatManager.get_stream_type(chat_stream)
|
||||
|
||||
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
def get_stream_info(chat_stream: ChatStream):
|
||||
"""获取聊天流信息的便捷函数"""
|
||||
return ChatManager.get_stream_info(chat_stream)
|
||||
|
||||
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
def get_streams_summary():
|
||||
"""获取聊天流统计摘要的便捷函数"""
|
||||
return ChatManager.get_streams_summary()
|
||||
|
||||
@@ -26,7 +26,7 @@ def get_global_config(key: str, default: Any = None) -> Any:
|
||||
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
key: 命名空间式配置键名,支持嵌套访问,如 "section.subsection.key",大小写敏感
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
@@ -41,7 +41,7 @@ def get_global_config(key: str, default: Any = None) -> Any:
|
||||
if hasattr(current, k):
|
||||
current = getattr(current, k)
|
||||
else:
|
||||
return default
|
||||
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
||||
return current
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}")
|
||||
@@ -54,26 +54,28 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
|
||||
|
||||
Args:
|
||||
plugin_config: 插件配置字典
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key",大小写敏感
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
try:
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
elif hasattr(current, k):
|
||||
current = getattr(current, k)
|
||||
else:
|
||||
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
||||
return current
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}")
|
||||
return default
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -82,7 +84,7 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
|
||||
|
||||
|
||||
async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]:
|
||||
"""根据用户名获取用户ID
|
||||
"""根据内部用户名获取用户ID
|
||||
|
||||
Args:
|
||||
person_name: 用户名
|
||||
@@ -93,8 +95,8 @@ async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]:
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(person_name)
|
||||
user_id = await person_info_manager.get_value(person_id, "user_id")
|
||||
platform = await person_info_manager.get_value(person_id, "platform")
|
||||
user_id: str = await person_info_manager.get_value(person_id, "user_id") # type: ignore
|
||||
platform: str = await person_info_manager.get_value(person_id, "platform") # type: ignore
|
||||
return platform, user_id
|
||||
except Exception as e:
|
||||
logger.error(f"[ConfigAPI] 根据用户名获取用户ID失败: {e}")
|
||||
@@ -114,7 +116,10 @@ async def get_person_info(person_id: str, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
return await person_info_manager.get_value(person_id, key, default)
|
||||
response = await person_info_manager.get_value(person_id, key)
|
||||
if not response:
|
||||
raise ValueError(f"[ConfigAPI] 获取用户 {person_id} 的信息 '{key}' 失败,返回默认值")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"[ConfigAPI] 获取用户信息失败: {e}")
|
||||
return default
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Dict, List, Any, Union, Type
|
||||
from typing import Dict, List, Any, Union, Type, Optional
|
||||
from src.common.logger import get_logger
|
||||
from peewee import Model, DoesNotExist
|
||||
|
||||
@@ -21,12 +21,12 @@ logger = get_logger("database_api")
|
||||
|
||||
async def db_query(
|
||||
model_class: Type[Model],
|
||||
query_type: str = "get",
|
||||
filters: Dict[str, Any] = None,
|
||||
data: Dict[str, Any] = None,
|
||||
limit: int = None,
|
||||
order_by: List[str] = None,
|
||||
single_result: bool = False,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[List[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
@@ -34,11 +34,11 @@ async def db_query(
|
||||
|
||||
Args:
|
||||
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
|
||||
data: 用于创建或更新的数据字典
|
||||
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
|
||||
filters: 过滤条件字典,键为字段名,值为要匹配的值
|
||||
data: 用于创建或更新的数据字典
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
@@ -48,7 +48,8 @@ async def db_query(
|
||||
- "update": 返回受影响的行数
|
||||
- "delete": 返回受影响的行数
|
||||
- "count": 返回记录数量
|
||||
|
||||
"""
|
||||
"""
|
||||
示例:
|
||||
# 查询最近10条消息
|
||||
messages = await database_api.db_query(
|
||||
@@ -62,16 +63,16 @@ async def db_query(
|
||||
# 创建一条记录
|
||||
new_record = await database_api.db_query(
|
||||
ActionRecords,
|
||||
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
|
||||
query_type="create",
|
||||
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
|
||||
)
|
||||
|
||||
# 更新记录
|
||||
updated_count = await database_api.db_query(
|
||||
ActionRecords,
|
||||
data={"action_done": True},
|
||||
query_type="update",
|
||||
filters={"action_id": "123"},
|
||||
data={"action_done": True}
|
||||
)
|
||||
|
||||
# 删除记录
|
||||
@@ -129,7 +130,7 @@ async def db_query(
|
||||
# 创建记录
|
||||
record = model_class.create(**data)
|
||||
# 返回创建的记录
|
||||
return model_class.select().where(model_class.id == record.id).dicts().get()
|
||||
return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
@@ -168,7 +169,7 @@ async def db_query(
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
|
||||
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
@@ -213,14 +214,14 @@ async def db_save(
|
||||
existing_record.save()
|
||||
|
||||
# 返回更新后的记录
|
||||
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
|
||||
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore
|
||||
return updated_record
|
||||
|
||||
# 如果没有找到现有记录或未提供key_field和key_value,创建新记录
|
||||
new_record = model_class.create(**data)
|
||||
|
||||
# 返回创建的记录
|
||||
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
|
||||
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore
|
||||
return created_record
|
||||
|
||||
except Exception as e:
|
||||
@@ -230,7 +231,11 @@ async def db_save(
|
||||
|
||||
|
||||
async def db_get(
|
||||
model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
|
||||
model_class: Type[Model],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
@@ -239,11 +244,12 @@ async def db_get(
|
||||
Args:
|
||||
model_class: Peewee模型类
|
||||
filters: 过滤条件,字段名和值的字典
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
|
||||
limit: 结果数量限制,如果为1则返回单个记录而不是列表
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
|
||||
limit: 结果数量限制
|
||||
single_result: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表
|
||||
|
||||
Returns:
|
||||
如果limit=1,返回单个记录字典或None;
|
||||
如果single_result为True,返回单个记录字典或None;
|
||||
否则返回记录字典列表或空列表。
|
||||
|
||||
示例:
|
||||
@@ -258,8 +264,8 @@ async def db_get(
|
||||
records = await database_api.db_get(
|
||||
Messages,
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by="-time",
|
||||
limit=10
|
||||
)
|
||||
"""
|
||||
try:
|
||||
@@ -286,14 +292,14 @@ async def db_get(
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if limit == 1:
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None if limit == 1 else []
|
||||
return None if single_result else []
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
@@ -302,7 +308,7 @@ async def store_action_info(
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: dict = None,
|
||||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
@@ -102,7 +102,7 @@ class PluginManager:
|
||||
"""
|
||||
加载已经注册的插件类
|
||||
"""
|
||||
plugin_class: Type[BasePlugin] = self.plugin_classes.get(plugin_name)
|
||||
plugin_class = self.plugin_classes.get(plugin_name)
|
||||
if not plugin_class:
|
||||
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
|
||||
return False, 1
|
||||
@@ -115,7 +115,10 @@ class PluginManager:
|
||||
plugin_dir = self._find_plugin_directory(plugin_class)
|
||||
if plugin_dir:
|
||||
self.plugin_paths[plugin_name] = plugin_dir # 更新路径
|
||||
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
||||
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
||||
if not plugin_instance:
|
||||
logger.error(f"插件 {plugin_name} 实例化失败")
|
||||
return False, 1
|
||||
# 检查插件是否启用
|
||||
if not plugin_instance.enable_plugin:
|
||||
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
|
||||
@@ -248,7 +251,7 @@ class PluginManager:
|
||||
"failed_plugin_details": self.failed_plugins.copy(),
|
||||
}
|
||||
|
||||
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, any]:
|
||||
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
|
||||
"""检查所有插件的Python依赖包
|
||||
|
||||
Args:
|
||||
@@ -381,7 +384,7 @@ class PluginManager:
|
||||
|
||||
return loaded_count, failed_count
|
||||
|
||||
def _find_plugin_directory(self, plugin_class: str) -> Optional[str]:
|
||||
def _find_plugin_directory(self, plugin_class: Type[BasePlugin]) -> Optional[str]:
|
||||
"""查找插件类对应的目录路径"""
|
||||
try:
|
||||
module = getmodule(plugin_class)
|
||||
|
||||
Reference in New Issue
Block a user