Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
54
changes.md
54
changes.md
@@ -1,26 +1,60 @@
|
||||
# 插件API与规范修改
|
||||
|
||||
1. 现在`plugin_system`的`__init__.py`文件中包含了所有插件API的导入,用户可以直接使用`from plugin_system import *`来导入所有API。
|
||||
1. 现在`plugin_system`的`__init__.py`文件中包含了所有插件API的导入,用户可以直接使用`from src.plugin_system import *`来导入所有API。
|
||||
|
||||
2. register_plugin函数现在转移到了`plugin_system.apis.plugin_register_api`模块中,用户可以通过`from plugin_system.apis.plugin_register_api import register_plugin`来导入。
|
||||
2. register_plugin函数现在转移到了`plugin_system.apis.plugin_register_api`模块中,用户可以通过`from src.plugin_system.apis.plugin_register_api import register_plugin`来导入。
|
||||
- 顺便一提,按照1中说法,你可以这么用:
|
||||
```python
|
||||
from src.plugin_system import register_plugin
|
||||
```
|
||||
|
||||
3. 现在强制要求的property如下:
|
||||
3. 现在强制要求的property如下,即你必须覆盖的属性有:
|
||||
- `plugin_name`: 插件名称,必须是唯一的。(与文件夹相同)
|
||||
- `enable_plugin`: 是否启用插件,默认为`True`。
|
||||
- `dependencies`: 插件依赖的其他插件列表,默认为空。**现在并不检查(也许)**
|
||||
- `python_dependencies`: 插件依赖的Python包列表,默认为空。**现在并不检查**
|
||||
- `config_file_name`: 插件配置文件名,默认为`config.toml`。
|
||||
- `config_schema`: 插件配置文件的schema,用于自动生成配置文件。
|
||||
4. 部分API的参数类型和返回值进行了调整
|
||||
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
|
||||
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
||||
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
||||
5. 增加了`logging_api`,可以用`get_logger`来获取日志记录器。
|
||||
|
||||
# 插件系统修改
|
||||
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
|
||||
2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容
|
||||
3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。**(可能有遗漏)**
|
||||
3. 部分API的参数类型和返回值进行了调整
|
||||
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
|
||||
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
||||
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
||||
3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。
|
||||
4. 现在增加了参数类型检查,完善了对应注释
|
||||
5. 现在插件抽象出了总基类 `PluginBase`
|
||||
- 基于`Action`和`Command`的插件基类现在为`BasePlugin`,它继承自`PluginBase`,由`register_plugin`装饰器注册。
|
||||
- 基于`Event`的插件基类现在为`BaseEventPlugin`,它也继承自`PluginBase`,由`register_event_plugin`装饰器注册。
|
||||
- <del>基于`Action`和`Command`的插件基类现在为`BasePlugin`。</del>
|
||||
- <del>基于`Event`的插件基类现在为`BaseEventPlugin`。</del>
|
||||
- 基于`Action`,`Command`和`Event`的插件基类现在为`BasePlugin`,所有插件都应该继承此基类。
|
||||
- `BasePlugin`继承自`PluginBase`。
|
||||
- 所有的插件类都由`register_plugin`装饰器注册。
|
||||
6. 现在我们终于可以让插件有自定义的名字了!
|
||||
- 真正实现了插件的`plugin_name`**不受文件夹名称限制**的功能。(吐槽:可乐你的某个小小细节导致我搞了好久……)
|
||||
- 通过在插件类中定义`plugin_name`属性来指定插件内部标识符。
|
||||
- 由于此更改一个文件中现在可以有多个插件类,但每个插件类必须有**唯一的**`plugin_name`。
|
||||
- 在某些插件加载失败时,现在会显示包名而不是插件内部标识符。
|
||||
- 例如:`MaiMBot.plugins.example_plugin`而不是`example_plugin`。
|
||||
- 仅在插件 import 失败时会如此,正常注册过程中失败的插件不会显示包名,而是显示插件内部标识符。(这是特性,但是基本上不可能出现这个情况)
|
||||
7. 现在不支持单文件插件了,加载方式已经完全删除。
|
||||
8. 把`BaseEventPlugin`合并到了`BasePlugin`中,所有插件都应该继承自`BasePlugin`。
|
||||
|
||||
|
||||
# 吐槽
|
||||
```python
|
||||
plugin_path = Path(plugin_file)
|
||||
if plugin_path.parent.name != "plugins":
|
||||
# 插件包格式:parent_dir.plugin
|
||||
module_name = f"plugins.{plugin_path.parent.name}.plugin"
|
||||
else:
|
||||
# 单文件格式:plugins.filename
|
||||
module_name = f"plugins.{plugin_path.stem}"
|
||||
```
|
||||
```python
|
||||
plugin_path = Path(plugin_file)
|
||||
module_name = ".".join(plugin_path.parent.parts)
|
||||
```
|
||||
这两个区别很大的。
|
||||
@@ -7,11 +7,13 @@ from src.plugin_system import (
|
||||
ComponentInfo,
|
||||
ActionActivationType,
|
||||
ConfigField,
|
||||
BaseEventHandler,
|
||||
EventType,
|
||||
MaiMessages,
|
||||
)
|
||||
|
||||
|
||||
# ===== Action组件 =====
|
||||
|
||||
|
||||
class HelloAction(BaseAction):
|
||||
"""问候Action - 简单的问候动作"""
|
||||
|
||||
@@ -82,7 +84,7 @@ class TimeCommand(BaseCommand):
|
||||
import datetime
|
||||
|
||||
# 获取当前时间
|
||||
time_format = self.get_config("time.format", "%Y-%m-%d %H:%M:%S")
|
||||
time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime(time_format)
|
||||
|
||||
@@ -93,6 +95,20 @@ class TimeCommand(BaseCommand):
|
||||
return True, f"显示了当前时间: {time_str}"
|
||||
|
||||
|
||||
class PrintMessage(BaseEventHandler):
|
||||
"""打印消息事件处理器 - 处理打印消息事件"""
|
||||
|
||||
event_type = EventType.ON_MESSAGE
|
||||
handler_name = "print_message_handler"
|
||||
handler_description = "打印接收到的消息"
|
||||
|
||||
async def execute(self, message: MaiMessages) -> Tuple[bool, str | None]:
|
||||
"""执行打印消息事件处理"""
|
||||
# 打印接收到的消息
|
||||
print(f"接收到消息: {message.raw_message}")
|
||||
return True, "消息已打印"
|
||||
|
||||
|
||||
# ===== 插件注册 =====
|
||||
|
||||
|
||||
@@ -122,6 +138,7 @@ class HelloWorldPlugin(BasePlugin):
|
||||
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
|
||||
},
|
||||
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
|
||||
"print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
@@ -129,4 +146,27 @@ class HelloWorldPlugin(BasePlugin):
|
||||
(HelloAction.get_action_info(), HelloAction),
|
||||
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
||||
(TimeCommand.get_command_info(), TimeCommand),
|
||||
(PrintMessage.get_handler_info(), PrintMessage),
|
||||
]
|
||||
|
||||
|
||||
# @register_plugin
|
||||
# class HelloWorldEventPlugin(BaseEPlugin):
|
||||
# """Hello World事件插件 - 处理问候和告别事件"""
|
||||
|
||||
# plugin_name = "hello_world_event_plugin"
|
||||
# enable_plugin = False
|
||||
# dependencies = []
|
||||
# python_dependencies = []
|
||||
# config_file_name = "event_config.toml"
|
||||
|
||||
# config_schema = {
|
||||
# "plugin": {
|
||||
# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"),
|
||||
# "version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
# },
|
||||
# }
|
||||
|
||||
# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
# return [(PrintMessage.get_handler_info(), PrintMessage)]
|
||||
|
||||
@@ -179,8 +179,7 @@ class HeartFChatting:
|
||||
await asyncio.sleep(10)
|
||||
if self.loop_mode == ChatMode.NORMAL:
|
||||
self.energy_value -= 0.3
|
||||
if self.energy_value <= 0.3:
|
||||
self.energy_value = 0.3
|
||||
self.energy_value = max(self.energy_value, 0.3)
|
||||
|
||||
def print_cycle_info(self, cycle_timers):
|
||||
# 记录循环信息和计时器结果
|
||||
@@ -257,6 +256,7 @@ class HeartFChatting:
|
||||
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||
|
||||
async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
|
||||
# sourcery skip: hoist-statement-from-if, merge-comparisons, reintroduce-else
|
||||
if not message_data:
|
||||
message_data = {}
|
||||
action_type = "no_action"
|
||||
@@ -462,7 +462,7 @@ class HeartFChatting:
|
||||
在"兴趣"模式下,判断是否回复并生成内容。
|
||||
"""
|
||||
|
||||
interested_rate = message_data.get("interest_value", 0.0) * self.willing_amplifier
|
||||
interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier
|
||||
|
||||
self.willing_manager.setup(message_data, self.chat_stream)
|
||||
|
||||
|
||||
@@ -294,10 +294,10 @@ class EmbeddingStore:
|
||||
"""
|
||||
if self.faiss_index is None:
|
||||
logger.debug("FaissIndex尚未构建,返回None")
|
||||
return None
|
||||
return []
|
||||
if self.idx2hash is None:
|
||||
logger.warning("idx2hash尚未构建,返回None")
|
||||
return None
|
||||
return []
|
||||
|
||||
# L2归一化
|
||||
faiss.normalize_L2(np.array([query], dtype=np.float32))
|
||||
@@ -318,15 +318,15 @@ class EmbeddingStore:
|
||||
class EmbeddingManager:
|
||||
def __init__(self):
|
||||
self.paragraphs_embedding_store = EmbeddingStore(
|
||||
local_storage['pg_namespace'],
|
||||
local_storage["pg_namespace"], # type: ignore
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.entities_embedding_store = EmbeddingStore(
|
||||
local_storage['pg_namespace'],
|
||||
local_storage["pg_namespace"], # type: ignore
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.relation_embedding_store = EmbeddingStore(
|
||||
local_storage['pg_namespace'],
|
||||
local_storage["pg_namespace"], # type: ignore
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.stored_pg_hashes = set()
|
||||
|
||||
@@ -30,7 +30,7 @@ def _get_kg_dir():
|
||||
"""
|
||||
安全地获取KG数据目录路径
|
||||
"""
|
||||
root_path = local_storage['root_path']
|
||||
root_path: str = local_storage["root_path"]
|
||||
if root_path is None:
|
||||
# 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -38,7 +38,7 @@ def _get_kg_dir():
|
||||
logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}")
|
||||
|
||||
# 获取RAG数据目录
|
||||
rag_data_dir = global_config["persistence"]["rag_data_dir"]
|
||||
rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
|
||||
if rag_data_dir is None:
|
||||
kg_dir = os.path.join(root_path, "data/rag")
|
||||
else:
|
||||
@@ -65,9 +65,9 @@ class KGManager:
|
||||
|
||||
# 持久化相关 - 使用延迟初始化的路径
|
||||
self.dir_path = get_kg_dir_str()
|
||||
self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml"
|
||||
self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet"
|
||||
self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json"
|
||||
self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml"
|
||||
self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet"
|
||||
self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json"
|
||||
|
||||
def save_to_file(self):
|
||||
"""将KG数据保存到文件"""
|
||||
@@ -91,11 +91,11 @@ class KGManager:
|
||||
"""从文件加载KG数据"""
|
||||
# 确保文件存在
|
||||
if not os.path.exists(self.pg_hash_file_path):
|
||||
raise Exception(f"KG段落hash文件{self.pg_hash_file_path}不存在")
|
||||
raise FileNotFoundError(f"KG段落hash文件{self.pg_hash_file_path}不存在")
|
||||
if not os.path.exists(self.ent_cnt_data_path):
|
||||
raise Exception(f"KG实体计数文件{self.ent_cnt_data_path}不存在")
|
||||
raise FileNotFoundError(f"KG实体计数文件{self.ent_cnt_data_path}不存在")
|
||||
if not os.path.exists(self.graph_data_path):
|
||||
raise Exception(f"KG图文件{self.graph_data_path}不存在")
|
||||
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
|
||||
|
||||
# 加载段落hash
|
||||
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
|
||||
@@ -122,8 +122,8 @@ class KGManager:
|
||||
# 避免自连接
|
||||
continue
|
||||
# 一个triple就是一条边(同时构建双向联系)
|
||||
hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
|
||||
hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2])
|
||||
hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
|
||||
hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2])
|
||||
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
||||
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
||||
entity_set.add(hash_key1)
|
||||
@@ -141,8 +141,8 @@ class KGManager:
|
||||
"""构建实体节点与文段节点之间的关系"""
|
||||
for idx in triple_list_data:
|
||||
for triple in triple_list_data[idx]:
|
||||
ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
|
||||
pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx)
|
||||
ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
|
||||
pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx)
|
||||
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
||||
|
||||
@staticmethod
|
||||
@@ -157,8 +157,8 @@ class KGManager:
|
||||
ent_hash_list = set()
|
||||
for triple_list in triple_list_data.values():
|
||||
for triple in triple_list:
|
||||
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0]))
|
||||
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2]))
|
||||
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0]))
|
||||
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2]))
|
||||
ent_hash_list = list(ent_hash_list)
|
||||
|
||||
synonym_hash_set = set()
|
||||
@@ -263,7 +263,7 @@ class KGManager:
|
||||
for src_tgt in node_to_node.keys():
|
||||
for node_hash in src_tgt:
|
||||
if node_hash not in existed_nodes:
|
||||
if node_hash.startswith(local_storage['ent_namespace']):
|
||||
if node_hash.startswith(local_storage["ent_namespace"]):
|
||||
# 新增实体节点
|
||||
node = embedding_manager.entities_embedding_store.store.get(node_hash)
|
||||
if node is None:
|
||||
@@ -275,7 +275,7 @@ class KGManager:
|
||||
node_item["type"] = "ent"
|
||||
node_item["create_time"] = now_time
|
||||
self.graph.update_node(node_item)
|
||||
elif node_hash.startswith(local_storage['pg_namespace']):
|
||||
elif node_hash.startswith(local_storage["pg_namespace"]):
|
||||
# 新增文段节点
|
||||
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
|
||||
if node is None:
|
||||
@@ -359,7 +359,7 @@ class KGManager:
|
||||
# 关系三元组
|
||||
triple = relation[2:-2].split("', '")
|
||||
for ent in [(triple[0]), (triple[2])]:
|
||||
ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent)
|
||||
ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent)
|
||||
if ent_hash in existed_nodes: # 该实体需在KG中存在
|
||||
if ent_hash not in ent_sim_scores: # 尚未记录的实体
|
||||
ent_sim_scores[ent_hash] = []
|
||||
@@ -437,7 +437,9 @@ class KGManager:
|
||||
# 获取最终结果
|
||||
# 从搜索结果中提取文段节点的结果
|
||||
passage_node_res = [
|
||||
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace'])
|
||||
(node_key, score)
|
||||
for node_key, score in ppr_res.items()
|
||||
if node_key.startswith(local_storage["pg_namespace"])
|
||||
]
|
||||
del ppr_res
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
|
||||
|
||||
def _initialize_knowledge_local_storage():
|
||||
"""
|
||||
初始化知识库相关的本地存储配置
|
||||
@@ -41,23 +42,21 @@ def _initialize_knowledge_local_storage():
|
||||
# 定义所有需要初始化的配置项
|
||||
default_configs = {
|
||||
# 路径配置
|
||||
'root_path': ROOT_PATH,
|
||||
'data_path': f"{ROOT_PATH}/data",
|
||||
|
||||
"root_path": ROOT_PATH,
|
||||
"data_path": f"{ROOT_PATH}/data",
|
||||
# 实体和命名空间配置
|
||||
'lpmm_invalid_entity': INVALID_ENTITY,
|
||||
'pg_namespace': PG_NAMESPACE,
|
||||
'ent_namespace': ENT_NAMESPACE,
|
||||
'rel_namespace': REL_NAMESPACE,
|
||||
|
||||
"lpmm_invalid_entity": INVALID_ENTITY,
|
||||
"pg_namespace": PG_NAMESPACE,
|
||||
"ent_namespace": ENT_NAMESPACE,
|
||||
"rel_namespace": REL_NAMESPACE,
|
||||
# RAG相关命名空间配置
|
||||
'rag_graph_namespace': RAG_GRAPH_NAMESPACE,
|
||||
'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE,
|
||||
'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE
|
||||
"rag_graph_namespace": RAG_GRAPH_NAMESPACE,
|
||||
"rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
|
||||
"rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
|
||||
}
|
||||
|
||||
# 日志级别映射:重要配置用info,其他用debug
|
||||
important_configs = {'root_path', 'data_path'}
|
||||
important_configs = {"root_path", "data_path"}
|
||||
|
||||
# 批量设置配置项
|
||||
initialized_count = 0
|
||||
@@ -78,18 +77,23 @@ def _initialize_knowledge_local_storage():
|
||||
else:
|
||||
logger.debug("知识库本地存储配置已存在,跳过初始化")
|
||||
|
||||
|
||||
# 初始化本地存储路径
|
||||
# sourcery skip: dict-comprehension
|
||||
_initialize_knowledge_local_storage()
|
||||
|
||||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
# 检查LPMM知识库是否启用
|
||||
if bot_global_config.lpmm_knowledge.enable:
|
||||
logger.info("正在初始化Mai-LPMM")
|
||||
logger.info("创建LLM客户端")
|
||||
llm_client_list = dict()
|
||||
llm_client_list = {}
|
||||
for key in global_config["llm_providers"]:
|
||||
llm_client_list[key] = LLMClient(
|
||||
global_config["llm_providers"][key]["base_url"],
|
||||
global_config["llm_providers"][key]["api_key"],
|
||||
global_config["llm_providers"][key]["base_url"], # type: ignore
|
||||
global_config["llm_providers"][key]["api_key"], # type: ignore
|
||||
)
|
||||
|
||||
# 初始化Embedding库
|
||||
@@ -98,7 +102,7 @@ if bot_global_config.lpmm_knowledge.enable:
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.warning("此消息不会影响正常使用:从文件加载Embedding库时,{}".format(e))
|
||||
logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}")
|
||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||
logger.info("Embedding库加载完成")
|
||||
# 初始化KG
|
||||
@@ -107,7 +111,7 @@ if bot_global_config.lpmm_knowledge.enable:
|
||||
try:
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.warning("此消息不会影响正常使用:从文件加载KG时,{}".format(e))
|
||||
logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}")
|
||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||
logger.info("KG加载完成")
|
||||
|
||||
@@ -116,7 +120,7 @@ if bot_global_config.lpmm_knowledge.enable:
|
||||
|
||||
# 数据比对:Embedding库与KG的段落hash集合
|
||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||
key = PG_NAMESPACE + "-" + pg_hash
|
||||
key = f"{PG_NAMESPACE}-{pg_hash}"
|
||||
if key not in embed_manager.stored_pg_hashes:
|
||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||
|
||||
@@ -134,5 +138,3 @@ if bot_global_config.lpmm_knowledge.enable:
|
||||
else:
|
||||
logger.info("LPMM知识库已禁用,跳过初始化")
|
||||
# 创建空的占位符对象,避免导入错误
|
||||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from .llm_client import LLMMessage
|
||||
|
||||
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体,并以JSON列表的形式输出。
|
||||
|
||||
输出格式示例:
|
||||
@@ -63,10 +61,10 @@ qa_system_prompt = """
|
||||
"""
|
||||
|
||||
|
||||
def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
|
||||
knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
|
||||
messages = [
|
||||
LLMMessage("system", qa_system_prompt).to_dict(),
|
||||
LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(),
|
||||
]
|
||||
return messages
|
||||
# def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
|
||||
# knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
|
||||
# messages = [
|
||||
# LLMMessage("system", qa_system_prompt).to_dict(),
|
||||
# LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(),
|
||||
# ]
|
||||
# return messages
|
||||
|
||||
@@ -9,6 +9,7 @@ from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from .chat_stream import ChatStream
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -106,6 +107,7 @@ class MessageRecv(Message):
|
||||
self.has_emoji = False
|
||||
self.is_picid = False
|
||||
self.has_picid = False
|
||||
self.is_voice = False
|
||||
self.is_mentioned = None
|
||||
|
||||
self.is_command = False
|
||||
@@ -153,17 +155,27 @@ class MessageRecv(Message):
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
@@ -212,10 +224,12 @@ class MessageRecvS4U(MessageRecv):
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
self.is_voice = False
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
@@ -233,12 +247,22 @@ class MessageRecvS4U(MessageRecv):
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.has_picid = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
if isinstance(segment.data, dict):
|
||||
@@ -253,6 +277,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "gift":
|
||||
self.is_voice = False
|
||||
self.is_gift = True
|
||||
# 解析gift_info,格式为"名称:数量"
|
||||
name, count = segment.data.split(":", 1) # type: ignore
|
||||
@@ -343,6 +368,10 @@ class MessageProcessBase(Message):
|
||||
if isinstance(seg.data, str):
|
||||
return await get_image_manager().get_emoji_description(seg.data)
|
||||
return "[表情,网卡了加载不出来]"
|
||||
elif seg.type == "voice":
|
||||
if isinstance(seg.data, str):
|
||||
return await get_voice_text(seg.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif seg.type == "at":
|
||||
return f"[@{seg.data}]"
|
||||
elif seg.type == "reply":
|
||||
@@ -455,25 +484,25 @@ class MessageSending(MessageProcessBase):
|
||||
if self.message_segment:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
@classmethod
|
||||
def from_thinking(
|
||||
cls,
|
||||
thinking: MessageThinking,
|
||||
message_segment: Seg,
|
||||
is_head: bool = False,
|
||||
is_emoji: bool = False,
|
||||
) -> "MessageSending":
|
||||
"""从思考状态消息创建发送状态消息"""
|
||||
return cls(
|
||||
message_id=thinking.message_info.message_id, # type: ignore
|
||||
chat_stream=thinking.chat_stream,
|
||||
message_segment=message_segment,
|
||||
bot_user_info=thinking.message_info.user_info, # type: ignore
|
||||
reply=thinking.reply,
|
||||
is_head=is_head,
|
||||
is_emoji=is_emoji,
|
||||
sender_info=None,
|
||||
)
|
||||
# @classmethod
|
||||
# def from_thinking(
|
||||
# cls,
|
||||
# thinking: MessageThinking,
|
||||
# message_segment: Seg,
|
||||
# is_head: bool = False,
|
||||
# is_emoji: bool = False,
|
||||
# ) -> "MessageSending":
|
||||
# """从思考状态消息创建发送状态消息"""
|
||||
# return cls(
|
||||
# message_id=thinking.message_info.message_id, # type: ignore
|
||||
# chat_stream=thinking.chat_stream,
|
||||
# message_segment=message_segment,
|
||||
# bot_user_info=thinking.message_info.user_info, # type: ignore
|
||||
# reply=thinking.reply,
|
||||
# is_head=is_head,
|
||||
# is_emoji=is_emoji,
|
||||
# sender_info=None,
|
||||
# )
|
||||
|
||||
def to_dict(self):
|
||||
ret = super().to_dict()
|
||||
|
||||
@@ -262,4 +262,4 @@ class ActionManager:
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_component_class(action_name) # type: ignore
|
||||
return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore
|
||||
|
||||
35
src/chat/utils/utils_voice.py
Normal file
35
src/chat/utils/utils_voice.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import base64
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_voice")
|
||||
|
||||
async def get_voice_text(voice_base64: str) -> str:
|
||||
"""获取音频文件描述"""
|
||||
if not global_config.chat.enable_asr:
|
||||
logger.warning("语音识别未启用,无法处理语音消息")
|
||||
return "[语音]"
|
||||
try:
|
||||
# 解码base64音频数据
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
if isinstance(voice_base64, str):
|
||||
voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
voice_bytes = base64.b64decode(voice_base64)
|
||||
_llm = LLMRequest(model=global_config.model.voice, request_type="voice")
|
||||
text = await _llm.generate_response_for_voice(voice_bytes)
|
||||
if text is None:
|
||||
logger.warning("未能生成语音文本")
|
||||
return "[语音(文本生成失败)]"
|
||||
|
||||
logger.debug(f"描述是{text}")
|
||||
|
||||
return f"[语音:{text}]"
|
||||
except Exception as e:
|
||||
logger.error(f"语音转文字失败: {str(e)}")
|
||||
return "[语音]"
|
||||
|
||||
@@ -21,6 +21,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
||||
self._decay_task = asyncio.create_task(self._decay_reply_willing())
|
||||
|
||||
async def get_reply_probability(self, message_id):
|
||||
# sourcery skip: inline-immediately-returned-variable
|
||||
willing_info = self.ongoing_messages[message_id]
|
||||
chat_id = willing_info.chat_id
|
||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||
|
||||
@@ -25,6 +25,8 @@ import asyncio
|
||||
import time
|
||||
import math
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
|
||||
class MxpWillingManager(BaseWillingManager):
|
||||
"""Mxp意愿管理器"""
|
||||
@@ -76,7 +78,7 @@ class MxpWillingManager(BaseWillingManager):
|
||||
self.chat_bot_message_time[w_info.chat_id].append(current_time)
|
||||
if len(self.chat_bot_message_time[w_info.chat_id]) == int(self.fatigue_messages_triggered_num):
|
||||
time_interval = 60 - (current_time - self.chat_bot_message_time[w_info.chat_id].pop(0))
|
||||
self.chat_fatigue_punishment_list[w_info.chat_id].append([current_time, time_interval * 2])
|
||||
self.chat_fatigue_punishment_list[w_info.chat_id].append((current_time, time_interval * 2))
|
||||
|
||||
async def after_generate_reply_handle(self, message_id: str):
|
||||
"""回复后处理"""
|
||||
@@ -87,12 +89,14 @@ class MxpWillingManager(BaseWillingManager):
|
||||
# rel_level = self._get_relationship_level_num(rel_value)
|
||||
# self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += rel_level * 0.05
|
||||
|
||||
now_chat_new_person = self.last_response_person.get(w_info.chat_id, [w_info.person_id, 0])
|
||||
now_chat_new_person = self.last_response_person.get(w_info.chat_id, (w_info.person_id, 0))
|
||||
if now_chat_new_person[0] == w_info.person_id:
|
||||
if now_chat_new_person[1] < 3:
|
||||
now_chat_new_person[1] += 1
|
||||
tmp_list = list(now_chat_new_person)
|
||||
tmp_list[1] += 1 # type: ignore
|
||||
self.last_response_person[w_info.chat_id] = tuple(tmp_list) # type: ignore
|
||||
else:
|
||||
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0]
|
||||
self.last_response_person[w_info.chat_id] = (w_info.person_id, 0)
|
||||
|
||||
async def not_reply_handle(self, message_id: str):
|
||||
"""不回复处理"""
|
||||
@@ -108,11 +112,12 @@ class MxpWillingManager(BaseWillingManager):
|
||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * (
|
||||
2 * self.last_response_person[w_info.chat_id][1] - 1
|
||||
)
|
||||
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0])
|
||||
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ("", 0))
|
||||
if now_chat_new_person[0] != w_info.person_id:
|
||||
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0]
|
||||
self.last_response_person[w_info.chat_id] = (w_info.person_id, 0)
|
||||
|
||||
async def get_reply_probability(self, message_id: str):
|
||||
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
||||
"""获取回复概率"""
|
||||
async with self.lock:
|
||||
w_info = self.ongoing_messages[message_id]
|
||||
@@ -121,17 +126,16 @@ class MxpWillingManager(BaseWillingManager):
|
||||
self.logger.debug(f"基础意愿值:{current_willing}")
|
||||
|
||||
if w_info.is_mentioned_bot:
|
||||
current_willing_ = self.mention_willing_gain / (int(current_willing) + 1)
|
||||
current_willing += current_willing_
|
||||
willing_gain = self.mention_willing_gain / (int(current_willing) + 1)
|
||||
current_willing += willing_gain
|
||||
if self.is_debug:
|
||||
self.logger.debug(f"提及增益:{current_willing_}")
|
||||
self.logger.debug(f"提及增益:{willing_gain}")
|
||||
|
||||
if w_info.interested_rate > 0:
|
||||
current_willing += math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain
|
||||
willing_gain = math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain
|
||||
current_willing += willing_gain
|
||||
if self.is_debug:
|
||||
self.logger.debug(
|
||||
f"兴趣增益:{math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain}"
|
||||
)
|
||||
self.logger.debug(f"兴趣增益:{willing_gain}")
|
||||
|
||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] = current_willing
|
||||
|
||||
@@ -152,8 +156,8 @@ class MxpWillingManager(BaseWillingManager):
|
||||
self.logger.debug(f"疲劳衰减:{self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)}")
|
||||
|
||||
chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id]
|
||||
chat_person_ogoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id]
|
||||
if len(chat_person_ogoing_messages) >= 2:
|
||||
chat_person_ongoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id]
|
||||
if len(chat_person_ongoing_messages) >= 2:
|
||||
current_willing = 0
|
||||
if self.is_debug:
|
||||
self.logger.debug("进行中消息惩罚:归0")
|
||||
@@ -191,34 +195,33 @@ class MxpWillingManager(BaseWillingManager):
|
||||
basic_willing + (willing - basic_willing) * self.intention_decay_rate
|
||||
)
|
||||
|
||||
def setup(self, message, chat, is_mentioned_bot, interested_rate):
|
||||
super().setup(message, chat, is_mentioned_bot, interested_rate)
|
||||
|
||||
self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get(
|
||||
chat.stream_id, self.basic_maximum_willing
|
||||
def setup(self, message: dict, chat_stream: ChatStream):
|
||||
super().setup(message, chat_stream)
|
||||
stream_id = chat_stream.stream_id
|
||||
self.chat_reply_willing[stream_id] = self.chat_reply_willing.get(stream_id, self.basic_maximum_willing)
|
||||
self.chat_person_reply_willing[stream_id] = self.chat_person_reply_willing.get(stream_id, {})
|
||||
self.chat_person_reply_willing[stream_id][self.ongoing_messages[message.get("message_id", "")].person_id] = (
|
||||
self.chat_person_reply_willing[stream_id].get(
|
||||
self.ongoing_messages[message.get("message_id", "")].person_id,
|
||||
self.chat_reply_willing[stream_id],
|
||||
)
|
||||
self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {})
|
||||
self.chat_person_reply_willing[chat.stream_id][
|
||||
self.ongoing_messages[message.message_info.message_id].person_id
|
||||
] = self.chat_person_reply_willing[chat.stream_id].get(
|
||||
self.ongoing_messages[message.message_info.message_id].person_id, self.chat_reply_willing[chat.stream_id]
|
||||
)
|
||||
|
||||
current_time = time.time()
|
||||
if chat.stream_id not in self.chat_new_message_time:
|
||||
self.chat_new_message_time[chat.stream_id] = []
|
||||
self.chat_new_message_time[chat.stream_id].append(current_time)
|
||||
if len(self.chat_new_message_time[chat.stream_id]) > self.number_of_message_storage:
|
||||
self.chat_new_message_time[chat.stream_id].pop(0)
|
||||
if stream_id not in self.chat_new_message_time:
|
||||
self.chat_new_message_time[stream_id] = []
|
||||
self.chat_new_message_time[stream_id].append(current_time)
|
||||
if len(self.chat_new_message_time[stream_id]) > self.number_of_message_storage:
|
||||
self.chat_new_message_time[stream_id].pop(0)
|
||||
|
||||
if chat.stream_id not in self.chat_fatigue_punishment_list:
|
||||
self.chat_fatigue_punishment_list[chat.stream_id] = [
|
||||
if stream_id not in self.chat_fatigue_punishment_list:
|
||||
self.chat_fatigue_punishment_list[stream_id] = [
|
||||
(
|
||||
current_time,
|
||||
self.number_of_message_storage * self.basic_maximum_willing / self.expected_replies_per_min * 60,
|
||||
)
|
||||
]
|
||||
self.chat_fatigue_willing_attenuation[chat.stream_id] = (
|
||||
self.chat_fatigue_willing_attenuation[stream_id] = (
|
||||
-2 * self.basic_maximum_willing * self.fatigue_coefficient
|
||||
)
|
||||
|
||||
@@ -227,12 +230,11 @@ class MxpWillingManager(BaseWillingManager):
|
||||
"""意愿值转化为概率"""
|
||||
willing = max(0, willing)
|
||||
if willing < 2:
|
||||
probability = math.atan(willing * 2) / math.pi * 2
|
||||
return math.atan(willing * 2) / math.pi * 2
|
||||
elif willing < 2.5:
|
||||
probability = math.atan(willing * 4) / math.pi * 2
|
||||
return math.atan(willing * 4) / math.pi * 2
|
||||
else:
|
||||
probability = 1
|
||||
return probability
|
||||
return 1
|
||||
|
||||
async def _chat_new_message_to_change_basic_willing(self):
|
||||
"""聊天流新消息改变基础意愿"""
|
||||
@@ -259,7 +261,7 @@ class MxpWillingManager(BaseWillingManager):
|
||||
update_time = 20
|
||||
elif len(message_times) == self.number_of_message_storage:
|
||||
time_interval = current_time - message_times[0]
|
||||
basic_willing = self._basic_willing_culculate(time_interval)
|
||||
basic_willing = self._basic_willing_calculate(time_interval)
|
||||
self.chat_reply_willing[chat_id] = basic_willing
|
||||
update_time = 17 * basic_willing / self.basic_maximum_willing + 3
|
||||
else:
|
||||
@@ -268,7 +270,7 @@ class MxpWillingManager(BaseWillingManager):
|
||||
if self.is_debug:
|
||||
self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}")
|
||||
|
||||
def _basic_willing_culculate(self, t: float) -> float:
|
||||
def _basic_willing_calculate(self, t: float) -> float:
|
||||
"""基础意愿值计算"""
|
||||
return math.tan(t * self.expected_replies_per_min * math.pi / 120 / self.number_of_message_storage) / 2
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class BaseWillingManager(ABC):
|
||||
is_mentioned_bot=message.get("is_mentioned", False),
|
||||
is_emoji=message.get("is_emoji", False),
|
||||
is_picid=message.get("is_picid", False),
|
||||
interested_rate=message.get("interest_value", 0),
|
||||
interested_rate = message.get("interest_value") or 0.0,
|
||||
)
|
||||
|
||||
def delete(self, message_id: str):
|
||||
|
||||
@@ -106,6 +106,9 @@ class ChatConfig(ConfigBase):
|
||||
focus_value: float = 1.0
|
||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||
|
||||
enable_asr: bool = False
|
||||
"""是否启用语音识别"""
|
||||
|
||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
根据当前时间和聊天流获取对应的 talk_frequency
|
||||
@@ -630,6 +633,9 @@ class ModelConfig(ConfigBase):
|
||||
vlm: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""视觉语言模型配置"""
|
||||
|
||||
voice: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""语音识别模型配置"""
|
||||
|
||||
tool_use: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""专注工具使用模型配置"""
|
||||
|
||||
|
||||
@@ -216,6 +216,8 @@ class LLMRequest:
|
||||
prompt: str = None,
|
||||
image_base64: str = None,
|
||||
image_format: str = None,
|
||||
file_bytes: bytes = None,
|
||||
file_format: str = None,
|
||||
payload: dict = None,
|
||||
retry_policy: dict = None,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -225,6 +227,8 @@ class LLMRequest:
|
||||
prompt: prompt文本
|
||||
image_base64: 图片的base64编码
|
||||
image_format: 图片格式
|
||||
file_bytes: 文件的二进制数据
|
||||
file_format: 文件格式
|
||||
payload: 请求体数据
|
||||
retry_policy: 自定义重试策略
|
||||
request_type: 请求类型
|
||||
@@ -246,9 +250,12 @@ class LLMRequest:
|
||||
# 构建请求体
|
||||
if image_base64:
|
||||
payload = await self._build_payload(prompt, image_base64, image_format)
|
||||
elif file_bytes:
|
||||
payload = await self._build_formdata_payload(file_bytes, file_format)
|
||||
elif payload is None:
|
||||
payload = await self._build_payload(prompt)
|
||||
|
||||
if not file_bytes:
|
||||
if stream_mode:
|
||||
payload["stream"] = stream_mode
|
||||
|
||||
@@ -278,6 +285,8 @@ class LLMRequest:
|
||||
"stream_mode": stream_mode,
|
||||
"image_base64": image_base64, # 保留必要的exception处理所需的原始数据
|
||||
"image_format": image_format,
|
||||
"file_bytes": file_bytes,
|
||||
"file_format": file_format,
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
@@ -287,6 +296,8 @@ class LLMRequest:
|
||||
prompt: str = None,
|
||||
image_base64: str = None,
|
||||
image_format: str = None,
|
||||
file_bytes: bytes = None,
|
||||
file_format: str = None,
|
||||
payload: dict = None,
|
||||
retry_policy: dict = None,
|
||||
response_handler: callable = None,
|
||||
@@ -299,6 +310,8 @@ class LLMRequest:
|
||||
prompt: prompt文本
|
||||
image_base64: 图片的base64编码
|
||||
image_format: 图片格式
|
||||
file_bytes: 文件的二进制数据
|
||||
file_format: 文件格式
|
||||
payload: 请求体数据
|
||||
retry_policy: 自定义重试策略
|
||||
response_handler: 自定义响应处理器
|
||||
@@ -307,25 +320,36 @@ class LLMRequest:
|
||||
"""
|
||||
# 获取请求配置
|
||||
request_content = await self._prepare_request(
|
||||
endpoint, prompt, image_base64, image_format, payload, retry_policy
|
||||
endpoint, prompt, image_base64, image_format, file_bytes, file_format, payload, retry_policy
|
||||
)
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
for retry in range(request_content["policy"]["max_retries"]):
|
||||
try:
|
||||
# 使用上下文管理器处理会话
|
||||
headers = await self._build_headers()
|
||||
if file_bytes:
|
||||
headers = await self._build_headers(is_formdata=True)
|
||||
else:
|
||||
headers = await self._build_headers(is_formdata=False)
|
||||
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
||||
if request_content["stream_mode"]:
|
||||
headers["Accept"] = "text/event-stream"
|
||||
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||
post_kwargs = {"headers": headers}
|
||||
#form-data数据上传方式不同
|
||||
if file_bytes:
|
||||
post_kwargs["data"] = request_content["payload"]
|
||||
else:
|
||||
post_kwargs["json"] = request_content["payload"]
|
||||
|
||||
async with session.post(
|
||||
request_content["api_url"], headers=headers, json=request_content["payload"]
|
||||
request_content["api_url"], **post_kwargs
|
||||
) as response:
|
||||
handled_result = await self._handle_response(
|
||||
response, request_content, retry, response_handler, user_id, request_type, endpoint
|
||||
)
|
||||
return handled_result
|
||||
|
||||
except Exception as e:
|
||||
handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
|
||||
retry += count_delta # 降级不计入重试次数
|
||||
@@ -605,7 +629,7 @@ class LLMRequest:
|
||||
)
|
||||
# 安全地检查和记录请求详情
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}")
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}")
|
||||
raise RuntimeError(
|
||||
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
|
||||
)
|
||||
@@ -619,7 +643,7 @@ class LLMRequest:
|
||||
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
|
||||
# 安全地检查和记录请求详情
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}")
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}")
|
||||
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
|
||||
|
||||
async def _transform_parameters(self, params: dict) -> dict:
|
||||
@@ -640,6 +664,33 @@ class LLMRequest:
|
||||
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
|
||||
return new_params
|
||||
|
||||
async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData:
|
||||
"""构建form-data请求体"""
|
||||
# 目前只适配了音频文件
|
||||
# 如果后续要支持其他类型的文件,可以在这里添加更多的处理逻辑
|
||||
data = aiohttp.FormData()
|
||||
content_type_list = {
|
||||
"wav": "audio/wav",
|
||||
"mp3": "audio/mpeg",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
}
|
||||
|
||||
content_type = content_type_list.get(file_format)
|
||||
if not content_type:
|
||||
logger.warning(f"暂不支持的文件类型: {file_format}")
|
||||
|
||||
data.add_field(
|
||||
"file",io.BytesIO(file_bytes),
|
||||
filename=f"file.{file_format}",
|
||||
content_type=f'{content_type}' # 根据实际文件类型设置
|
||||
)
|
||||
data.add_field(
|
||||
"model", self.model_name
|
||||
)
|
||||
return data
|
||||
|
||||
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
|
||||
"""构建请求体"""
|
||||
# 复制一份参数,避免直接修改 self.params
|
||||
@@ -725,7 +776,8 @@ class LLMRequest:
|
||||
return content, reasoning_content, tool_calls
|
||||
else:
|
||||
return content, reasoning_content
|
||||
|
||||
elif "text" in result and result["text"]:
|
||||
return result["text"]
|
||||
return "没有返回结果", ""
|
||||
|
||||
@staticmethod
|
||||
@@ -739,11 +791,15 @@ class LLMRequest:
|
||||
reasoning = ""
|
||||
return content, reasoning
|
||||
|
||||
async def _build_headers(self, no_key: bool = False) -> dict:
|
||||
async def _build_headers(self, no_key: bool = False, is_formdata: bool = False) -> dict:
|
||||
"""构建请求头"""
|
||||
if no_key:
|
||||
if is_formdata:
|
||||
return {"Authorization": "Bearer **********"}
|
||||
return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
|
||||
else:
|
||||
if is_formdata:
|
||||
return {"Authorization": f"Bearer {self.api_key}"}
|
||||
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
# 防止小朋友们截图自己的key
|
||||
|
||||
@@ -761,6 +817,11 @@ class LLMRequest:
|
||||
content, reasoning_content = response
|
||||
return content, reasoning_content
|
||||
|
||||
async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple:
|
||||
"""根据输入的语音文件生成模型的异步响应"""
|
||||
response = await self._execute_request(endpoint="/audio/transcriptions",file_bytes=voice_bytes, file_format='wav')
|
||||
return response
|
||||
|
||||
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
|
||||
"""异步方式根据输入的提示生成模型的响应"""
|
||||
# 构建请求体,不硬编码max_tokens
|
||||
|
||||
@@ -18,11 +18,16 @@ from .base import (
|
||||
CommandInfo,
|
||||
PluginInfo,
|
||||
PythonDependency,
|
||||
BaseEventHandler,
|
||||
EventHandlerInfo,
|
||||
EventType,
|
||||
MaiMessages,
|
||||
)
|
||||
from .core.plugin_manager import (
|
||||
from .core import (
|
||||
plugin_manager,
|
||||
component_registry,
|
||||
dependency_manager,
|
||||
events_manager,
|
||||
)
|
||||
|
||||
# 导入工具模块
|
||||
@@ -33,7 +38,7 @@ from .utils import (
|
||||
# generate_plugin_manifest,
|
||||
)
|
||||
|
||||
from .apis.plugin_register_api import register_plugin
|
||||
from .apis import register_plugin, get_logger
|
||||
|
||||
|
||||
__version__ = "1.0.0"
|
||||
@@ -43,6 +48,7 @@ __all__ = [
|
||||
"BasePlugin",
|
||||
"BaseAction",
|
||||
"BaseCommand",
|
||||
"BaseEventHandler",
|
||||
# 类型定义
|
||||
"ComponentType",
|
||||
"ActionActivationType",
|
||||
@@ -52,15 +58,21 @@ __all__ = [
|
||||
"CommandInfo",
|
||||
"PluginInfo",
|
||||
"PythonDependency",
|
||||
"EventHandlerInfo",
|
||||
"EventType",
|
||||
# 消息
|
||||
"MaiMessages",
|
||||
# 管理器
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"dependency_manager",
|
||||
"events_manager",
|
||||
# 装饰器
|
||||
"register_plugin",
|
||||
"ConfigField",
|
||||
# 工具函数
|
||||
"ManifestValidator",
|
||||
"get_logger",
|
||||
# "ManifestGenerator",
|
||||
# "validate_plugin_manifest",
|
||||
# "generate_plugin_manifest",
|
||||
|
||||
@@ -18,7 +18,8 @@ from src.plugin_system.apis import (
|
||||
utils_api,
|
||||
plugin_register_api,
|
||||
)
|
||||
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
# 导出所有API模块,使它们可以通过 apis.xxx 方式访问
|
||||
__all__ = [
|
||||
"chat_api",
|
||||
@@ -32,4 +33,6 @@ __all__ = [
|
||||
"send_api",
|
||||
"utils_api",
|
||||
"plugin_register_api",
|
||||
"get_logger",
|
||||
"register_plugin",
|
||||
]
|
||||
|
||||
3
src/plugin_system/apis/logging_api.py
Normal file
3
src/plugin_system/apis/logging_api.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.common.logger import get_logger
|
||||
|
||||
__all__ = ["get_logger"]
|
||||
@@ -1,6 +1,8 @@
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_register")
|
||||
logger = get_logger("plugin_manager") # 复用plugin_manager名称
|
||||
|
||||
|
||||
def register_plugin(cls):
|
||||
@@ -22,16 +24,23 @@ def register_plugin(cls):
|
||||
|
||||
# 只是注册插件类,不立即实例化
|
||||
# 插件管理器会负责实例化和注册
|
||||
plugin_name = cls.plugin_name or cls.__name__
|
||||
plugin_manager.plugin_classes[plugin_name] = cls # type: ignore
|
||||
logger.debug(f"插件类已注册: {plugin_name}")
|
||||
plugin_name: str = cls.plugin_name # type: ignore
|
||||
if "." in plugin_name:
|
||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
plugin_manager.plugin_classes[plugin_name] = cls
|
||||
splitted_name = cls.__module__.split(".")
|
||||
root_path = Path(__file__)
|
||||
|
||||
# 查找项目根目录
|
||||
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
|
||||
root_path = root_path.parent
|
||||
|
||||
if not (root_path / "pyproject.toml").exists():
|
||||
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
|
||||
return cls
|
||||
|
||||
def register_event_plugin(cls, *args, **kwargs):
|
||||
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
|
||||
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
|
||||
|
||||
"""事件插件注册装饰器
|
||||
|
||||
用法:
|
||||
@register_event_plugin
|
||||
"""
|
||||
return cls
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
from .base_plugin import BasePlugin
|
||||
from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .component_types import (
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
@@ -16,6 +17,9 @@ from .component_types import (
|
||||
CommandInfo,
|
||||
PluginInfo,
|
||||
PythonDependency,
|
||||
EventHandlerInfo,
|
||||
EventType,
|
||||
MaiMessages,
|
||||
)
|
||||
from .config_types import ConfigField
|
||||
|
||||
@@ -32,4 +36,8 @@ __all__ = [
|
||||
"PluginInfo",
|
||||
"PythonDependency",
|
||||
"ConfigField",
|
||||
"EventHandlerInfo",
|
||||
"EventType",
|
||||
"BaseEventHandler",
|
||||
"MaiMessages",
|
||||
]
|
||||
|
||||
@@ -41,6 +41,7 @@ class BaseAction(ABC):
|
||||
action_message: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
|
||||
"""初始化Action组件
|
||||
|
||||
Args:
|
||||
@@ -355,7 +356,9 @@ class BaseAction(ABC):
|
||||
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
|
||||
|
||||
if "." in name:
|
||||
logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
# 获取focus_activation_type和normal_activation_type
|
||||
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
@@ -219,7 +219,9 @@ class BaseCommand(ABC):
|
||||
Returns:
|
||||
CommandInfo: 生成的Command信息对象
|
||||
"""
|
||||
|
||||
if "." in cls.command_name:
|
||||
logger.error(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return CommandInfo(
|
||||
name=cls.command_name,
|
||||
component_type=ComponentType.COMMAND,
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from .plugin_base import PluginBase
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
class BaseEventPlugin(PluginBase):
|
||||
"""基于事件的插件基类
|
||||
|
||||
所有事件类型的插件都应该继承这个基类
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
51
src/plugin_system/base/base_events_handler.py
Normal file
51
src/plugin_system/base/base_events_handler.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
|
||||
|
||||
logger = get_logger("base_event_handler")
|
||||
|
||||
|
||||
class BaseEventHandler(ABC):
|
||||
"""事件处理器基类
|
||||
|
||||
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
||||
"""
|
||||
|
||||
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
|
||||
handler_name: str = "" # 处理器名称
|
||||
handler_description: str = ""
|
||||
weight: int = 0 # 权重,数值越大优先级越高
|
||||
intercept_message: bool = False # 是否拦截消息,默认为否
|
||||
|
||||
def __init__(self):
|
||||
self.log_prefix = "[EventHandler]"
|
||||
if self.event_type == EventType.UNKNOWN:
|
||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, message: MaiMessages) -> Tuple[bool, Optional[str]]:
|
||||
"""执行事件处理的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息)
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 execute 方法")
|
||||
|
||||
@classmethod
|
||||
def get_handler_info(cls) -> "EventHandlerInfo":
|
||||
"""获取事件处理器的信息"""
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
|
||||
if "." in name:
|
||||
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return EventHandlerInfo(
|
||||
name=name,
|
||||
component_type=ComponentType.EVENT_HANDLER,
|
||||
description=getattr(cls, "handler_description", "events处理器"),
|
||||
event_type=cls.event_type,
|
||||
weight=cls.weight,
|
||||
intercept_message=cls.intercept_message,
|
||||
)
|
||||
@@ -1,9 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Type
|
||||
from typing import List, Type, Tuple, Union
|
||||
from .plugin_base import PluginBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.plugin_system.base.component_types import ComponentInfo, ActionInfo, CommandInfo, EventHandlerInfo
|
||||
from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
|
||||
logger = get_logger("base_plugin")
|
||||
|
||||
@@ -21,7 +24,15 @@ class BasePlugin(PluginBase):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_plugin_components(self) -> List[tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(
|
||||
self,
|
||||
) -> List[
|
||||
Union[
|
||||
Tuple[ActionInfo, Type[BaseAction]],
|
||||
Tuple[CommandInfo, Type[BaseCommand]],
|
||||
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
|
||||
]
|
||||
]:
|
||||
"""获取插件包含的组件列表
|
||||
|
||||
子类必须实现此方法,返回组件信息和组件类的列表
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from maim_message import Seg
|
||||
|
||||
|
||||
# 组件类型枚举
|
||||
@@ -10,7 +11,10 @@ class ComponentType(Enum):
|
||||
ACTION = "action" # 动作组件
|
||||
COMMAND = "command" # 命令组件
|
||||
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
||||
LISTENER = "listener" # 事件监听组件(预留)
|
||||
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
# 动作激活类型枚举
|
||||
@@ -46,12 +50,17 @@ class EventType(Enum):
|
||||
事件类型枚举类
|
||||
"""
|
||||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_MESSAGE = "on_message"
|
||||
ON_PLAN = "on_plan"
|
||||
POST_LLM = "post_llm"
|
||||
AFTER_LLM = "after_llm"
|
||||
POST_SEND = "post_send"
|
||||
AFTER_SEND = "after_send"
|
||||
UNKNOWN = "unknown" # 未知事件类型
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -142,6 +151,19 @@ class CommandInfo(ComponentInfo):
|
||||
self.component_type = ComponentType.COMMAND
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventHandlerInfo(ComponentInfo):
|
||||
"""事件处理器组件信息"""
|
||||
|
||||
event_type: EventType = EventType.ON_MESSAGE # 监听事件类型
|
||||
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
|
||||
weight: int = 0 # 事件处理器权重,决定执行顺序
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginInfo:
|
||||
"""插件信息"""
|
||||
@@ -198,3 +220,42 @@ class PluginInfo:
|
||||
def get_pip_requirements(self) -> List[str]:
|
||||
"""获取所有pip安装格式的依赖"""
|
||||
return [dep.get_pip_requirement() for dep in self.python_dependencies]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaiMessages:
|
||||
"""MaiM插件消息"""
|
||||
|
||||
message_segments: List[Seg] = field(default_factory=list)
|
||||
"""消息段列表,支持多段消息"""
|
||||
|
||||
message_base_info: Dict[str, Any] = field(default_factory=dict)
|
||||
"""消息基本信息,包含平台,用户信息等数据"""
|
||||
|
||||
plain_text: str = ""
|
||||
"""纯文本消息内容"""
|
||||
|
||||
raw_message: Optional[str] = None
|
||||
"""原始消息内容"""
|
||||
|
||||
is_group_message: bool = False
|
||||
"""是否为群组消息"""
|
||||
|
||||
is_private_message: bool = False
|
||||
"""是否为私聊消息"""
|
||||
|
||||
stream_id: Optional[str] = None
|
||||
"""流ID,用于标识消息流"""
|
||||
|
||||
llm_prompt: Optional[str] = None
|
||||
"""LLM提示词"""
|
||||
|
||||
llm_response: Optional[str] = None
|
||||
"""LLM响应内容"""
|
||||
|
||||
additional_data: Dict[Any, Any] = field(default_factory=dict)
|
||||
"""附加数据,可以存储额外信息"""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.message_segments is None:
|
||||
self.message_segments = []
|
||||
|
||||
@@ -7,9 +7,11 @@
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"dependency_manager",
|
||||
"events_manager",
|
||||
]
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
import re
|
||||
|
||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
)
|
||||
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
@@ -23,12 +26,11 @@ class ComponentRegistry:
|
||||
|
||||
def __init__(self):
|
||||
# 组件注册表
|
||||
self._components: Dict[str, ComponentInfo] = {} # 组件名 -> 组件信息
|
||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {
|
||||
ComponentType.ACTION: {},
|
||||
ComponentType.COMMAND: {},
|
||||
}
|
||||
self._component_classes: Dict[str, Union[Type[BaseCommand], Type[BaseAction]]] = {} # 组件名 -> 组件类
|
||||
self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息
|
||||
# 类型 -> 命名空间式名称 -> 组件信息
|
||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||
# 命名空间式组件名 -> 组件类
|
||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
|
||||
|
||||
# 插件注册表
|
||||
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
|
||||
@@ -39,20 +41,43 @@ class ComponentRegistry:
|
||||
|
||||
# Command特定注册表
|
||||
self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类
|
||||
self._command_patterns: Dict[Pattern, Type[BaseCommand]] = {} # 编译后的正则 -> command类
|
||||
self._command_patterns: Dict[Pattern, str] = {} # 编译后的正则 -> command名
|
||||
|
||||
# EventHandler特定注册表
|
||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} # event_handler名 -> event_handler类
|
||||
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} # 启用的事件处理器
|
||||
|
||||
logger.info("组件注册中心初始化完成")
|
||||
|
||||
# === 通用组件注册方法 ===
|
||||
# == 注册方法 ==
|
||||
|
||||
def register_plugin(self, plugin_info: PluginInfo) -> bool:
|
||||
"""注册插件
|
||||
|
||||
Args:
|
||||
plugin_info: 插件信息
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
plugin_name = plugin_info.name
|
||||
|
||||
if plugin_name in self._plugins:
|
||||
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._plugins[plugin_name] = plugin_info
|
||||
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
|
||||
return True
|
||||
|
||||
def register_component(
|
||||
self, component_info: ComponentInfo, component_class: Union[Type[BaseCommand], Type[BaseAction]]
|
||||
self, component_info: ComponentInfo, component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler]]
|
||||
) -> bool:
|
||||
"""注册组件
|
||||
|
||||
Args:
|
||||
component_info: 组件信息
|
||||
component_class: 组件类
|
||||
component_info (ComponentInfo): 组件信息
|
||||
component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
@@ -60,68 +85,110 @@ class ComponentRegistry:
|
||||
component_name = component_info.name
|
||||
component_type = component_info.component_type
|
||||
plugin_name = getattr(component_info, "plugin_name", "unknown")
|
||||
if "." in component_name:
|
||||
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
if "." in plugin_name:
|
||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
|
||||
# 🔥 系统级别自动区分:为不同类型的组件添加命名空间前缀
|
||||
if component_type == ComponentType.ACTION:
|
||||
namespaced_name = f"action.{component_name}"
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
namespaced_name = f"command.{component_name}"
|
||||
else:
|
||||
# 未来扩展的组件类型
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
|
||||
# 检查命名空间化的名称是否冲突
|
||||
if namespaced_name in self._components:
|
||||
existing_info = self._components[namespaced_name]
|
||||
existing_plugin = getattr(existing_info, "plugin_name", "unknown")
|
||||
|
||||
logger.warning(
|
||||
f"组件冲突: {component_type.value}组件 '{component_name}' "
|
||||
f"已被插件 '{existing_plugin}' 注册,跳过插件 '{plugin_name}' 的注册"
|
||||
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
|
||||
)
|
||||
return False
|
||||
|
||||
# 注册到通用注册表(使用命名空间化的名称)
|
||||
self._components[namespaced_name] = component_info
|
||||
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称)
|
||||
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
|
||||
self._component_classes[namespaced_name] = component_class
|
||||
self._components_classes[namespaced_name] = component_class
|
||||
|
||||
# 根据组件类型进行特定注册(使用原始名称)
|
||||
if component_type == ComponentType.ACTION:
|
||||
self._register_action_component(component_info, component_class) # type: ignore
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
self._register_command_component(component_info, component_class) # type: ignore
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
ret = self._register_action_component(component_info, component_class) # type: ignore
|
||||
case ComponentType.COMMAND:
|
||||
ret = self._register_command_component(component_info, component_class) # type: ignore
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
ret = self._register_event_handler_component(component_info, component_class) # type: ignore
|
||||
case _:
|
||||
logger.warning(f"未知组件类型: {component_type}")
|
||||
|
||||
if not ret:
|
||||
return False
|
||||
logger.debug(
|
||||
f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' "
|
||||
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' "
|
||||
f"({component_class.__name__}) [插件: {plugin_name}]"
|
||||
)
|
||||
return True
|
||||
|
||||
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]):
|
||||
# -------------------------------- NEED REFACTORING --------------------------------
|
||||
# -------------------------------- LOGIC ERROR -------------------------------------
|
||||
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
|
||||
"""注册Action组件到Action特定注册表"""
|
||||
action_name = action_info.name
|
||||
if not (action_name := action_info.name):
|
||||
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
|
||||
logger.error(f"注册失败: {action_name} 不是有效的Action")
|
||||
return False
|
||||
|
||||
self._action_registry[action_name] = action_class
|
||||
|
||||
# 如果启用,添加到默认动作集
|
||||
if action_info.enabled:
|
||||
self._default_actions[action_name] = action_info
|
||||
|
||||
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]):
|
||||
return True
|
||||
|
||||
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
|
||||
"""注册Command组件到Command特定注册表"""
|
||||
command_name = command_info.name
|
||||
if not (command_name := command_info.name):
|
||||
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
|
||||
logger.error(f"注册失败: {command_name} 不是有效的Command")
|
||||
return False
|
||||
|
||||
self._command_registry[command_name] = command_class
|
||||
|
||||
# 编译正则表达式并注册
|
||||
if command_info.command_pattern:
|
||||
# 如果启用了且有匹配模式
|
||||
if command_info.enabled and command_info.command_pattern:
|
||||
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
|
||||
self._command_patterns[pattern] = command_class
|
||||
if pattern not in self._command_patterns:
|
||||
self._command_patterns[pattern] = command_name
|
||||
|
||||
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
|
||||
|
||||
return True
|
||||
|
||||
def _register_event_handler_component(
|
||||
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
|
||||
) -> bool:
|
||||
if not (handler_name := handler_info.name):
|
||||
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
|
||||
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
|
||||
return False
|
||||
|
||||
self._event_handler_registry[handler_name] = handler_class
|
||||
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
if events_manager.register_event_subscriber(handler_info, handler_class):
|
||||
self._enabled_event_handlers[handler_name] = handler_class
|
||||
return True
|
||||
else:
|
||||
logger.error(f"注册事件处理器 {handler_name} 失败")
|
||||
return False
|
||||
|
||||
# === 组件查询方法 ===
|
||||
|
||||
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: # type: ignore
|
||||
def get_component_info(
|
||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
||||
) -> Optional[ComponentInfo]:
|
||||
# sourcery skip: class-extract-method
|
||||
"""获取组件信息,支持自动命名空间解析
|
||||
|
||||
@@ -138,18 +205,12 @@ class ComponentRegistry:
|
||||
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
if component_type == ComponentType.ACTION:
|
||||
namespaced_name = f"action.{component_name}"
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
namespaced_name = f"command.{component_name}"
|
||||
else:
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
return self._components.get(namespaced_name)
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
candidates = []
|
||||
for namespace_prefix in ["action", "command"]:
|
||||
for namespace_prefix in [types.value for types in ComponentType]:
|
||||
namespaced_name = f"{namespace_prefix}.{component_name}"
|
||||
if component_info := self._components.get(namespaced_name):
|
||||
candidates.append((namespace_prefix, namespaced_name, component_info))
|
||||
@@ -171,8 +232,8 @@ class ComponentRegistry:
|
||||
def get_component_class(
|
||||
self,
|
||||
component_name: str,
|
||||
component_type: ComponentType = None, # type: ignore
|
||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction]]]:
|
||||
component_type: Optional[ComponentType] = None,
|
||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler]]]:
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
@@ -184,29 +245,23 @@ class ComponentRegistry:
|
||||
"""
|
||||
# 1. 如果已经是命名空间化的名称,直接查找
|
||||
if "." in component_name:
|
||||
return self._component_classes.get(component_name)
|
||||
return self._components_classes.get(component_name)
|
||||
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
if component_type == ComponentType.ACTION:
|
||||
namespaced_name = f"action.{component_name}"
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
namespaced_name = f"command.{component_name}"
|
||||
else:
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
|
||||
return self._component_classes.get(namespaced_name)
|
||||
return self._components_classes.get(namespaced_name)
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
candidates = []
|
||||
for namespace_prefix in ["action", "command"]:
|
||||
for namespace_prefix in [types.value for types in ComponentType]:
|
||||
namespaced_name = f"{namespace_prefix}.{component_name}"
|
||||
if component_class := self._component_classes.get(namespaced_name):
|
||||
if component_class := self._components_classes.get(namespaced_name):
|
||||
candidates.append((namespace_prefix, namespaced_name, component_class))
|
||||
|
||||
if len(candidates) == 1:
|
||||
# 只有一个匹配,直接返回
|
||||
namespace, full_name, cls = candidates[0]
|
||||
_, full_name, cls = candidates[0]
|
||||
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
|
||||
return cls
|
||||
elif len(candidates) > 1:
|
||||
@@ -235,7 +290,7 @@ class ComponentRegistry:
|
||||
"""获取Action注册表(用于兼容现有系统)"""
|
||||
return self._action_registry.copy()
|
||||
|
||||
def get_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||
"""获取Action信息"""
|
||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||
return info if isinstance(info, ActionInfo) else None
|
||||
@@ -247,18 +302,18 @@ class ComponentRegistry:
|
||||
# === Command特定查询方法 ===
|
||||
|
||||
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
|
||||
"""获取Command注册表(用于兼容现有系统)"""
|
||||
"""获取Command注册表"""
|
||||
return self._command_registry.copy()
|
||||
|
||||
def get_command_patterns(self) -> Dict[Pattern, Type[BaseCommand]]:
|
||||
"""获取Command模式注册表(用于兼容现有系统)"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def get_command_info(self, command_name: str) -> Optional[CommandInfo]:
|
||||
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
|
||||
"""获取Command信息"""
|
||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||
return info if isinstance(info, CommandInfo) else None
|
||||
|
||||
def get_command_patterns(self) -> Dict[Pattern, str]:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
@@ -270,47 +325,36 @@ class ComponentRegistry:
|
||||
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
|
||||
"""
|
||||
|
||||
for pattern, command_class in self._command_patterns.items():
|
||||
if match := pattern.match(text):
|
||||
command_name = None
|
||||
# 查找对应的组件信息
|
||||
for name, cls in self._command_registry.items():
|
||||
if cls == command_class:
|
||||
command_name = name
|
||||
break
|
||||
|
||||
# 检查命令是否启用
|
||||
if command_name:
|
||||
command_info = self.get_command_info(command_name)
|
||||
if command_info and command_info.enabled:
|
||||
candidates = [pattern for pattern in self._command_patterns if pattern.match(text)]
|
||||
if not candidates:
|
||||
return None
|
||||
if len(candidates) > 1:
|
||||
logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配")
|
||||
command_name = self._command_patterns[candidates[0]]
|
||||
command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore
|
||||
return (
|
||||
command_class,
|
||||
match.groupdict(),
|
||||
self._command_registry[command_name],
|
||||
candidates[0].match(text).groupdict(), # type: ignore
|
||||
command_info.intercept_message,
|
||||
command_info.plugin_name,
|
||||
)
|
||||
return None
|
||||
|
||||
# === 插件管理方法 ===
|
||||
# === 事件处理器特定查询方法 ===
|
||||
|
||||
def register_plugin(self, plugin_info: PluginInfo) -> bool:
|
||||
"""注册插件
|
||||
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||
"""获取事件处理器注册表"""
|
||||
return self._event_handler_registry.copy()
|
||||
|
||||
Args:
|
||||
plugin_info: 插件信息
|
||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
|
||||
"""获取事件处理器信息"""
|
||||
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
||||
return info if isinstance(info, EventHandlerInfo) else None
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
plugin_name = plugin_info.name
|
||||
def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||
"""获取启用的事件处理器"""
|
||||
return self._enabled_event_handlers.copy()
|
||||
|
||||
if plugin_name in self._plugins:
|
||||
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._plugins[plugin_name] = plugin_info
|
||||
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
|
||||
return True
|
||||
# === 插件查询方法 ===
|
||||
|
||||
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
|
||||
"""获取插件信息"""
|
||||
@@ -344,82 +388,22 @@ class ComponentRegistry:
|
||||
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
|
||||
return plugin_instance.config if plugin_instance else None
|
||||
|
||||
# === 状态管理方法 ===
|
||||
|
||||
# def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||
# # -------------------------------- LOGIC ERROR -------------------------------------
|
||||
# """启用组件,支持命名空间解析"""
|
||||
# # 首先尝试找到正确的命名空间化名称
|
||||
# component_info = self.get_component_info(component_name, component_type)
|
||||
# if not component_info:
|
||||
# return False
|
||||
|
||||
# # 根据组件类型构造正确的命名空间化名称
|
||||
# if component_info.component_type == ComponentType.ACTION:
|
||||
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
|
||||
# elif component_info.component_type == ComponentType.COMMAND:
|
||||
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
|
||||
# else:
|
||||
# namespaced_name = (
|
||||
# f"{component_info.component_type.value}.{component_name}"
|
||||
# if "." not in component_name
|
||||
# else component_name
|
||||
# )
|
||||
|
||||
# if namespaced_name in self._components:
|
||||
# self._components[namespaced_name].enabled = True
|
||||
# # 如果是Action,更新默认动作集
|
||||
# # ---- HERE ----
|
||||
# # if isinstance(component_info, ActionInfo):
|
||||
# # self._action_descriptions[component_name] = component_info.description
|
||||
# logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
|
||||
# return True
|
||||
# return False
|
||||
|
||||
# def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||
# # -------------------------------- LOGIC ERROR -------------------------------------
|
||||
# """禁用组件,支持命名空间解析"""
|
||||
# # 首先尝试找到正确的命名空间化名称
|
||||
# component_info = self.get_component_info(component_name, component_type)
|
||||
# if not component_info:
|
||||
# return False
|
||||
|
||||
# # 根据组件类型构造正确的命名空间化名称
|
||||
# if component_info.component_type == ComponentType.ACTION:
|
||||
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
|
||||
# elif component_info.component_type == ComponentType.COMMAND:
|
||||
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
|
||||
# else:
|
||||
# namespaced_name = (
|
||||
# f"{component_info.component_type.value}.{component_name}"
|
||||
# if "." not in component_name
|
||||
# else component_name
|
||||
# )
|
||||
|
||||
# if namespaced_name in self._components:
|
||||
# self._components[namespaced_name].enabled = False
|
||||
# # 如果是Action,从默认动作集中移除
|
||||
# # ---- HERE ----
|
||||
# # if component_name in self._action_descriptions:
|
||||
# # del self._action_descriptions[component_name]
|
||||
# logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
|
||||
# return True
|
||||
# return False
|
||||
|
||||
def get_registry_stats(self) -> Dict[str, Any]:
|
||||
"""获取注册中心统计信息"""
|
||||
action_components: int = 0
|
||||
command_components: int = 0
|
||||
events_handlers: int = 0
|
||||
for component in self._components.values():
|
||||
if component.component_type == ComponentType.ACTION:
|
||||
action_components += 1
|
||||
elif component.component_type == ComponentType.COMMAND:
|
||||
command_components += 1
|
||||
elif component.component_type == ComponentType.EVENT_HANDLER:
|
||||
events_handlers += 1
|
||||
return {
|
||||
"action_components": action_components,
|
||||
"command_components": command_components,
|
||||
"event_handlers": events_handlers,
|
||||
"total_components": len(self._components),
|
||||
"total_plugins": len(self._plugins),
|
||||
"components_by_type": {
|
||||
@@ -430,5 +414,4 @@ class ComponentRegistry:
|
||||
}
|
||||
|
||||
|
||||
# 全局组件注册中心实例
|
||||
component_registry = ComponentRegistry()
|
||||
|
||||
@@ -1,11 +1,136 @@
|
||||
from typing import List, Dict, Type
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Type
|
||||
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
|
||||
logger = get_logger("events_manager")
|
||||
|
||||
|
||||
class EventsManager:
|
||||
def __init__(self):
|
||||
# 有权重的 events 订阅者注册表
|
||||
self.events_subscribers: Dict[EventType, List[Dict[int, Type]]] = {event: [] for event in EventType}
|
||||
self.events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
|
||||
self.handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
||||
|
||||
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""注册事件处理器
|
||||
|
||||
Args:
|
||||
handler_info (EventHandlerInfo): 事件处理器信息
|
||||
handler_class (Type[BaseEventHandler]): 事件处理器类
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
handler_name = handler_info.name
|
||||
plugin_name = getattr(handler_info, "plugin_name", "unknown")
|
||||
|
||||
namespace_name = f"{plugin_name}.{handler_name}"
|
||||
if namespace_name in self.handler_mapping:
|
||||
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
if not issubclass(handler_class, BaseEventHandler):
|
||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||
return False
|
||||
|
||||
self.handler_mapping[namespace_name] = handler_class
|
||||
|
||||
return self._insert_event_handler(handler_class)
|
||||
|
||||
async def handler_mai_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
message: MessageRecv,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional[str] = None,
|
||||
) -> None:
|
||||
"""处理 events"""
|
||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||
for handler in self.events_subscribers.get(event_type, []):
|
||||
if handler.intercept_message:
|
||||
await handler.execute(transformed_message)
|
||||
else:
|
||||
asyncio.create_task(handler.execute(transformed_message))
|
||||
|
||||
def _insert_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""插入事件处理器到对应的事件类型列表中"""
|
||||
if handler_class.event_type == EventType.UNKNOWN:
|
||||
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
|
||||
return False
|
||||
|
||||
self.events_subscribers[handler_class.event_type].append(handler_class())
|
||||
self.events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
|
||||
|
||||
return True
|
||||
|
||||
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""从事件类型列表中移除事件处理器"""
|
||||
if handler_class.event_type == EventType.UNKNOWN:
|
||||
logger.warning(f"事件处理器 {handler_class.__name__} 的事件类型未知,不存在于处理器列表中")
|
||||
return False
|
||||
|
||||
handlers = self.events_subscribers[handler_class.event_type]
|
||||
for i, handler in enumerate(handlers):
|
||||
if isinstance(handler, handler_class):
|
||||
del handlers[i]
|
||||
logger.debug(f"事件处理器 {handler_class.__name__} 已移除")
|
||||
return True
|
||||
|
||||
logger.warning(f"未找到事件处理器 {handler_class.__name__},无法移除")
|
||||
return False
|
||||
|
||||
def _transform_event_message(
|
||||
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[str] = None
|
||||
) -> MaiMessages:
|
||||
"""转换事件消息格式"""
|
||||
# 直接赋值部分内容
|
||||
transformed_message = MaiMessages(
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response=llm_response,
|
||||
raw_message=message.raw_message,
|
||||
additional_data=message.message_info.additional_config or {},
|
||||
)
|
||||
|
||||
# 消息段处理
|
||||
if message.message_segment.type == "seglist":
|
||||
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
|
||||
else:
|
||||
transformed_message.message_segments = [message.message_segment]
|
||||
|
||||
# stream_id 处理
|
||||
if hasattr(message, "chat_stream"):
|
||||
transformed_message.stream_id = message.chat_stream.stream_id
|
||||
|
||||
# 处理后文本
|
||||
transformed_message.plain_text = message.processed_plain_text
|
||||
|
||||
# 基本信息
|
||||
if message.message_info.platform:
|
||||
transformed_message.message_base_info["platform"] = message.message_info.platform
|
||||
if message.message_info.group_info:
|
||||
transformed_message.is_group_message = True
|
||||
transformed_message.message_base_info.update(
|
||||
{
|
||||
"group_id": message.message_info.group_info.group_id,
|
||||
"group_name": message.message_info.group_info.group_name,
|
||||
}
|
||||
)
|
||||
if message.message_info.user_info:
|
||||
if not transformed_message.is_group_message:
|
||||
transformed_message.is_private_message = True
|
||||
transformed_message.message_base_info.update(
|
||||
{
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
|
||||
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
|
||||
}
|
||||
)
|
||||
|
||||
return transformed_message
|
||||
|
||||
|
||||
events_manager = EventsManager()
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||
import os
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
from inspect import getmodule
|
||||
from pathlib import Path
|
||||
import inspect
|
||||
import traceback
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||
@@ -28,7 +30,7 @@ class PluginManager:
|
||||
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
|
||||
|
||||
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
|
||||
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件类及其错误信息,插件名 -> 错误信息
|
||||
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
|
||||
|
||||
# 确保插件目录存在
|
||||
self._ensure_plugin_directories()
|
||||
@@ -107,12 +109,8 @@ class PluginManager:
|
||||
# 使用记录的插件目录路径
|
||||
plugin_dir = self.plugin_paths.get(plugin_name)
|
||||
|
||||
# 如果没有记录,则尝试查找(fallback)
|
||||
# 如果没有记录,直接返回失败
|
||||
if not plugin_dir:
|
||||
plugin_dir = self._find_plugin_directory(plugin_class)
|
||||
if plugin_dir:
|
||||
self.plugin_paths[plugin_name] = plugin_dir # 更新路径
|
||||
else:
|
||||
return False, 1
|
||||
|
||||
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
||||
@@ -360,24 +358,14 @@ class PluginManager:
|
||||
|
||||
logger.debug(f"正在扫描插件根目录: {directory}")
|
||||
|
||||
# 遍历目录中的所有Python文件和包
|
||||
# 遍历目录中的所有包
|
||||
for item in os.listdir(directory):
|
||||
item_path = os.path.join(directory, item)
|
||||
|
||||
if os.path.isfile(item_path) and item.endswith(".py") and item != "__init__.py":
|
||||
# 单文件插件
|
||||
plugin_name = Path(item_path).stem
|
||||
if self._load_plugin_module_file(item_path, plugin_name, directory):
|
||||
loaded_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
elif os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
|
||||
# 插件包
|
||||
if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
|
||||
plugin_file = os.path.join(item_path, "plugin.py")
|
||||
if os.path.exists(plugin_file):
|
||||
plugin_name = item # 使用目录名作为插件名
|
||||
if self._load_plugin_module_file(plugin_file, plugin_name, item_path):
|
||||
if self._load_plugin_module_file(plugin_file):
|
||||
loaded_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
@@ -387,14 +375,16 @@ class PluginManager:
|
||||
def _find_plugin_directory(self, plugin_class: Type[PluginBase]) -> Optional[str]:
|
||||
"""查找插件类对应的目录路径"""
|
||||
try:
|
||||
module = getmodule(plugin_class)
|
||||
if module and hasattr(module, "__file__") and module.__file__:
|
||||
return os.path.dirname(module.__file__)
|
||||
# module = getmodule(plugin_class)
|
||||
# if module and hasattr(module, "__file__") and module.__file__:
|
||||
# return os.path.dirname(module.__file__)
|
||||
file_path = inspect.getfile(plugin_class)
|
||||
return os.path.dirname(file_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"通过inspect获取插件目录失败: {e}")
|
||||
return None
|
||||
|
||||
def _load_plugin_module_file(self, plugin_file: str, plugin_name: str, plugin_dir: str) -> bool:
|
||||
def _load_plugin_module_file(self, plugin_file: str) -> bool:
|
||||
# sourcery skip: extract-method
|
||||
"""加载单个插件模块文件
|
||||
|
||||
@@ -405,12 +395,7 @@ class PluginManager:
|
||||
"""
|
||||
# 生成模块名
|
||||
plugin_path = Path(plugin_file)
|
||||
if plugin_path.parent.name != "plugins":
|
||||
# 插件包格式:parent_dir.plugin
|
||||
module_name = f"plugins.{plugin_path.parent.name}.plugin"
|
||||
else:
|
||||
# 单文件格式:plugins.filename
|
||||
module_name = f"plugins.{plugin_path.stem}"
|
||||
module_name = ".".join(plugin_path.parent.parts)
|
||||
|
||||
try:
|
||||
# 动态导入插件模块
|
||||
@@ -422,16 +407,13 @@ class PluginManager:
|
||||
module = module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 记录插件名和目录路径的映射
|
||||
self.plugin_paths[plugin_name] = plugin_dir
|
||||
|
||||
logger.debug(f"插件模块加载成功: {plugin_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
|
||||
logger.error(error_msg)
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
self.failed_plugins[module_name] = error_msg
|
||||
return False
|
||||
|
||||
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
@@ -475,13 +457,14 @@ class PluginManager:
|
||||
stats = component_registry.get_registry_stats()
|
||||
action_count = stats.get("action_components", 0)
|
||||
command_count = stats.get("command_components", 0)
|
||||
event_handler_count = stats.get("event_handlers", 0)
|
||||
total_components = stats.get("total_components", 0)
|
||||
|
||||
# 📋 显示插件加载总览
|
||||
if total_registered > 0:
|
||||
logger.info("🎉 插件系统加载完成!")
|
||||
logger.info(
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})"
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, EventHandler: {event_handler_count})"
|
||||
)
|
||||
|
||||
# 显示详细的插件列表
|
||||
@@ -510,8 +493,9 @@ class PluginManager:
|
||||
|
||||
# 组件列表
|
||||
if plugin_info.components:
|
||||
action_components = [c for c in plugin_info.components if c.component_type.name == "ACTION"]
|
||||
command_components = [c for c in plugin_info.components if c.component_type.name == "COMMAND"]
|
||||
action_components = [c for c in plugin_info.components if c.component_type == ComponentType.ACTION]
|
||||
command_components = [c for c in plugin_info.components if c.component_type == ComponentType.COMMAND]
|
||||
event_handler_components = [c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER]
|
||||
|
||||
if action_components:
|
||||
action_names = [c.name for c in action_components]
|
||||
@@ -521,6 +505,10 @@ class PluginManager:
|
||||
command_names = [c.name for c in command_components]
|
||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||
|
||||
if event_handler_components:
|
||||
event_handler_names = [c.name for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||
|
||||
# 依赖信息
|
||||
if plugin_info.dependencies:
|
||||
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
|
||||
@@ -530,6 +518,12 @@ class PluginManager:
|
||||
config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌"
|
||||
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
|
||||
|
||||
root_path = Path(__file__)
|
||||
|
||||
# 查找项目根目录
|
||||
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
|
||||
root_path = root_path.parent
|
||||
|
||||
# 显示目录统计
|
||||
logger.info("📂 加载目录统计:")
|
||||
for directory in self.plugin_directories:
|
||||
@@ -537,7 +531,11 @@ class PluginManager:
|
||||
plugins_in_dir = []
|
||||
for plugin_name in self.loaded_plugins.keys():
|
||||
plugin_path = self.plugin_paths.get(plugin_name, "")
|
||||
if plugin_path.startswith(directory):
|
||||
if (
|
||||
Path(plugin_path)
|
||||
.resolve()
|
||||
.is_relative_to(Path(os.path.join(str(root_path), directory)).resolve())
|
||||
):
|
||||
plugins_in_dir.append(plugin_name)
|
||||
|
||||
if plugins_in_dir:
|
||||
|
||||
@@ -80,8 +80,9 @@ class ReplyAction(BaseAction):
|
||||
logger.info(f"{self.log_prefix} 回复目标: {reply_to}")
|
||||
|
||||
try:
|
||||
prepared_reply = self.action_data.get("prepared_reply", "")
|
||||
if not prepared_reply:
|
||||
if prepared_reply := self.action_data.get("prepared_reply", ""):
|
||||
reply_text = prepared_reply
|
||||
else:
|
||||
try:
|
||||
success, reply_set, _ = await asyncio.wait_for(
|
||||
generator_api.generate_reply(
|
||||
@@ -109,9 +110,6 @@ class ReplyAction(BaseAction):
|
||||
logger.info(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复"
|
||||
)
|
||||
else:
|
||||
reply_text = prepared_reply
|
||||
|
||||
# 构建回复文本
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
@@ -120,8 +118,9 @@ class ReplyAction(BaseAction):
|
||||
data = reply_seg[1]
|
||||
if not first_replied:
|
||||
if need_reply:
|
||||
await self.send_text(content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False)
|
||||
first_replied = True
|
||||
await self.send_text(
|
||||
content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
||||
)
|
||||
else:
|
||||
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False)
|
||||
first_replied = True
|
||||
@@ -190,17 +189,15 @@ class CoreActionsPlugin(BasePlugin):
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 从配置动态设置Action/Command ---
|
||||
emoji_chance = global_config.emoji.emoji_chance
|
||||
if global_config.emoji.emoji_activate_type == "random":
|
||||
EmojiAction.random_activation_probability = emoji_chance
|
||||
EmojiAction.focus_activation_type = ActionActivationType.RANDOM
|
||||
EmojiAction.normal_activation_type = ActionActivationType.RANDOM
|
||||
elif global_config.emoji.emoji_activate_type == "llm":
|
||||
if global_config.emoji.emoji_activate_type == "llm":
|
||||
EmojiAction.random_activation_probability = 0.0
|
||||
EmojiAction.focus_activation_type = ActionActivationType.LLM_JUDGE
|
||||
EmojiAction.normal_activation_type = ActionActivationType.LLM_JUDGE
|
||||
|
||||
elif global_config.emoji.emoji_activate_type == "random":
|
||||
EmojiAction.random_activation_probability = global_config.emoji.emoji_chance
|
||||
EmojiAction.focus_activation_type = ActionActivationType.RANDOM
|
||||
EmojiAction.normal_activation_type = ActionActivationType.RANDOM
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
if self.get_config("components.enable_reply", True):
|
||||
|
||||
@@ -33,7 +33,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
query = function_args.get("query")
|
||||
query: str = function_args.get("query") # type: ignore
|
||||
# threshold = function_args.get("threshold", 0.4)
|
||||
|
||||
# 检查LPMM知识库是否启用
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "4.4.3"
|
||||
version = "4.4.4"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
@@ -87,6 +87,7 @@ talk_frequency_adjust = [
|
||||
# - 时间支持跨天,例如 "00:10,0.3" 表示从凌晨0:10开始使用频率0.3
|
||||
# - 系统会自动将 "platform:id:type" 转换为内部的哈希chat_id进行匹配
|
||||
|
||||
enable_asr = false # 是否启用语音识别,启用后麦麦可以通过语音输入进行对话,启用该功能需要配置语音识别模型[model.voice]
|
||||
|
||||
[message_receive]
|
||||
# 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息
|
||||
@@ -294,6 +295,12 @@ provider = "SILICONFLOW"
|
||||
pri_in = 0.35
|
||||
pri_out = 0.35
|
||||
|
||||
[model.voice] # 语音识别模型
|
||||
name = "FunAudioLLM/SenseVoiceSmall"
|
||||
provider = "SILICONFLOW"
|
||||
pri_in = 0
|
||||
pri_out = 0
|
||||
|
||||
[model.tool_use] #工具调用模型,需要使用支持工具调用的模型
|
||||
name = "Qwen/Qwen3-14B"
|
||||
provider = "SILICONFLOW"
|
||||
|
||||
Reference in New Issue
Block a user