This commit is contained in:
SengokuCola
2025-07-15 02:54:02 +08:00
5 changed files with 110 additions and 77 deletions

View File

@@ -16,3 +16,7 @@
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容
3. 修复了插件系统混用了`plugin_name``display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。**(可能有遗漏)**
3. 部分API的参数类型和返回值进行了调整
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
- `config_api.py`中的`get_global_config``get_plugin_config`方法现在支持嵌套访问的配置键名。
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时保证了typing正确`db_get`方法增加了`single_result`参数,与`db_query`保持一致。

View File

@@ -13,23 +13,29 @@
"""
from typing import List, Dict, Any, Optional
from src.common.logger import get_logger
from enum import Enum
# 导入依赖
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
logger = get_logger("chat_api")
class SpecialTypes(Enum):
"""特殊枚举类型"""
ALL_PLATFORMS = "all_platforms"
class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取所有聊天流
Args:
platform: 平台筛选,默认为"qq"
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 聊天流列表
@@ -37,7 +43,7 @@ class ChatManager:
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if stream.platform == platform:
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的聊天流")
except Exception as e:
@@ -45,11 +51,11 @@ class ChatManager:
return streams
@staticmethod
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取所有群聊聊天流
Args:
platform: 平台筛选,默认为"qq"
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 群聊聊天流列表
@@ -57,7 +63,7 @@ class ChatManager:
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if stream.platform == platform and stream.group_info:
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
@@ -65,11 +71,11 @@ class ChatManager:
return streams
@staticmethod
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取所有私聊聊天流
Args:
platform: 平台筛选,默认为"qq"
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 私聊聊天流列表
@@ -77,7 +83,7 @@ class ChatManager:
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if stream.platform == platform and not stream.group_info:
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
@@ -85,12 +91,14 @@ class ChatManager:
return streams
@staticmethod
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]:
"""根据群ID获取聊天流
Args:
group_id: 群聊ID
platform: 平台,默认为"qq"
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None
@@ -110,12 +118,14 @@ class ChatManager:
return None
@staticmethod
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]:
"""根据用户ID获取私聊流
Args:
user_id: 用户ID
platform: 平台,默认为"qq"
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None
@@ -145,7 +155,7 @@ class ChatManager:
str: 聊天类型 ("group", "private", "unknown")
"""
if not chat_stream:
return "unknown"
raise ValueError("chat_stream cannot be None")
if hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
@@ -165,7 +175,7 @@ class ChatManager:
return {}
try:
info = {
info: Dict[str, Any] = {
"stream_id": chat_stream.stream_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
@@ -200,9 +210,9 @@ class ChatManager:
Dict[str, int]: 包含各种统计信息的字典
"""
try:
all_streams = ChatManager.get_all_streams()
group_streams = ChatManager.get_group_streams()
private_streams = ChatManager.get_private_streams()
all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS)
group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS)
private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS)
summary = {
"total_streams": len(all_streams),
@@ -215,7 +225,12 @@ class ChatManager:
return summary
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
return {"total_streams": 0, "group_streams": 0, "private_streams": 0, "qq_streams": 0}
return {
"total_streams": 0,
"group_streams": 0,
"private_streams": 0,
"qq_streams": 0,
}
# =============================================================================
@@ -223,41 +238,41 @@ class ChatManager:
# =============================================================================
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq"):
"""获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform)
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq"):
"""获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform)
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq"):
"""获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq"):
"""根据群ID获取聊天流的便捷函数"""
return ChatManager.get_stream_by_group_id(group_id, platform)
return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq"):
"""根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_stream_by_user_id(user_id, platform)
return ChatManager.get_private_stream_by_user_id(user_id, platform)
def get_stream_type(chat_stream: ChatStream) -> str:
def get_stream_type(chat_stream: ChatStream):
"""获取聊天流类型的便捷函数"""
return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
def get_stream_info(chat_stream: ChatStream):
"""获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream)
def get_streams_summary() -> Dict[str, int]:
def get_streams_summary():
"""获取聊天流统计摘要的便捷函数"""
return ChatManager.get_streams_summary()

View File

@@ -26,7 +26,7 @@ def get_global_config(key: str, default: Any = None) -> Any:
插件应使用此方法读取全局配置,以保证只读和隔离性。
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
key: 命名空间式配置键名,支持嵌套访问"section.subsection.key",大小写敏感
default: 如果配置不存在时返回的默认值
Returns:
@@ -41,7 +41,7 @@ def get_global_config(key: str, default: Any = None) -> Any:
if hasattr(current, k):
current = getattr(current, k)
else:
return default
raise KeyError(f"配置中不存在子空间或键 '{k}'")
return current
except Exception as e:
logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}")
@@ -54,26 +54,28 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
Args:
plugin_config: 插件配置字典
key: 配置键名,支持嵌套访问如 "section.subsection.key"
key: 配置键名,支持嵌套访问如 "section.subsection.key",大小写敏感
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
if not plugin_config:
return default
# 支持嵌套键访问
keys = key.split(".")
current = plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
try:
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
elif hasattr(current, k):
current = getattr(current, k)
else:
raise KeyError(f"配置中不存在子空间或键 '{k}'")
return current
except Exception as e:
logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}")
return default
# =============================================================================
@@ -82,7 +84,7 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]:
"""根据用户名获取用户ID
"""根据内部用户名获取用户ID
Args:
person_name: 用户名
@@ -93,8 +95,8 @@ async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]:
try:
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(person_name)
user_id = await person_info_manager.get_value(person_id, "user_id")
platform = await person_info_manager.get_value(person_id, "platform")
user_id: str = await person_info_manager.get_value(person_id, "user_id") # type: ignore
platform: str = await person_info_manager.get_value(person_id, "platform") # type: ignore
return platform, user_id
except Exception as e:
logger.error(f"[ConfigAPI] 根据用户名获取用户ID失败: {e}")
@@ -114,7 +116,10 @@ async def get_person_info(person_id: str, key: str, default: Any = None) -> Any:
"""
try:
person_info_manager = get_person_info_manager()
return await person_info_manager.get_value(person_id, key, default)
response = await person_info_manager.get_value(person_id, key)
if not response:
raise ValueError(f"[ConfigAPI] 获取用户 {person_id} 的信息 '{key}' 失败,返回默认值")
return response
except Exception as e:
logger.error(f"[ConfigAPI] 获取用户信息失败: {e}")
return default

View File

@@ -8,7 +8,7 @@
"""
import traceback
from typing import Dict, List, Any, Union, Type
from typing import Dict, List, Any, Union, Type, Optional
from src.common.logger import get_logger
from peewee import Model, DoesNotExist
@@ -21,12 +21,12 @@ logger = get_logger("database_api")
async def db_query(
model_class: Type[Model],
query_type: str = "get",
filters: Dict[str, Any] = None,
data: Dict[str, Any] = None,
limit: int = None,
order_by: List[str] = None,
single_result: bool = False,
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
@@ -34,11 +34,11 @@ async def db_query(
Args:
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
data: 用于创建或更新的数据字典
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典,键为字段名,值为要匹配的值
data: 用于创建或更新的数据字典
limit: 限制结果数量
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果
Returns:
@@ -48,7 +48,8 @@ async def db_query(
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
"""
"""
示例:
# 查询最近10条消息
messages = await database_api.db_query(
@@ -62,16 +63,16 @@ async def db_query(
# 创建一条记录
new_record = await database_api.db_query(
ActionRecords,
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
query_type="create",
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
)
# 更新记录
updated_count = await database_api.db_query(
ActionRecords,
data={"action_done": True},
query_type="update",
filters={"action_id": "123"},
data={"action_done": True}
)
# 删除记录
@@ -129,7 +130,7 @@ async def db_query(
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get()
return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore
elif query_type == "update":
if not data:
@@ -168,7 +169,7 @@ async def db_query(
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Union[Dict[str, Any], None]:
"""保存数据到数据库(创建或更新)
@@ -213,14 +214,14 @@ async def db_save(
existing_record.save()
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore
return created_record
except Exception as e:
@@ -230,7 +231,11 @@ async def db_save(
async def db_get(
model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
model_class: Type[Model],
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
@@ -239,11 +244,12 @@ async def db_get(
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
limit: 结果数量限制如果为1则返回单个记录而不是列表
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段即time字段降序
limit: 结果数量限制
single_result: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表
Returns:
如果limit=1返回单个记录字典或None
如果single_result为True返回单个记录字典或None
否则返回记录字典列表或空列表。
示例:
@@ -258,8 +264,8 @@ async def db_get(
records = await database_api.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by="-time",
limit=10
)
"""
try:
@@ -286,14 +292,14 @@ async def db_get(
results = list(query.dicts())
# 返回结果
if limit == 1:
if single_result:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if limit == 1 else []
return None if single_result else []
async def store_action_info(
@@ -302,7 +308,7 @@ async def store_action_info(
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: dict = None,
action_data: Optional[dict] = None,
action_name: str = "",
) -> Union[Dict[str, Any], None]:
"""存储动作信息到数据库

View File

@@ -102,7 +102,7 @@ class PluginManager:
"""
加载已经注册的插件类
"""
plugin_class: Type[BasePlugin] = self.plugin_classes.get(plugin_name)
plugin_class = self.plugin_classes.get(plugin_name)
if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1
@@ -115,7 +115,10 @@ class PluginManager:
plugin_dir = self._find_plugin_directory(plugin_class)
if plugin_dir:
self.plugin_paths[plugin_name] = plugin_dir # 更新路径
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
if not plugin_instance:
logger.error(f"插件 {plugin_name} 实例化失败")
return False, 1
# 检查插件是否启用
if not plugin_instance.enable_plugin:
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
@@ -248,7 +251,7 @@ class PluginManager:
"failed_plugin_details": self.failed_plugins.copy(),
}
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, any]:
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
"""检查所有插件的Python依赖包
Args:
@@ -381,7 +384,7 @@ class PluginManager:
return loaded_count, failed_count
def _find_plugin_directory(self, plugin_class: str) -> Optional[str]:
def _find_plugin_directory(self, plugin_class: Type[BasePlugin]) -> Optional[str]:
"""查找插件类对应的目录路径"""
try:
module = getmodule(plugin_class)