This commit is contained in:
SengokuCola
2025-07-20 14:24:49 +08:00
34 changed files with 913 additions and 442 deletions

View File

@@ -1,26 +1,60 @@
# 插件API与规范修改 # 插件API与规范修改
1. 现在`plugin_system``__init__.py`文件中包含了所有插件API的导入用户可以直接使用`from plugin_system import *`来导入所有API。 1. 现在`plugin_system``__init__.py`文件中包含了所有插件API的导入用户可以直接使用`from src.plugin_system import *`来导入所有API。
2. register_plugin函数现在转移到了`plugin_system.apis.plugin_register_api`模块中,用户可以通过`from plugin_system.apis.plugin_register_api import register_plugin`来导入。 2. register_plugin函数现在转移到了`plugin_system.apis.plugin_register_api`模块中,用户可以通过`from src.plugin_system.apis.plugin_register_api import register_plugin`来导入。
- 顺便一提按照1中说法你可以这么用
```python
from src.plugin_system import register_plugin
```
3. 现在强制要求的property如下 3. 现在强制要求的property如下,即你必须覆盖的属性有
- `plugin_name`: 插件名称,必须是唯一的。(与文件夹相同) - `plugin_name`: 插件名称,必须是唯一的。(与文件夹相同)
- `enable_plugin`: 是否启用插件,默认为`True` - `enable_plugin`: 是否启用插件,默认为`True`。
- `dependencies`: 插件依赖的其他插件列表,默认为空。**现在并不检查(也许)** - `dependencies`: 插件依赖的其他插件列表,默认为空。**现在并不检查(也许)**
- `python_dependencies`: 插件依赖的Python包列表默认为空。**现在并不检查** - `python_dependencies`: 插件依赖的Python包列表默认为空。**现在并不检查**
- `config_file_name`: 插件配置文件名,默认为`config.toml` - `config_file_name`: 插件配置文件名,默认为`config.toml`。
- `config_schema`: 插件配置文件的schema用于自动生成配置文件。 - `config_schema`: 插件配置文件的schema用于自动生成配置文件。
4. 部分API的参数类型和返回值进行了调整
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时保证了typing正确`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
5. 增加了`logging_api`,可以用`get_logger`来获取日志记录器。
# 插件系统修改 # 插件系统修改
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)** 1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容 2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容
3. 修复了插件系统混用了`plugin_name``display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。**(可能有遗漏)** 3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。
3. 部分API的参数类型和返回值进行了调整
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
- `config_api.py`中的`get_global_config``get_plugin_config`方法现在支持嵌套访问的配置键名。
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时保证了typing正确`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
4. 现在增加了参数类型检查,完善了对应注释 4. 现在增加了参数类型检查,完善了对应注释
5. 现在插件抽象出了总基类 `PluginBase` 5. 现在插件抽象出了总基类 `PluginBase`
- 基于`Action``Command`的插件基类现在为`BasePlugin`,它继承自`PluginBase`,由`register_plugin`装饰器注册。 - <del>基于`Action`和`Command`的插件基类现在为`BasePlugin`。</del>
- 基于`Event`的插件基类现在为`BaseEventPlugin`,它也继承自`PluginBase`,由`register_event_plugin`装饰器注册。 - <del>基于`Event`的插件基类现在为`BaseEventPlugin`。</del>
- 基于`Action``Command`和`Event`的插件基类现在为`BasePlugin`,所有插件都应该继承此基类。
- `BasePlugin`继承自`PluginBase`。
- 所有的插件类都由`register_plugin`装饰器注册。
6. 现在我们终于可以让插件有自定义的名字了!
- 真正实现了插件的`plugin_name`**不受文件夹名称限制**的功能。(吐槽:可乐你的某个小小细节导致我搞了好久……)
- 通过在插件类中定义`plugin_name`属性来指定插件内部标识符。
- 由于此更改一个文件中现在可以有多个插件类,但每个插件类必须有**唯一的**`plugin_name`。
- 在某些插件加载失败时,现在会显示包名而不是插件内部标识符。
- 例如:`MaiMBot.plugins.example_plugin`而不是`example_plugin`。
- 仅在插件 import 失败时会如此,正常注册过程中失败的插件不会显示包名,而是显示插件内部标识符。(这是特性,但是基本上不可能出现这个情况)
7. 现在不支持单文件插件了,加载方式已经完全删除。
8. 把`BaseEventPlugin`合并到了`BasePlugin`中,所有插件都应该继承自`BasePlugin`。
# 吐槽
```python
plugin_path = Path(plugin_file)
if plugin_path.parent.name != "plugins":
# 插件包格式parent_dir.plugin
module_name = f"plugins.{plugin_path.parent.name}.plugin"
else:
# 单文件格式plugins.filename
module_name = f"plugins.{plugin_path.stem}"
```
```python
plugin_path = Path(plugin_file)
module_name = ".".join(plugin_path.parent.parts)
```
这两个区别很大的。

View File

@@ -7,11 +7,13 @@ from src.plugin_system import (
ComponentInfo, ComponentInfo,
ActionActivationType, ActionActivationType,
ConfigField, ConfigField,
BaseEventHandler,
EventType,
MaiMessages,
) )
# ===== Action组件 ===== # ===== Action组件 =====
class HelloAction(BaseAction): class HelloAction(BaseAction):
"""问候Action - 简单的问候动作""" """问候Action - 简单的问候动作"""
@@ -82,7 +84,7 @@ class TimeCommand(BaseCommand):
import datetime import datetime
# 获取当前时间 # 获取当前时间
time_format = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore
now = datetime.datetime.now() now = datetime.datetime.now()
time_str = now.strftime(time_format) time_str = now.strftime(time_format)
@@ -93,6 +95,20 @@ class TimeCommand(BaseCommand):
return True, f"显示了当前时间: {time_str}" return True, f"显示了当前时间: {time_str}"
class PrintMessage(BaseEventHandler):
"""打印消息事件处理器 - 处理打印消息事件"""
event_type = EventType.ON_MESSAGE
handler_name = "print_message_handler"
handler_description = "打印接收到的消息"
async def execute(self, message: MaiMessages) -> Tuple[bool, str | None]:
"""执行打印消息事件处理"""
# 打印接收到的消息
print(f"接收到消息: {message.raw_message}")
return True, "消息已打印"
# ===== 插件注册 ===== # ===== 插件注册 =====
@@ -122,6 +138,7 @@ class HelloWorldPlugin(BasePlugin):
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"), "enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
}, },
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")}, "time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
"print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")},
} }
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
@@ -129,4 +146,27 @@ class HelloWorldPlugin(BasePlugin):
(HelloAction.get_action_info(), HelloAction), (HelloAction.get_action_info(), HelloAction),
(ByeAction.get_action_info(), ByeAction), # 添加告别Action (ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand), (TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage),
] ]
# @register_plugin
# class HelloWorldEventPlugin(BaseEPlugin):
# """Hello World事件插件 - 处理问候和告别事件"""
# plugin_name = "hello_world_event_plugin"
# enable_plugin = False
# dependencies = []
# python_dependencies = []
# config_file_name = "event_config.toml"
# config_schema = {
# "plugin": {
# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"),
# "version": ConfigField(type=str, default="1.0.0", description="插件版本"),
# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
# },
# }
# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
# return [(PrintMessage.get_handler_info(), PrintMessage)]

View File

@@ -179,8 +179,7 @@ class HeartFChatting:
await asyncio.sleep(10) await asyncio.sleep(10)
if self.loop_mode == ChatMode.NORMAL: if self.loop_mode == ChatMode.NORMAL:
self.energy_value -= 0.3 self.energy_value -= 0.3
if self.energy_value <= 0.3: self.energy_value = max(self.energy_value, 0.3)
self.energy_value = 0.3
def print_cycle_info(self, cycle_timers): def print_cycle_info(self, cycle_timers):
# 记录循环信息和计时器结果 # 记录循环信息和计时器结果
@@ -257,6 +256,7 @@ class HeartFChatting:
return f"{person_name}:{message_data.get('processed_plain_text')}" return f"{person_name}:{message_data.get('processed_plain_text')}"
async def _observe(self, message_data: Optional[Dict[str, Any]] = None): async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
# sourcery skip: hoist-statement-from-if, merge-comparisons, reintroduce-else
if not message_data: if not message_data:
message_data = {} message_data = {}
action_type = "no_action" action_type = "no_action"
@@ -462,7 +462,7 @@ class HeartFChatting:
"兴趣"模式下,判断是否回复并生成内容。 "兴趣"模式下,判断是否回复并生成内容。
""" """
interested_rate = message_data.get("interest_value", 0.0) * self.willing_amplifier interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier
self.willing_manager.setup(message_data, self.chat_stream) self.willing_manager.setup(message_data, self.chat_stream)

View File

@@ -106,10 +106,10 @@ class EmbeddingStore:
asyncio.get_running_loop() asyncio.get_running_loop()
# 如果在事件循环中,使用线程池执行 # 如果在事件循环中,使用线程池执行
import concurrent.futures import concurrent.futures
def run_in_thread(): def run_in_thread():
return asyncio.run(get_embedding(s)) return asyncio.run(get_embedding(s))
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread) future = executor.submit(run_in_thread)
result = future.result() result = future.result()
@@ -294,10 +294,10 @@ class EmbeddingStore:
""" """
if self.faiss_index is None: if self.faiss_index is None:
logger.debug("FaissIndex尚未构建,返回None") logger.debug("FaissIndex尚未构建,返回None")
return None return []
if self.idx2hash is None: if self.idx2hash is None:
logger.warning("idx2hash尚未构建,返回None") logger.warning("idx2hash尚未构建,返回None")
return None return []
# L2归一化 # L2归一化
faiss.normalize_L2(np.array([query], dtype=np.float32)) faiss.normalize_L2(np.array([query], dtype=np.float32))
@@ -318,15 +318,15 @@ class EmbeddingStore:
class EmbeddingManager: class EmbeddingManager:
def __init__(self): def __init__(self):
self.paragraphs_embedding_store = EmbeddingStore( self.paragraphs_embedding_store = EmbeddingStore(
local_storage['pg_namespace'], local_storage["pg_namespace"], # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.entities_embedding_store = EmbeddingStore( self.entities_embedding_store = EmbeddingStore(
local_storage['pg_namespace'], local_storage["pg_namespace"], # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.relation_embedding_store = EmbeddingStore( self.relation_embedding_store = EmbeddingStore(
local_storage['pg_namespace'], local_storage["pg_namespace"], # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.stored_pg_hashes = set() self.stored_pg_hashes = set()

View File

@@ -30,20 +30,20 @@ def _get_kg_dir():
""" """
安全地获取KG数据目录路径 安全地获取KG数据目录路径
""" """
root_path = local_storage['root_path'] root_path: str = local_storage["root_path"]
if root_path is None: if root_path is None:
# 如果 local_storage 中没有 root_path使用当前文件的相对路径作为备用 # 如果 local_storage 中没有 root_path使用当前文件的相对路径作为备用
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
logger.warning(f"local_storage 中未找到 root_path使用备用路径: {root_path}") logger.warning(f"local_storage 中未找到 root_path使用备用路径: {root_path}")
# 获取RAG数据目录 # 获取RAG数据目录
rag_data_dir = global_config["persistence"]["rag_data_dir"] rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
if rag_data_dir is None: if rag_data_dir is None:
kg_dir = os.path.join(root_path, "data/rag") kg_dir = os.path.join(root_path, "data/rag")
else: else:
kg_dir = os.path.join(root_path, rag_data_dir) kg_dir = os.path.join(root_path, rag_data_dir)
return str(kg_dir).replace("\\", "/") return str(kg_dir).replace("\\", "/")
@@ -65,9 +65,9 @@ class KGManager:
# 持久化相关 - 使用延迟初始化的路径 # 持久化相关 - 使用延迟初始化的路径
self.dir_path = get_kg_dir_str() self.dir_path = get_kg_dir_str()
self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml" self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet" self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json" self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json"
def save_to_file(self): def save_to_file(self):
"""将KG数据保存到文件""" """将KG数据保存到文件"""
@@ -91,11 +91,11 @@ class KGManager:
"""从文件加载KG数据""" """从文件加载KG数据"""
# 确保文件存在 # 确保文件存在
if not os.path.exists(self.pg_hash_file_path): if not os.path.exists(self.pg_hash_file_path):
raise Exception(f"KG段落hash文件{self.pg_hash_file_path}不存在") raise FileNotFoundError(f"KG段落hash文件{self.pg_hash_file_path}不存在")
if not os.path.exists(self.ent_cnt_data_path): if not os.path.exists(self.ent_cnt_data_path):
raise Exception(f"KG实体计数文件{self.ent_cnt_data_path}不存在") raise FileNotFoundError(f"KG实体计数文件{self.ent_cnt_data_path}不存在")
if not os.path.exists(self.graph_data_path): if not os.path.exists(self.graph_data_path):
raise Exception(f"KG图文件{self.graph_data_path}不存在") raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
# 加载段落hash # 加载段落hash
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f: with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
@@ -122,8 +122,8 @@ class KGManager:
# 避免自连接 # 避免自连接
continue continue
# 一个triple就是一条边同时构建双向联系 # 一个triple就是一条边同时构建双向联系
hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2]) hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1) entity_set.add(hash_key1)
@@ -141,8 +141,8 @@ class KGManager:
"""构建实体节点与文段节点之间的关系""" """构建实体节点与文段节点之间的关系"""
for idx in triple_list_data: for idx in triple_list_data:
for triple in triple_list_data[idx]: for triple in triple_list_data[idx]:
ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx) pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod @staticmethod
@@ -157,8 +157,8 @@ class KGManager:
ent_hash_list = set() ent_hash_list = set()
for triple_list in triple_list_data.values(): for triple_list in triple_list_data.values():
for triple in triple_list: for triple in triple_list:
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0])) ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0]))
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2])) ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2]))
ent_hash_list = list(ent_hash_list) ent_hash_list = list(ent_hash_list)
synonym_hash_set = set() synonym_hash_set = set()
@@ -263,7 +263,7 @@ class KGManager:
for src_tgt in node_to_node.keys(): for src_tgt in node_to_node.keys():
for node_hash in src_tgt: for node_hash in src_tgt:
if node_hash not in existed_nodes: if node_hash not in existed_nodes:
if node_hash.startswith(local_storage['ent_namespace']): if node_hash.startswith(local_storage["ent_namespace"]):
# 新增实体节点 # 新增实体节点
node = embedding_manager.entities_embedding_store.store.get(node_hash) node = embedding_manager.entities_embedding_store.store.get(node_hash)
if node is None: if node is None:
@@ -275,7 +275,7 @@ class KGManager:
node_item["type"] = "ent" node_item["type"] = "ent"
node_item["create_time"] = now_time node_item["create_time"] = now_time
self.graph.update_node(node_item) self.graph.update_node(node_item)
elif node_hash.startswith(local_storage['pg_namespace']): elif node_hash.startswith(local_storage["pg_namespace"]):
# 新增文段节点 # 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
if node is None: if node is None:
@@ -359,7 +359,7 @@ class KGManager:
# 关系三元组 # 关系三元组
triple = relation[2:-2].split("', '") triple = relation[2:-2].split("', '")
for ent in [(triple[0]), (triple[2])]: for ent in [(triple[0]), (triple[2])]:
ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent) ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent)
if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash in existed_nodes: # 该实体需在KG中存在
if ent_hash not in ent_sim_scores: # 尚未记录的实体 if ent_hash not in ent_sim_scores: # 尚未记录的实体
ent_sim_scores[ent_hash] = [] ent_sim_scores[ent_hash] = []
@@ -437,7 +437,9 @@ class KGManager:
# 获取最终结果 # 获取最终结果
# 从搜索结果中提取文段节点的结果 # 从搜索结果中提取文段节点的结果
passage_node_res = [ passage_node_res = [
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace']) (node_key, score)
for node_key, score in ppr_res.items()
if node_key.startswith(local_storage["pg_namespace"])
] ]
del ppr_res del ppr_res

View File

@@ -33,6 +33,7 @@ RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
DATA_PATH = os.path.join(ROOT_PATH, "data") DATA_PATH = os.path.join(ROOT_PATH, "data")
def _initialize_knowledge_local_storage(): def _initialize_knowledge_local_storage():
""" """
初始化知识库相关的本地存储配置 初始化知识库相关的本地存储配置
@@ -41,55 +42,58 @@ def _initialize_knowledge_local_storage():
# 定义所有需要初始化的配置项 # 定义所有需要初始化的配置项
default_configs = { default_configs = {
# 路径配置 # 路径配置
'root_path': ROOT_PATH, "root_path": ROOT_PATH,
'data_path': f"{ROOT_PATH}/data", "data_path": f"{ROOT_PATH}/data",
# 实体和命名空间配置 # 实体和命名空间配置
'lpmm_invalid_entity': INVALID_ENTITY, "lpmm_invalid_entity": INVALID_ENTITY,
'pg_namespace': PG_NAMESPACE, "pg_namespace": PG_NAMESPACE,
'ent_namespace': ENT_NAMESPACE, "ent_namespace": ENT_NAMESPACE,
'rel_namespace': REL_NAMESPACE, "rel_namespace": REL_NAMESPACE,
# RAG相关命名空间配置 # RAG相关命名空间配置
'rag_graph_namespace': RAG_GRAPH_NAMESPACE, "rag_graph_namespace": RAG_GRAPH_NAMESPACE,
'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE, "rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE "rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
} }
# 日志级别映射重要配置用info其他用debug # 日志级别映射重要配置用info其他用debug
important_configs = {'root_path', 'data_path'} important_configs = {"root_path", "data_path"}
# 批量设置配置项 # 批量设置配置项
initialized_count = 0 initialized_count = 0
for key, default_value in default_configs.items(): for key, default_value in default_configs.items():
if local_storage[key] is None: if local_storage[key] is None:
local_storage[key] = default_value local_storage[key] = default_value
# 根据重要性选择日志级别 # 根据重要性选择日志级别
if key in important_configs: if key in important_configs:
logger.info(f"设置{key}: {default_value}") logger.info(f"设置{key}: {default_value}")
else: else:
logger.debug(f"设置{key}: {default_value}") logger.debug(f"设置{key}: {default_value}")
initialized_count += 1 initialized_count += 1
if initialized_count > 0: if initialized_count > 0:
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置") logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
else: else:
logger.debug("知识库本地存储配置已存在,跳过初始化") logger.debug("知识库本地存储配置已存在,跳过初始化")
# 初始化本地存储路径 # 初始化本地存储路径
# sourcery skip: dict-comprehension
_initialize_knowledge_local_storage() _initialize_knowledge_local_storage()
qa_manager = None
inspire_manager = None
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用
if bot_global_config.lpmm_knowledge.enable: if bot_global_config.lpmm_knowledge.enable:
logger.info("正在初始化Mai-LPMM") logger.info("正在初始化Mai-LPMM")
logger.info("创建LLM客户端") logger.info("创建LLM客户端")
llm_client_list = dict() llm_client_list = {}
for key in global_config["llm_providers"]: for key in global_config["llm_providers"]:
llm_client_list[key] = LLMClient( llm_client_list[key] = LLMClient(
global_config["llm_providers"][key]["base_url"], global_config["llm_providers"][key]["base_url"], # type: ignore
global_config["llm_providers"][key]["api_key"], global_config["llm_providers"][key]["api_key"], # type: ignore
) )
# 初始化Embedding库 # 初始化Embedding库
@@ -98,7 +102,7 @@ if bot_global_config.lpmm_knowledge.enable:
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()
except Exception as e: except Exception as e:
logger.warning("此消息不会影响正常使用从文件加载Embedding库时{}".format(e)) logger.warning(f"此消息不会影响正常使用从文件加载Embedding库时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("Embedding库加载完成") logger.info("Embedding库加载完成")
# 初始化KG # 初始化KG
@@ -107,7 +111,7 @@ if bot_global_config.lpmm_knowledge.enable:
try: try:
kg_manager.load_from_file() kg_manager.load_from_file()
except Exception as e: except Exception as e:
logger.warning("此消息不会影响正常使用从文件加载KG时{}".format(e)) logger.warning(f"此消息不会影响正常使用从文件加载KG时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("KG加载完成") logger.info("KG加载完成")
@@ -116,7 +120,7 @@ if bot_global_config.lpmm_knowledge.enable:
# 数据比对Embedding库与KG的段落hash集合 # 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes: for pg_hash in kg_manager.stored_paragraph_hashes:
key = PG_NAMESPACE + "-" + pg_hash key = f"{PG_NAMESPACE}-{pg_hash}"
if key not in embed_manager.stored_pg_hashes: if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}") logger.warning(f"KG中存在Embedding库中不存在的段落{key}")
@@ -134,5 +138,3 @@ if bot_global_config.lpmm_knowledge.enable:
else: else:
logger.info("LPMM知识库已禁用跳过初始化") logger.info("LPMM知识库已禁用跳过初始化")
# 创建空的占位符对象,避免导入错误 # 创建空的占位符对象,避免导入错误
qa_manager = None
inspire_manager = None

View File

@@ -1,5 +1,3 @@
from .llm_client import LLMMessage
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体并以JSON列表的形式输出。 entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体并以JSON列表的形式输出。
输出格式示例: 输出格式示例:
@@ -63,10 +61,10 @@ qa_system_prompt = """
""" """
def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]: # def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]) # knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
messages = [ # messages = [
LLMMessage("system", qa_system_prompt).to_dict(), # LLMMessage("system", qa_system_prompt).to_dict(),
LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(), # LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息\n{knowledge}").to_dict(),
] # ]
return messages # return messages

View File

@@ -9,6 +9,7 @@ from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.utils_image import get_image_manager from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_voice import get_voice_text
from .chat_stream import ChatStream from .chat_stream import ChatStream
install(extra_lines=3) install(extra_lines=3)
@@ -106,6 +107,7 @@ class MessageRecv(Message):
self.has_emoji = False self.has_emoji = False
self.is_picid = False self.is_picid = False
self.has_picid = False self.has_picid = False
self.is_voice = False
self.is_mentioned = None self.is_mentioned = None
self.is_command = False self.is_command = False
@@ -153,17 +155,27 @@ class MessageRecv(Message):
self.has_emoji = True self.has_emoji = True
self.is_emoji = True self.is_emoji = True
self.is_picid = False self.is_picid = False
self.is_voice = False
if isinstance(segment.data, str): if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data) return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]" return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.is_picid = False
self.is_emoji = False
self.is_voice = True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot": elif segment.type == "mention_bot":
self.is_picid = False self.is_picid = False
self.is_emoji = False self.is_emoji = False
self.is_voice = False
self.is_mentioned = float(segment.data) # type: ignore self.is_mentioned = float(segment.data) # type: ignore
return "" return ""
elif segment.type == "priority_info": elif segment.type == "priority_info":
self.is_picid = False self.is_picid = False
self.is_emoji = False self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict): if isinstance(segment.data, dict):
# 处理优先级信息 # 处理优先级信息
self.priority_mode = "priority" self.priority_mode = "priority"
@@ -212,10 +224,12 @@ class MessageRecvS4U(MessageRecv):
""" """
try: try:
if segment.type == "text": if segment.type == "text":
self.is_voice = False
self.is_picid = False self.is_picid = False
self.is_emoji = False self.is_emoji = False
return segment.data # type: ignore return segment.data # type: ignore
elif segment.type == "image": elif segment.type == "image":
self.is_voice = False
# 如果是base64图片数据 # 如果是base64图片数据
if isinstance(segment.data, str): if isinstance(segment.data, str):
self.has_picid = True self.has_picid = True
@@ -233,12 +247,22 @@ class MessageRecvS4U(MessageRecv):
if isinstance(segment.data, str): if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data) return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]" return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.has_picid = False
self.is_picid = False
self.is_emoji = False
self.is_voice = True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot": elif segment.type == "mention_bot":
self.is_voice = False
self.is_picid = False self.is_picid = False
self.is_emoji = False self.is_emoji = False
self.is_mentioned = float(segment.data) # type: ignore self.is_mentioned = float(segment.data) # type: ignore
return "" return ""
elif segment.type == "priority_info": elif segment.type == "priority_info":
self.is_voice = False
self.is_picid = False self.is_picid = False
self.is_emoji = False self.is_emoji = False
if isinstance(segment.data, dict): if isinstance(segment.data, dict):
@@ -253,6 +277,7 @@ class MessageRecvS4U(MessageRecv):
""" """
return "" return ""
elif segment.type == "gift": elif segment.type == "gift":
self.is_voice = False
self.is_gift = True self.is_gift = True
# 解析gift_info格式为"名称:数量" # 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1) # type: ignore name, count = segment.data.split(":", 1) # type: ignore
@@ -343,6 +368,10 @@ class MessageProcessBase(Message):
if isinstance(seg.data, str): if isinstance(seg.data, str):
return await get_image_manager().get_emoji_description(seg.data) return await get_image_manager().get_emoji_description(seg.data)
return "[表情,网卡了加载不出来]" return "[表情,网卡了加载不出来]"
elif seg.type == "voice":
if isinstance(seg.data, str):
return await get_voice_text(seg.data)
return "[发了一段语音,网卡了加载不出来]"
elif seg.type == "at": elif seg.type == "at":
return f"[@{seg.data}]" return f"[@{seg.data}]"
elif seg.type == "reply": elif seg.type == "reply":
@@ -455,25 +484,25 @@ class MessageSending(MessageProcessBase):
if self.message_segment: if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment) self.processed_plain_text = await self._process_message_segments(self.message_segment)
@classmethod # @classmethod
def from_thinking( # def from_thinking(
cls, # cls,
thinking: MessageThinking, # thinking: MessageThinking,
message_segment: Seg, # message_segment: Seg,
is_head: bool = False, # is_head: bool = False,
is_emoji: bool = False, # is_emoji: bool = False,
) -> "MessageSending": # ) -> "MessageSending":
"""从思考状态消息创建发送状态消息""" # """从思考状态消息创建发送状态消息"""
return cls( # return cls(
message_id=thinking.message_info.message_id, # type: ignore # message_id=thinking.message_info.message_id, # type: ignore
chat_stream=thinking.chat_stream, # chat_stream=thinking.chat_stream,
message_segment=message_segment, # message_segment=message_segment,
bot_user_info=thinking.message_info.user_info, # type: ignore # bot_user_info=thinking.message_info.user_info, # type: ignore
reply=thinking.reply, # reply=thinking.reply,
is_head=is_head, # is_head=is_head,
is_emoji=is_emoji, # is_emoji=is_emoji,
sender_info=None, # sender_info=None,
) # )
def to_dict(self): def to_dict(self):
ret = super().to_dict() ret = super().to_dict()

View File

@@ -262,4 +262,4 @@ class ActionManager:
""" """
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
return component_registry.get_component_class(action_name) # type: ignore return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore

View File

@@ -0,0 +1,35 @@
import base64
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_voice")
async def get_voice_text(voice_base64: str) -> str:
"""获取音频文件描述"""
if not global_config.chat.enable_asr:
logger.warning("语音识别未启用,无法处理语音消息")
return "[语音]"
try:
# 解码base64音频数据
# 确保base64字符串只包含ASCII字符
if isinstance(voice_base64, str):
voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
voice_bytes = base64.b64decode(voice_base64)
_llm = LLMRequest(model=global_config.model.voice, request_type="voice")
text = await _llm.generate_response_for_voice(voice_bytes)
if text is None:
logger.warning("未能生成语音文本")
return "[语音(文本生成失败)]"
logger.debug(f"描述是{text}")
return f"[语音:{text}]"
except Exception as e:
logger.error(f"语音转文字失败: {str(e)}")
return "[语音]"

View File

@@ -21,6 +21,7 @@ class ClassicalWillingManager(BaseWillingManager):
self._decay_task = asyncio.create_task(self._decay_reply_willing()) self._decay_task = asyncio.create_task(self._decay_reply_willing())
async def get_reply_probability(self, message_id): async def get_reply_probability(self, message_id):
# sourcery skip: inline-immediately-returned-variable
willing_info = self.ongoing_messages[message_id] willing_info = self.ongoing_messages[message_id]
chat_id = willing_info.chat_id chat_id = willing_info.chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0) current_willing = self.chat_reply_willing.get(chat_id, 0)

View File

@@ -25,6 +25,8 @@ import asyncio
import time import time
import math import math
from src.chat.message_receive.chat_stream import ChatStream
class MxpWillingManager(BaseWillingManager): class MxpWillingManager(BaseWillingManager):
"""Mxp意愿管理器""" """Mxp意愿管理器"""
@@ -76,7 +78,7 @@ class MxpWillingManager(BaseWillingManager):
self.chat_bot_message_time[w_info.chat_id].append(current_time) self.chat_bot_message_time[w_info.chat_id].append(current_time)
if len(self.chat_bot_message_time[w_info.chat_id]) == int(self.fatigue_messages_triggered_num): if len(self.chat_bot_message_time[w_info.chat_id]) == int(self.fatigue_messages_triggered_num):
time_interval = 60 - (current_time - self.chat_bot_message_time[w_info.chat_id].pop(0)) time_interval = 60 - (current_time - self.chat_bot_message_time[w_info.chat_id].pop(0))
self.chat_fatigue_punishment_list[w_info.chat_id].append([current_time, time_interval * 2]) self.chat_fatigue_punishment_list[w_info.chat_id].append((current_time, time_interval * 2))
async def after_generate_reply_handle(self, message_id: str): async def after_generate_reply_handle(self, message_id: str):
"""回复后处理""" """回复后处理"""
@@ -87,12 +89,14 @@ class MxpWillingManager(BaseWillingManager):
# rel_level = self._get_relationship_level_num(rel_value) # rel_level = self._get_relationship_level_num(rel_value)
# self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += rel_level * 0.05 # self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += rel_level * 0.05
now_chat_new_person = self.last_response_person.get(w_info.chat_id, [w_info.person_id, 0]) now_chat_new_person = self.last_response_person.get(w_info.chat_id, (w_info.person_id, 0))
if now_chat_new_person[0] == w_info.person_id: if now_chat_new_person[0] == w_info.person_id:
if now_chat_new_person[1] < 3: if now_chat_new_person[1] < 3:
now_chat_new_person[1] += 1 tmp_list = list(now_chat_new_person)
tmp_list[1] += 1 # type: ignore
self.last_response_person[w_info.chat_id] = tuple(tmp_list) # type: ignore
else: else:
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0] self.last_response_person[w_info.chat_id] = (w_info.person_id, 0)
async def not_reply_handle(self, message_id: str): async def not_reply_handle(self, message_id: str):
"""不回复处理""" """不回复处理"""
@@ -108,11 +112,12 @@ class MxpWillingManager(BaseWillingManager):
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * ( self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * (
2 * self.last_response_person[w_info.chat_id][1] - 1 2 * self.last_response_person[w_info.chat_id][1] - 1
) )
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0]) now_chat_new_person = self.last_response_person.get(w_info.chat_id, ("", 0))
if now_chat_new_person[0] != w_info.person_id: if now_chat_new_person[0] != w_info.person_id:
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0] self.last_response_person[w_info.chat_id] = (w_info.person_id, 0)
async def get_reply_probability(self, message_id: str): async def get_reply_probability(self, message_id: str):
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
"""获取回复概率""" """获取回复概率"""
async with self.lock: async with self.lock:
w_info = self.ongoing_messages[message_id] w_info = self.ongoing_messages[message_id]
@@ -121,17 +126,16 @@ class MxpWillingManager(BaseWillingManager):
self.logger.debug(f"基础意愿值:{current_willing}") self.logger.debug(f"基础意愿值:{current_willing}")
if w_info.is_mentioned_bot: if w_info.is_mentioned_bot:
current_willing_ = self.mention_willing_gain / (int(current_willing) + 1) willing_gain = self.mention_willing_gain / (int(current_willing) + 1)
current_willing += current_willing_ current_willing += willing_gain
if self.is_debug: if self.is_debug:
self.logger.debug(f"提及增益:{current_willing_}") self.logger.debug(f"提及增益:{willing_gain}")
if w_info.interested_rate > 0: if w_info.interested_rate > 0:
current_willing += math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain willing_gain = math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain
current_willing += willing_gain
if self.is_debug: if self.is_debug:
self.logger.debug( self.logger.debug(f"兴趣增益:{willing_gain}")
f"兴趣增益:{math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain}"
)
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] = current_willing self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] = current_willing
@@ -152,8 +156,8 @@ class MxpWillingManager(BaseWillingManager):
self.logger.debug(f"疲劳衰减:{self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)}") self.logger.debug(f"疲劳衰减:{self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)}")
chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id] chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id]
chat_person_ogoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id] chat_person_ongoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id]
if len(chat_person_ogoing_messages) >= 2: if len(chat_person_ongoing_messages) >= 2:
current_willing = 0 current_willing = 0
if self.is_debug: if self.is_debug:
self.logger.debug("进行中消息惩罚归0") self.logger.debug("进行中消息惩罚归0")
@@ -191,34 +195,33 @@ class MxpWillingManager(BaseWillingManager):
basic_willing + (willing - basic_willing) * self.intention_decay_rate basic_willing + (willing - basic_willing) * self.intention_decay_rate
) )
def setup(self, message, chat, is_mentioned_bot, interested_rate): def setup(self, message: dict, chat_stream: ChatStream):
super().setup(message, chat, is_mentioned_bot, interested_rate) super().setup(message, chat_stream)
stream_id = chat_stream.stream_id
self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get( self.chat_reply_willing[stream_id] = self.chat_reply_willing.get(stream_id, self.basic_maximum_willing)
chat.stream_id, self.basic_maximum_willing self.chat_person_reply_willing[stream_id] = self.chat_person_reply_willing.get(stream_id, {})
) self.chat_person_reply_willing[stream_id][self.ongoing_messages[message.get("message_id", "")].person_id] = (
self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {}) self.chat_person_reply_willing[stream_id].get(
self.chat_person_reply_willing[chat.stream_id][ self.ongoing_messages[message.get("message_id", "")].person_id,
self.ongoing_messages[message.message_info.message_id].person_id self.chat_reply_willing[stream_id],
] = self.chat_person_reply_willing[chat.stream_id].get( )
self.ongoing_messages[message.message_info.message_id].person_id, self.chat_reply_willing[chat.stream_id]
) )
current_time = time.time() current_time = time.time()
if chat.stream_id not in self.chat_new_message_time: if stream_id not in self.chat_new_message_time:
self.chat_new_message_time[chat.stream_id] = [] self.chat_new_message_time[stream_id] = []
self.chat_new_message_time[chat.stream_id].append(current_time) self.chat_new_message_time[stream_id].append(current_time)
if len(self.chat_new_message_time[chat.stream_id]) > self.number_of_message_storage: if len(self.chat_new_message_time[stream_id]) > self.number_of_message_storage:
self.chat_new_message_time[chat.stream_id].pop(0) self.chat_new_message_time[stream_id].pop(0)
if chat.stream_id not in self.chat_fatigue_punishment_list: if stream_id not in self.chat_fatigue_punishment_list:
self.chat_fatigue_punishment_list[chat.stream_id] = [ self.chat_fatigue_punishment_list[stream_id] = [
( (
current_time, current_time,
self.number_of_message_storage * self.basic_maximum_willing / self.expected_replies_per_min * 60, self.number_of_message_storage * self.basic_maximum_willing / self.expected_replies_per_min * 60,
) )
] ]
self.chat_fatigue_willing_attenuation[chat.stream_id] = ( self.chat_fatigue_willing_attenuation[stream_id] = (
-2 * self.basic_maximum_willing * self.fatigue_coefficient -2 * self.basic_maximum_willing * self.fatigue_coefficient
) )
@@ -227,12 +230,11 @@ class MxpWillingManager(BaseWillingManager):
"""意愿值转化为概率""" """意愿值转化为概率"""
willing = max(0, willing) willing = max(0, willing)
if willing < 2: if willing < 2:
probability = math.atan(willing * 2) / math.pi * 2 return math.atan(willing * 2) / math.pi * 2
elif willing < 2.5: elif willing < 2.5:
probability = math.atan(willing * 4) / math.pi * 2 return math.atan(willing * 4) / math.pi * 2
else: else:
probability = 1 return 1
return probability
async def _chat_new_message_to_change_basic_willing(self): async def _chat_new_message_to_change_basic_willing(self):
"""聊天流新消息改变基础意愿""" """聊天流新消息改变基础意愿"""
@@ -259,7 +261,7 @@ class MxpWillingManager(BaseWillingManager):
update_time = 20 update_time = 20
elif len(message_times) == self.number_of_message_storage: elif len(message_times) == self.number_of_message_storage:
time_interval = current_time - message_times[0] time_interval = current_time - message_times[0]
basic_willing = self._basic_willing_culculate(time_interval) basic_willing = self._basic_willing_calculate(time_interval)
self.chat_reply_willing[chat_id] = basic_willing self.chat_reply_willing[chat_id] = basic_willing
update_time = 17 * basic_willing / self.basic_maximum_willing + 3 update_time = 17 * basic_willing / self.basic_maximum_willing + 3
else: else:
@@ -268,7 +270,7 @@ class MxpWillingManager(BaseWillingManager):
if self.is_debug: if self.is_debug:
self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}") self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}")
def _basic_willing_culculate(self, t: float) -> float: def _basic_willing_calculate(self, t: float) -> float:
"""基础意愿值计算""" """基础意愿值计算"""
return math.tan(t * self.expected_replies_per_min * math.pi / 120 / self.number_of_message_storage) / 2 return math.tan(t * self.expected_replies_per_min * math.pi / 120 / self.number_of_message_storage) / 2

View File

@@ -104,7 +104,7 @@ class BaseWillingManager(ABC):
is_mentioned_bot=message.get("is_mentioned", False), is_mentioned_bot=message.get("is_mentioned", False),
is_emoji=message.get("is_emoji", False), is_emoji=message.get("is_emoji", False),
is_picid=message.get("is_picid", False), is_picid=message.get("is_picid", False),
interested_rate=message.get("interest_value", 0), interested_rate = message.get("interest_value") or 0.0,
) )
def delete(self, message_id: str): def delete(self, message_id: str):

View File

@@ -106,6 +106,9 @@ class ChatConfig(ConfigBase):
focus_value: float = 1.0 focus_value: float = 1.0
"""麦麦的专注思考能力越低越容易专注消耗token也越多""" """麦麦的专注思考能力越低越容易专注消耗token也越多"""
enable_asr: bool = False
"""是否启用语音识别"""
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float: def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
""" """
根据当前时间和聊天流获取对应的 talk_frequency 根据当前时间和聊天流获取对应的 talk_frequency
@@ -630,6 +633,9 @@ class ModelConfig(ConfigBase):
vlm: dict[str, Any] = field(default_factory=lambda: {}) vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置""" """视觉语言模型配置"""
voice: dict[str, Any] = field(default_factory=lambda: {})
"""语音识别模型配置"""
tool_use: dict[str, Any] = field(default_factory=lambda: {}) tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""专注工具使用模型配置""" """专注工具使用模型配置"""

View File

@@ -216,6 +216,8 @@ class LLMRequest:
prompt: str = None, prompt: str = None,
image_base64: str = None, image_base64: str = None,
image_format: str = None, image_format: str = None,
file_bytes: bytes = None,
file_format: str = None,
payload: dict = None, payload: dict = None,
retry_policy: dict = None, retry_policy: dict = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -225,6 +227,8 @@ class LLMRequest:
prompt: prompt文本 prompt: prompt文本
image_base64: 图片的base64编码 image_base64: 图片的base64编码
image_format: 图片格式 image_format: 图片格式
file_bytes: 文件的二进制数据
file_format: 文件格式
payload: 请求体数据 payload: 请求体数据
retry_policy: 自定义重试策略 retry_policy: 自定义重试策略
request_type: 请求类型 request_type: 请求类型
@@ -246,30 +250,33 @@ class LLMRequest:
# 构建请求体 # 构建请求体
if image_base64: if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format) payload = await self._build_payload(prompt, image_base64, image_format)
elif file_bytes:
payload = await self._build_formdata_payload(file_bytes, file_format)
elif payload is None: elif payload is None:
payload = await self._build_payload(prompt) payload = await self._build_payload(prompt)
if stream_mode: if not file_bytes:
payload["stream"] = stream_mode if stream_mode:
payload["stream"] = stream_mode
if self.temp != 0.7: if self.temp != 0.7:
payload["temperature"] = self.temp payload["temperature"] = self.temp
# 添加enable_thinking参数如果不是默认值False # 添加enable_thinking参数如果不是默认值False
if not self.enable_thinking: if not self.enable_thinking:
payload["enable_thinking"] = False payload["enable_thinking"] = False
if self.thinking_budget != 4096: if self.thinking_budget != 4096:
payload["thinking_budget"] = self.thinking_budget payload["thinking_budget"] = self.thinking_budget
if self.max_tokens: if self.max_tokens:
payload["max_tokens"] = self.max_tokens payload["max_tokens"] = self.max_tokens
# if "max_tokens" not in payload and "max_completion_tokens" not in payload: # if "max_tokens" not in payload and "max_completion_tokens" not in payload:
# payload["max_tokens"] = global_config.model.model_max_output_length # payload["max_tokens"] = global_config.model.model_max_output_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens") payload["max_completion_tokens"] = payload.pop("max_tokens")
return { return {
"policy": policy, "policy": policy,
@@ -278,6 +285,8 @@ class LLMRequest:
"stream_mode": stream_mode, "stream_mode": stream_mode,
"image_base64": image_base64, # 保留必要的exception处理所需的原始数据 "image_base64": image_base64, # 保留必要的exception处理所需的原始数据
"image_format": image_format, "image_format": image_format,
"file_bytes": file_bytes,
"file_format": file_format,
"prompt": prompt, "prompt": prompt,
} }
@@ -287,6 +296,8 @@ class LLMRequest:
prompt: str = None, prompt: str = None,
image_base64: str = None, image_base64: str = None,
image_format: str = None, image_format: str = None,
file_bytes: bytes = None,
file_format: str = None,
payload: dict = None, payload: dict = None,
retry_policy: dict = None, retry_policy: dict = None,
response_handler: callable = None, response_handler: callable = None,
@@ -299,6 +310,8 @@ class LLMRequest:
prompt: prompt文本 prompt: prompt文本
image_base64: 图片的base64编码 image_base64: 图片的base64编码
image_format: 图片格式 image_format: 图片格式
file_bytes: 文件的二进制数据
file_format: 文件格式
payload: 请求体数据 payload: 请求体数据
retry_policy: 自定义重试策略 retry_policy: 自定义重试策略
response_handler: 自定义响应处理器 response_handler: 自定义响应处理器
@@ -307,25 +320,36 @@ class LLMRequest:
""" """
# 获取请求配置 # 获取请求配置
request_content = await self._prepare_request( request_content = await self._prepare_request(
endpoint, prompt, image_base64, image_format, payload, retry_policy endpoint, prompt, image_base64, image_format, file_bytes, file_format, payload, retry_policy
) )
if request_type is None: if request_type is None:
request_type = self.request_type request_type = self.request_type
for retry in range(request_content["policy"]["max_retries"]): for retry in range(request_content["policy"]["max_retries"]):
try: try:
# 使用上下文管理器处理会话 # 使用上下文管理器处理会话
headers = await self._build_headers() if file_bytes:
headers = await self._build_headers(is_formdata=True)
else:
headers = await self._build_headers(is_formdata=False)
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响 # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if request_content["stream_mode"]: if request_content["stream_mode"]:
headers["Accept"] = "text/event-stream" headers["Accept"] = "text/event-stream"
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
async with session.post( post_kwargs = {"headers": headers}
request_content["api_url"], headers=headers, json=request_content["payload"] #form-data数据上传方式不同
if file_bytes:
post_kwargs["data"] = request_content["payload"]
else:
post_kwargs["json"] = request_content["payload"]
async with session.post(
request_content["api_url"], **post_kwargs
) as response: ) as response:
handled_result = await self._handle_response( handled_result = await self._handle_response(
response, request_content, retry, response_handler, user_id, request_type, endpoint response, request_content, retry, response_handler, user_id, request_type, endpoint
) )
return handled_result return handled_result
except Exception as e: except Exception as e:
handled_payload, count_delta = await self._handle_exception(e, retry, request_content) handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
retry += count_delta # 降级不计入重试次数 retry += count_delta # 降级不计入重试次数
@@ -605,7 +629,7 @@ class LLMRequest:
) )
# 安全地检查和记录请求详情 # 安全地检查和记录请求详情
handled_payload = await _safely_record(request_content, payload) handled_payload = await _safely_record(request_content, payload)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}") logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}")
raise RuntimeError( raise RuntimeError(
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}" f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
) )
@@ -619,7 +643,7 @@ class LLMRequest:
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}") logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
# 安全地检查和记录请求详情 # 安全地检查和记录请求详情
handled_payload = await _safely_record(request_content, payload) handled_payload = await _safely_record(request_content, payload)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}") logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}")
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
async def _transform_parameters(self, params: dict) -> dict: async def _transform_parameters(self, params: dict) -> dict:
@@ -640,6 +664,33 @@ class LLMRequest:
new_params["max_completion_tokens"] = new_params.pop("max_tokens") new_params["max_completion_tokens"] = new_params.pop("max_tokens")
return new_params return new_params
async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData:
"""构建form-data请求体"""
# 目前只适配了音频文件
# 如果后续要支持其他类型的文件,可以在这里添加更多的处理逻辑
data = aiohttp.FormData()
content_type_list = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
}
content_type = content_type_list.get(file_format)
if not content_type:
logger.warning(f"暂不支持的文件类型: {file_format}")
data.add_field(
"file",io.BytesIO(file_bytes),
filename=f"file.{file_format}",
content_type=f'{content_type}' # 根据实际文件类型设置
)
data.add_field(
"model", self.model_name
)
return data
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict: async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
"""构建请求体""" """构建请求体"""
# 复制一份参数,避免直接修改 self.params # 复制一份参数,避免直接修改 self.params
@@ -725,7 +776,8 @@ class LLMRequest:
return content, reasoning_content, tool_calls return content, reasoning_content, tool_calls
else: else:
return content, reasoning_content return content, reasoning_content
elif "text" in result and result["text"]:
return result["text"]
return "没有返回结果", "" return "没有返回结果", ""
@staticmethod @staticmethod
@@ -739,11 +791,15 @@ class LLMRequest:
reasoning = "" reasoning = ""
return content, reasoning return content, reasoning
async def _build_headers(self, no_key: bool = False) -> dict: async def _build_headers(self, no_key: bool = False, is_formdata: bool = False) -> dict:
"""构建请求头""" """构建请求头"""
if no_key: if no_key:
if is_formdata:
return {"Authorization": "Bearer **********"}
return {"Authorization": "Bearer **********", "Content-Type": "application/json"} return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
else: else:
if is_formdata:
return {"Authorization": f"Bearer {self.api_key}"}
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
# 防止小朋友们截图自己的key # 防止小朋友们截图自己的key
@@ -761,6 +817,11 @@ class LLMRequest:
content, reasoning_content = response content, reasoning_content = response
return content, reasoning_content return content, reasoning_content
async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple:
"""根据输入的语音文件生成模型的异步响应"""
response = await self._execute_request(endpoint="/audio/transcriptions",file_bytes=voice_bytes, file_format='wav')
return response
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
"""异步方式根据输入的提示生成模型的响应""" """异步方式根据输入的提示生成模型的响应"""
# 构建请求体不硬编码max_tokens # 构建请求体不硬编码max_tokens

View File

@@ -18,11 +18,16 @@ from .base import (
CommandInfo, CommandInfo,
PluginInfo, PluginInfo,
PythonDependency, PythonDependency,
BaseEventHandler,
EventHandlerInfo,
EventType,
MaiMessages,
) )
from .core.plugin_manager import ( from .core import (
plugin_manager, plugin_manager,
component_registry, component_registry,
dependency_manager, dependency_manager,
events_manager,
) )
# 导入工具模块 # 导入工具模块
@@ -33,7 +38,7 @@ from .utils import (
# generate_plugin_manifest, # generate_plugin_manifest,
) )
from .apis.plugin_register_api import register_plugin from .apis import register_plugin, get_logger
__version__ = "1.0.0" __version__ = "1.0.0"
@@ -43,6 +48,7 @@ __all__ = [
"BasePlugin", "BasePlugin",
"BaseAction", "BaseAction",
"BaseCommand", "BaseCommand",
"BaseEventHandler",
# 类型定义 # 类型定义
"ComponentType", "ComponentType",
"ActionActivationType", "ActionActivationType",
@@ -52,15 +58,21 @@ __all__ = [
"CommandInfo", "CommandInfo",
"PluginInfo", "PluginInfo",
"PythonDependency", "PythonDependency",
"EventHandlerInfo",
"EventType",
# 消息
"MaiMessages",
# 管理器 # 管理器
"plugin_manager", "plugin_manager",
"component_registry", "component_registry",
"dependency_manager", "dependency_manager",
"events_manager",
# 装饰器 # 装饰器
"register_plugin", "register_plugin",
"ConfigField", "ConfigField",
# 工具函数 # 工具函数
"ManifestValidator", "ManifestValidator",
"get_logger",
# "ManifestGenerator", # "ManifestGenerator",
# "validate_plugin_manifest", # "validate_plugin_manifest",
# "generate_plugin_manifest", # "generate_plugin_manifest",

View File

@@ -18,7 +18,8 @@ from src.plugin_system.apis import (
utils_api, utils_api,
plugin_register_api, plugin_register_api,
) )
from .logging_api import get_logger
from .plugin_register_api import register_plugin
# 导出所有API模块使它们可以通过 apis.xxx 方式访问 # 导出所有API模块使它们可以通过 apis.xxx 方式访问
__all__ = [ __all__ = [
"chat_api", "chat_api",
@@ -32,4 +33,6 @@ __all__ = [
"send_api", "send_api",
"utils_api", "utils_api",
"plugin_register_api", "plugin_register_api",
"get_logger",
"register_plugin",
] ]

View File

@@ -0,0 +1,3 @@
from src.common.logger import get_logger
__all__ = ["get_logger"]

View File

@@ -1,6 +1,8 @@
from pathlib import Path
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("plugin_register") logger = get_logger("plugin_manager") # 复用plugin_manager名称
def register_plugin(cls): def register_plugin(cls):
@@ -22,16 +24,23 @@ def register_plugin(cls):
# 只是注册插件类,不立即实例化 # 只是注册插件类,不立即实例化
# 插件管理器会负责实例化和注册 # 插件管理器会负责实例化和注册
plugin_name = cls.plugin_name or cls.__name__ plugin_name: str = cls.plugin_name # type: ignore
plugin_manager.plugin_classes[plugin_name] = cls # type: ignore if "." in plugin_name:
logger.debug(f"插件类已注册: {plugin_name}") logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
plugin_manager.plugin_classes[plugin_name] = cls
splitted_name = cls.__module__.split(".")
root_path = Path(__file__)
# 查找项目根目录
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
root_path = root_path.parent
if not (root_path / "pyproject.toml").exists():
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
return cls
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
return cls return cls
def register_event_plugin(cls, *args, **kwargs):
"""事件插件注册装饰器
用法:
@register_event_plugin
"""

View File

@@ -7,6 +7,7 @@
from .base_plugin import BasePlugin from .base_plugin import BasePlugin
from .base_action import BaseAction from .base_action import BaseAction
from .base_command import BaseCommand from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .component_types import ( from .component_types import (
ComponentType, ComponentType,
ActionActivationType, ActionActivationType,
@@ -16,6 +17,9 @@ from .component_types import (
CommandInfo, CommandInfo,
PluginInfo, PluginInfo,
PythonDependency, PythonDependency,
EventHandlerInfo,
EventType,
MaiMessages,
) )
from .config_types import ConfigField from .config_types import ConfigField
@@ -32,4 +36,8 @@ __all__ = [
"PluginInfo", "PluginInfo",
"PythonDependency", "PythonDependency",
"ConfigField", "ConfigField",
"EventHandlerInfo",
"EventType",
"BaseEventHandler",
"MaiMessages",
] ]

View File

@@ -41,6 +41,7 @@ class BaseAction(ABC):
action_message: Optional[dict] = None, action_message: Optional[dict] = None,
**kwargs, **kwargs,
): ):
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
"""初始化Action组件 """初始化Action组件
Args: Args:
@@ -355,7 +356,9 @@ class BaseAction(ABC):
# 从类属性读取名称,如果没有定义则使用类名自动生成 # 从类属性读取名称,如果没有定义则使用类名自动生成
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", "")) name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
if "." in name:
logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
# 获取focus_activation_type和normal_activation_type # 获取focus_activation_type和normal_activation_type
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS) focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS) normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)

View File

@@ -219,7 +219,9 @@ class BaseCommand(ABC):
Returns: Returns:
CommandInfo: 生成的Command信息对象 CommandInfo: 生成的Command信息对象
""" """
if "." in cls.command_name:
logger.error(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
return CommandInfo( return CommandInfo(
name=cls.command_name, name=cls.command_name,
component_type=ComponentType.COMMAND, component_type=ComponentType.COMMAND,

View File

@@ -1,14 +0,0 @@
from abc import abstractmethod
from .plugin_base import PluginBase
from src.common.logger import get_logger
class BaseEventPlugin(PluginBase):
"""基于事件的插件基类
所有事件类型的插件都应该继承这个基类
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional
from src.common.logger import get_logger
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
logger = get_logger("base_event_handler")
class BaseEventHandler(ABC):
"""事件处理器基类
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
"""
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
handler_name: str = "" # 处理器名称
handler_description: str = ""
weight: int = 0 # 权重,数值越大优先级越高
intercept_message: bool = False # 是否拦截消息,默认为否
def __init__(self):
self.log_prefix = "[EventHandler]"
if self.event_type == EventType.UNKNOWN:
raise NotImplementedError("事件处理器必须指定 event_type")
@abstractmethod
async def execute(self, message: MaiMessages) -> Tuple[bool, Optional[str]]:
"""执行事件处理的抽象方法,子类必须实现
Returns:
Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息)
"""
raise NotImplementedError("子类必须实现 execute 方法")
@classmethod
def get_handler_info(cls) -> "EventHandlerInfo":
"""获取事件处理器的信息"""
# 从类属性读取名称,如果没有定义则使用类名自动生成
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
if "." in name:
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
return EventHandlerInfo(
name=name,
component_type=ComponentType.EVENT_HANDLER,
description=getattr(cls, "handler_description", "events处理器"),
event_type=cls.event_type,
weight=cls.weight,
intercept_message=cls.intercept_message,
)

View File

@@ -1,9 +1,12 @@
from abc import abstractmethod from abc import abstractmethod
from typing import List, Type from typing import List, Type, Tuple, Union
from .plugin_base import PluginBase from .plugin_base import PluginBase
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentInfo from src.plugin_system.base.component_types import ComponentInfo, ActionInfo, CommandInfo, EventHandlerInfo
from .base_action import BaseAction
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
logger = get_logger("base_plugin") logger = get_logger("base_plugin")
@@ -21,7 +24,15 @@ class BasePlugin(PluginBase):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@abstractmethod @abstractmethod
def get_plugin_components(self) -> List[tuple[ComponentInfo, Type]]: def get_plugin_components(
self,
) -> List[
Union[
Tuple[ActionInfo, Type[BaseAction]],
Tuple[CommandInfo, Type[BaseCommand]],
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
]
]:
"""获取插件包含的组件列表 """获取插件包含的组件列表
子类必须实现此方法,返回组件信息和组件类的列表 子类必须实现此方法,返回组件信息和组件类的列表

View File

@@ -1,6 +1,7 @@
from enum import Enum from enum import Enum
from typing import Dict, Any, List from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from maim_message import Seg
# 组件类型枚举 # 组件类型枚举
@@ -10,7 +11,10 @@ class ComponentType(Enum):
ACTION = "action" # 动作组件 ACTION = "action" # 动作组件
COMMAND = "command" # 命令组件 COMMAND = "command" # 命令组件
SCHEDULER = "scheduler" # 定时任务组件(预留) SCHEDULER = "scheduler" # 定时任务组件(预留)
LISTENER = "listener" # 事件监听组件(预留) EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
def __str__(self) -> str:
return self.value
# 动作激活类型枚举 # 动作激活类型枚举
@@ -46,12 +50,17 @@ class EventType(Enum):
事件类型枚举类 事件类型枚举类
""" """
ON_START = "on_start" # 启动事件,用于调用按时任务
ON_MESSAGE = "on_message" ON_MESSAGE = "on_message"
ON_PLAN = "on_plan" ON_PLAN = "on_plan"
POST_LLM = "post_llm" POST_LLM = "post_llm"
AFTER_LLM = "after_llm" AFTER_LLM = "after_llm"
POST_SEND = "post_send" POST_SEND = "post_send"
AFTER_SEND = "after_send" AFTER_SEND = "after_send"
UNKNOWN = "unknown" # 未知事件类型
def __str__(self) -> str:
return self.value
@dataclass @dataclass
@@ -142,6 +151,19 @@ class CommandInfo(ComponentInfo):
self.component_type = ComponentType.COMMAND self.component_type = ComponentType.COMMAND
@dataclass
class EventHandlerInfo(ComponentInfo):
"""事件处理器组件信息"""
event_type: EventType = EventType.ON_MESSAGE # 监听事件类型
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
weight: int = 0 # 事件处理器权重,决定执行顺序
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.EVENT_HANDLER
@dataclass @dataclass
class PluginInfo: class PluginInfo:
"""插件信息""" """插件信息"""
@@ -198,3 +220,42 @@ class PluginInfo:
def get_pip_requirements(self) -> List[str]: def get_pip_requirements(self) -> List[str]:
"""获取所有pip安装格式的依赖""" """获取所有pip安装格式的依赖"""
return [dep.get_pip_requirement() for dep in self.python_dependencies] return [dep.get_pip_requirement() for dep in self.python_dependencies]
@dataclass
class MaiMessages:
"""MaiM插件消息"""
message_segments: List[Seg] = field(default_factory=list)
"""消息段列表,支持多段消息"""
message_base_info: Dict[str, Any] = field(default_factory=dict)
"""消息基本信息,包含平台,用户信息等数据"""
plain_text: str = ""
"""纯文本消息内容"""
raw_message: Optional[str] = None
"""原始消息内容"""
is_group_message: bool = False
"""是否为群组消息"""
is_private_message: bool = False
"""是否为私聊消息"""
stream_id: Optional[str] = None
"""流ID用于标识消息流"""
llm_prompt: Optional[str] = None
"""LLM提示词"""
llm_response: Optional[str] = None
"""LLM响应内容"""
additional_data: Dict[Any, Any] = field(default_factory=dict)
"""附加数据,可以存储额外信息"""
def __post_init__(self):
if self.message_segments is None:
self.message_segments = []

View File

@@ -7,9 +7,11 @@
from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.dependency_manager import dependency_manager from src.plugin_system.core.dependency_manager import dependency_manager
from src.plugin_system.core.events_manager import events_manager
__all__ = [ __all__ = [
"plugin_manager", "plugin_manager",
"component_registry", "component_registry",
"dependency_manager", "dependency_manager",
"events_manager",
] ]

View File

@@ -1,16 +1,19 @@
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
import re import re
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import ( from src.plugin_system.base.component_types import (
ComponentInfo, ComponentInfo,
ActionInfo, ActionInfo,
CommandInfo, CommandInfo,
EventHandlerInfo,
PluginInfo, PluginInfo,
ComponentType, ComponentType,
) )
from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("component_registry") logger = get_logger("component_registry")
@@ -23,12 +26,11 @@ class ComponentRegistry:
def __init__(self): def __init__(self):
# 组件注册表 # 组件注册表
self._components: Dict[str, ComponentInfo] = {} # 组件名 -> 组件信息 self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = { # 类型 -> 命名空间式名称 -> 组件信息
ComponentType.ACTION: {}, self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
ComponentType.COMMAND: {}, # 命名空间式组件名 -> 组件类
} self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
self._component_classes: Dict[str, Union[Type[BaseCommand], Type[BaseAction]]] = {} # 组件名 -> 组件类
# 插件注册表 # 插件注册表
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息 self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
@@ -39,20 +41,43 @@ class ComponentRegistry:
# Command特定注册表 # Command特定注册表
self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类 self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类
self._command_patterns: Dict[Pattern, Type[BaseCommand]] = {} # 编译后的正则 -> command self._command_patterns: Dict[Pattern, str] = {} # 编译后的正则 -> command
# EventHandler特定注册表
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} # event_handler名 -> event_handler类
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} # 启用的事件处理器
logger.info("组件注册中心初始化完成") logger.info("组件注册中心初始化完成")
# === 通用组件注册方法 === # == 注册方法 ==
def register_plugin(self, plugin_info: PluginInfo) -> bool:
"""注册插件
Args:
plugin_info: 插件信息
Returns:
bool: 是否注册成功
"""
plugin_name = plugin_info.name
if plugin_name in self._plugins:
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
return False
self._plugins[plugin_name] = plugin_info
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
return True
def register_component( def register_component(
self, component_info: ComponentInfo, component_class: Union[Type[BaseCommand], Type[BaseAction]] self, component_info: ComponentInfo, component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler]]
) -> bool: ) -> bool:
"""注册组件 """注册组件
Args: Args:
component_info: 组件信息 component_info (ComponentInfo): 组件信息
component_class: 组件类 component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类
Returns: Returns:
bool: 是否注册成功 bool: 是否注册成功
@@ -60,68 +85,110 @@ class ComponentRegistry:
component_name = component_info.name component_name = component_info.name
component_type = component_info.component_type component_type = component_info.component_type
plugin_name = getattr(component_info, "plugin_name", "unknown") plugin_name = getattr(component_info, "plugin_name", "unknown")
if "." in component_name:
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in plugin_name:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
# 🔥 系统级别自动区分:为不同类型的组件添加命名空间前缀 namespaced_name = f"{component_type}.{component_name}"
if component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}"
elif component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}"
else:
# 未来扩展的组件类型
namespaced_name = f"{component_type.value}.{component_name}"
# 检查命名空间化的名称是否冲突
if namespaced_name in self._components: if namespaced_name in self._components:
existing_info = self._components[namespaced_name] existing_info = self._components[namespaced_name]
existing_plugin = getattr(existing_info, "plugin_name", "unknown") existing_plugin = getattr(existing_info, "plugin_name", "unknown")
logger.warning( logger.warning(
f"组件冲突: {component_type.value}组件 '{component_name}' " f"组件冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
f"已被插件 '{existing_plugin}' 注册,跳过插件 '{plugin_name}' 的注册"
) )
return False return False
# 注册到通用注册表(使用命名空间化的名称) self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称)
self._components[namespaced_name] = component_info
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名 self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
self._component_classes[namespaced_name] = component_class self._components_classes[namespaced_name] = component_class
# 根据组件类型进行特定注册(使用原始名称) # 根据组件类型进行特定注册(使用原始名称)
if component_type == ComponentType.ACTION: match component_type:
self._register_action_component(component_info, component_class) # type: ignore case ComponentType.ACTION:
elif component_type == ComponentType.COMMAND: ret = self._register_action_component(component_info, component_class) # type: ignore
self._register_command_component(component_info, component_class) # type: ignore case ComponentType.COMMAND:
ret = self._register_command_component(component_info, component_class) # type: ignore
case ComponentType.EVENT_HANDLER:
ret = self._register_event_handler_component(component_info, component_class) # type: ignore
case _:
logger.warning(f"未知组件类型: {component_type}")
if not ret:
return False
logger.debug( logger.debug(
f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' " f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' "
f"({component_class.__name__}) [插件: {plugin_name}]" f"({component_class.__name__}) [插件: {plugin_name}]"
) )
return True return True
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]): def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR -------------------------------------
"""注册Action组件到Action特定注册表""" """注册Action组件到Action特定注册表"""
action_name = action_info.name if not (action_name := action_info.name):
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
return False
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
logger.error(f"注册失败: {action_name} 不是有效的Action")
return False
self._action_registry[action_name] = action_class self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集 # 如果启用,添加到默认动作集
if action_info.enabled: if action_info.enabled:
self._default_actions[action_name] = action_info self._default_actions[action_name] = action_info
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]): return True
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
"""注册Command组件到Command特定注册表""" """注册Command组件到Command特定注册表"""
command_name = command_info.name if not (command_name := command_info.name):
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
return False
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
logger.error(f"注册失败: {command_name} 不是有效的Command")
return False
self._command_registry[command_name] = command_class self._command_registry[command_name] = command_class
# 编译正则表达式并注册 # 如果启用了且有匹配模式
if command_info.command_pattern: if command_info.enabled and command_info.command_pattern:
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL) pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
self._command_patterns[pattern] = command_class if pattern not in self._command_patterns:
self._command_patterns[pattern] = command_name
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
return True
def _register_event_handler_component(
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
) -> bool:
if not (handler_name := handler_info.name):
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
return False
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
return False
self._event_handler_registry[handler_name] = handler_class
from .events_manager import events_manager # 延迟导入防止循环导入问题
if events_manager.register_event_subscriber(handler_info, handler_class):
self._enabled_event_handlers[handler_name] = handler_class
return True
else:
logger.error(f"注册事件处理器 {handler_name} 失败")
return False
# === 组件查询方法 === # === 组件查询方法 ===
def get_component_info(
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: # type: ignore self, component_name: str, component_type: Optional[ComponentType] = None
) -> Optional[ComponentInfo]:
# sourcery skip: class-extract-method # sourcery skip: class-extract-method
"""获取组件信息,支持自动命名空间解析 """获取组件信息,支持自动命名空间解析
@@ -138,18 +205,12 @@ class ComponentRegistry:
# 2. 如果指定了组件类型,构造命名空间化的名称查找 # 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type: if component_type:
if component_type == ComponentType.ACTION: namespaced_name = f"{component_type}.{component_name}"
namespaced_name = f"action.{component_name}"
elif component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}"
else:
namespaced_name = f"{component_type.value}.{component_name}"
return self._components.get(namespaced_name) return self._components.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找 # 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = [] candidates = []
for namespace_prefix in ["action", "command"]: for namespace_prefix in [types.value for types in ComponentType]:
namespaced_name = f"{namespace_prefix}.{component_name}" namespaced_name = f"{namespace_prefix}.{component_name}"
if component_info := self._components.get(namespaced_name): if component_info := self._components.get(namespaced_name):
candidates.append((namespace_prefix, namespaced_name, component_info)) candidates.append((namespace_prefix, namespaced_name, component_info))
@@ -171,8 +232,8 @@ class ComponentRegistry:
def get_component_class( def get_component_class(
self, self,
component_name: str, component_name: str,
component_type: ComponentType = None, # type: ignore component_type: Optional[ComponentType] = None,
) -> Optional[Union[Type[BaseCommand], Type[BaseAction]]]: ) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler]]]:
"""获取组件类,支持自动命名空间解析 """获取组件类,支持自动命名空间解析
Args: Args:
@@ -184,29 +245,23 @@ class ComponentRegistry:
""" """
# 1. 如果已经是命名空间化的名称,直接查找 # 1. 如果已经是命名空间化的名称,直接查找
if "." in component_name: if "." in component_name:
return self._component_classes.get(component_name) return self._components_classes.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找 # 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type: if component_type:
if component_type == ComponentType.ACTION: namespaced_name = f"{component_type.value}.{component_name}"
namespaced_name = f"action.{component_name}" return self._components_classes.get(namespaced_name)
elif component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}"
else:
namespaced_name = f"{component_type.value}.{component_name}"
return self._component_classes.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找 # 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = [] candidates = []
for namespace_prefix in ["action", "command"]: for namespace_prefix in [types.value for types in ComponentType]:
namespaced_name = f"{namespace_prefix}.{component_name}" namespaced_name = f"{namespace_prefix}.{component_name}"
if component_class := self._component_classes.get(namespaced_name): if component_class := self._components_classes.get(namespaced_name):
candidates.append((namespace_prefix, namespaced_name, component_class)) candidates.append((namespace_prefix, namespaced_name, component_class))
if len(candidates) == 1: if len(candidates) == 1:
# 只有一个匹配,直接返回 # 只有一个匹配,直接返回
namespace, full_name, cls = candidates[0] _, full_name, cls = candidates[0]
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'") logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
return cls return cls
elif len(candidates) > 1: elif len(candidates) > 1:
@@ -235,7 +290,7 @@ class ComponentRegistry:
"""获取Action注册表用于兼容现有系统""" """获取Action注册表用于兼容现有系统"""
return self._action_registry.copy() return self._action_registry.copy()
def get_action_info(self, action_name: str) -> Optional[ActionInfo]: def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
"""获取Action信息""" """获取Action信息"""
info = self.get_component_info(action_name, ComponentType.ACTION) info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None return info if isinstance(info, ActionInfo) else None
@@ -247,18 +302,18 @@ class ComponentRegistry:
# === Command特定查询方法 === # === Command特定查询方法 ===
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]: def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
"""获取Command注册表(用于兼容现有系统)""" """获取Command注册表"""
return self._command_registry.copy() return self._command_registry.copy()
def get_command_patterns(self) -> Dict[Pattern, Type[BaseCommand]]: def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
"""获取Command模式注册表用于兼容现有系统"""
return self._command_patterns.copy()
def get_command_info(self, command_name: str) -> Optional[CommandInfo]:
"""获取Command信息""" """获取Command信息"""
info = self.get_component_info(command_name, ComponentType.COMMAND) info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None return info if isinstance(info, CommandInfo) else None
def get_command_patterns(self) -> Dict[Pattern, str]:
"""获取Command模式注册表"""
return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]: def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
# sourcery skip: use-named-expression, use-next # sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令 """根据文本查找匹配的命令
@@ -270,47 +325,36 @@ class ComponentRegistry:
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
""" """
for pattern, command_class in self._command_patterns.items(): candidates = [pattern for pattern in self._command_patterns if pattern.match(text)]
if match := pattern.match(text): if not candidates:
command_name = None return None
# 查找对应的组件信息 if len(candidates) > 1:
for name, cls in self._command_registry.items(): logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配")
if cls == command_class: command_name = self._command_patterns[candidates[0]]
command_name = name command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore
break return (
self._command_registry[command_name],
candidates[0].match(text).groupdict(), # type: ignore
command_info.intercept_message,
command_info.plugin_name,
)
# 检查命令是否启用 # === 事件处理器特定查询方法 ===
if command_name:
command_info = self.get_command_info(command_name)
if command_info and command_info.enabled:
return (
command_class,
match.groupdict(),
command_info.intercept_message,
command_info.plugin_name,
)
return None
# === 插件管理方法 === def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
"""获取事件处理器注册表"""
return self._event_handler_registry.copy()
def register_plugin(self, plugin_info: PluginInfo) -> bool: def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
"""注册插件 """获取事件处理器信息"""
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
return info if isinstance(info, EventHandlerInfo) else None
Args: def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
plugin_info: 插件信息 """获取启用的事件处理器"""
return self._enabled_event_handlers.copy()
Returns: # === 插件查询方法 ===
bool: 是否注册成功
"""
plugin_name = plugin_info.name
if plugin_name in self._plugins:
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
return False
self._plugins[plugin_name] = plugin_info
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
return True
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]: def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
"""获取插件信息""" """获取插件信息"""
@@ -344,82 +388,22 @@ class ComponentRegistry:
plugin_instance = plugin_manager.get_plugin_instance(plugin_name) plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
return plugin_instance.config if plugin_instance else None return plugin_instance.config if plugin_instance else None
# === 状态管理方法 ===
# def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# # -------------------------------- LOGIC ERROR -------------------------------------
# """启用组件,支持命名空间解析"""
# # 首先尝试找到正确的命名空间化名称
# component_info = self.get_component_info(component_name, component_type)
# if not component_info:
# return False
# # 根据组件类型构造正确的命名空间化名称
# if component_info.component_type == ComponentType.ACTION:
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
# elif component_info.component_type == ComponentType.COMMAND:
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
# else:
# namespaced_name = (
# f"{component_info.component_type.value}.{component_name}"
# if "." not in component_name
# else component_name
# )
# if namespaced_name in self._components:
# self._components[namespaced_name].enabled = True
# # 如果是Action更新默认动作集
# # ---- HERE ----
# # if isinstance(component_info, ActionInfo):
# # self._action_descriptions[component_name] = component_info.description
# logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
# return True
# return False
# def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# # -------------------------------- LOGIC ERROR -------------------------------------
# """禁用组件,支持命名空间解析"""
# # 首先尝试找到正确的命名空间化名称
# component_info = self.get_component_info(component_name, component_type)
# if not component_info:
# return False
# # 根据组件类型构造正确的命名空间化名称
# if component_info.component_type == ComponentType.ACTION:
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
# elif component_info.component_type == ComponentType.COMMAND:
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
# else:
# namespaced_name = (
# f"{component_info.component_type.value}.{component_name}"
# if "." not in component_name
# else component_name
# )
# if namespaced_name in self._components:
# self._components[namespaced_name].enabled = False
# # 如果是Action从默认动作集中移除
# # ---- HERE ----
# # if component_name in self._action_descriptions:
# # del self._action_descriptions[component_name]
# logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
# return True
# return False
def get_registry_stats(self) -> Dict[str, Any]: def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息""" """获取注册中心统计信息"""
action_components: int = 0 action_components: int = 0
command_components: int = 0 command_components: int = 0
events_handlers: int = 0
for component in self._components.values(): for component in self._components.values():
if component.component_type == ComponentType.ACTION: if component.component_type == ComponentType.ACTION:
action_components += 1 action_components += 1
elif component.component_type == ComponentType.COMMAND: elif component.component_type == ComponentType.COMMAND:
command_components += 1 command_components += 1
elif component.component_type == ComponentType.EVENT_HANDLER:
events_handlers += 1
return { return {
"action_components": action_components, "action_components": action_components,
"command_components": command_components, "command_components": command_components,
"event_handlers": events_handlers,
"total_components": len(self._components), "total_components": len(self._components),
"total_plugins": len(self._plugins), "total_plugins": len(self._plugins),
"components_by_type": { "components_by_type": {
@@ -430,5 +414,4 @@ class ComponentRegistry:
} }
# 全局组件注册中心实例
component_registry = ComponentRegistry() component_registry = ComponentRegistry()

View File

@@ -1,11 +1,136 @@
from typing import List, Dict, Type import asyncio
from typing import List, Dict, Optional, Type
from src.plugin_system.base.component_types import EventType from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("events_manager")
class EventsManager: class EventsManager:
def __init__(self): def __init__(self):
# 有权重的 events 订阅者注册表 # 有权重的 events 订阅者注册表
self.events_subscribers: Dict[EventType, List[Dict[int, Type]]] = {event: [] for event in EventType} self.events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
self.handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
events_manager = EventsManager() def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
"""注册事件处理器
Args:
handler_info (EventHandlerInfo): 事件处理器信息
handler_class (Type[BaseEventHandler]): 事件处理器类
Returns:
bool: 是否注册成功
"""
handler_name = handler_info.name
plugin_name = getattr(handler_info, "plugin_name", "unknown")
namespace_name = f"{plugin_name}.{handler_name}"
if namespace_name in self.handler_mapping:
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
return False
if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")
return False
self.handler_mapping[namespace_name] = handler_class
return self._insert_event_handler(handler_class)
async def handler_mai_events(
self,
event_type: EventType,
message: MessageRecv,
llm_prompt: Optional[str] = None,
llm_response: Optional[str] = None,
) -> None:
"""处理 events"""
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
for handler in self.events_subscribers.get(event_type, []):
if handler.intercept_message:
await handler.execute(transformed_message)
else:
asyncio.create_task(handler.execute(transformed_message))
def _insert_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
"""插入事件处理器到对应的事件类型列表中"""
if handler_class.event_type == EventType.UNKNOWN:
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
return False
self.events_subscribers[handler_class.event_type].append(handler_class())
self.events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
return True
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
"""从事件类型列表中移除事件处理器"""
if handler_class.event_type == EventType.UNKNOWN:
logger.warning(f"事件处理器 {handler_class.__name__} 的事件类型未知,不存在于处理器列表中")
return False
handlers = self.events_subscribers[handler_class.event_type]
for i, handler in enumerate(handlers):
if isinstance(handler, handler_class):
del handlers[i]
logger.debug(f"事件处理器 {handler_class.__name__} 已移除")
return True
logger.warning(f"未找到事件处理器 {handler_class.__name__},无法移除")
return False
def _transform_event_message(
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[str] = None
) -> MaiMessages:
"""转换事件消息格式"""
# 直接赋值部分内容
transformed_message = MaiMessages(
llm_prompt=llm_prompt,
llm_response=llm_response,
raw_message=message.raw_message,
additional_data=message.message_info.additional_config or {},
)
# 消息段处理
if message.message_segment.type == "seglist":
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed_message.message_segments = [message.message_segment]
# stream_id 处理
if hasattr(message, "chat_stream"):
transformed_message.stream_id = message.chat_stream.stream_id
# 处理后文本
transformed_message.plain_text = message.processed_plain_text
# 基本信息
if message.message_info.platform:
transformed_message.message_base_info["platform"] = message.message_info.platform
if message.message_info.group_info:
transformed_message.is_group_message = True
transformed_message.message_base_info.update(
{
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
}
)
if message.message_info.user_info:
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
transformed_message.message_base_info.update(
{
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
}
)
return transformed_message
events_manager = EventsManager()

View File

@@ -1,10 +1,12 @@
from typing import Dict, List, Optional, Tuple, Type, Any
import os import os
from importlib.util import spec_from_file_location, module_from_spec import inspect
from inspect import getmodule
from pathlib import Path
import traceback import traceback
from typing import Dict, List, Optional, Tuple, Type, Any
from importlib.util import spec_from_file_location, module_from_spec
from pathlib import Path
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.dependency_manager import dependency_manager from src.plugin_system.core.dependency_manager import dependency_manager
@@ -28,7 +30,7 @@ class PluginManager:
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径 self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例 self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件及其错误信息,插件名 -> 错误信息 self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
# 确保插件目录存在 # 确保插件目录存在
self._ensure_plugin_directories() self._ensure_plugin_directories()
@@ -107,13 +109,9 @@ class PluginManager:
# 使用记录的插件目录路径 # 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name) plugin_dir = self.plugin_paths.get(plugin_name)
# 如果没有记录,则尝试查找fallback # 如果没有记录,直接返回失败
if not plugin_dir: if not plugin_dir:
plugin_dir = self._find_plugin_directory(plugin_class) return False, 1
if plugin_dir:
self.plugin_paths[plugin_name] = plugin_dir # 更新路径
else:
return False, 1
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败 plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
if not plugin_instance: if not plugin_instance:
@@ -360,24 +358,14 @@ class PluginManager:
logger.debug(f"正在扫描插件根目录: {directory}") logger.debug(f"正在扫描插件根目录: {directory}")
# 遍历目录中的所有Python文件和 # 遍历目录中的所有包
for item in os.listdir(directory): for item in os.listdir(directory):
item_path = os.path.join(directory, item) item_path = os.path.join(directory, item)
if os.path.isfile(item_path) and item.endswith(".py") and item != "__init__.py": if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
# 单文件插件
plugin_name = Path(item_path).stem
if self._load_plugin_module_file(item_path, plugin_name, directory):
loaded_count += 1
else:
failed_count += 1
elif os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
# 插件包
plugin_file = os.path.join(item_path, "plugin.py") plugin_file = os.path.join(item_path, "plugin.py")
if os.path.exists(plugin_file): if os.path.exists(plugin_file):
plugin_name = item # 使用目录名作为插件名 if self._load_plugin_module_file(plugin_file):
if self._load_plugin_module_file(plugin_file, plugin_name, item_path):
loaded_count += 1 loaded_count += 1
else: else:
failed_count += 1 failed_count += 1
@@ -387,14 +375,16 @@ class PluginManager:
def _find_plugin_directory(self, plugin_class: Type[PluginBase]) -> Optional[str]: def _find_plugin_directory(self, plugin_class: Type[PluginBase]) -> Optional[str]:
"""查找插件类对应的目录路径""" """查找插件类对应的目录路径"""
try: try:
module = getmodule(plugin_class) # module = getmodule(plugin_class)
if module and hasattr(module, "__file__") and module.__file__: # if module and hasattr(module, "__file__") and module.__file__:
return os.path.dirname(module.__file__) # return os.path.dirname(module.__file__)
file_path = inspect.getfile(plugin_class)
return os.path.dirname(file_path)
except Exception as e: except Exception as e:
logger.debug(f"通过inspect获取插件目录失败: {e}") logger.debug(f"通过inspect获取插件目录失败: {e}")
return None return None
def _load_plugin_module_file(self, plugin_file: str, plugin_name: str, plugin_dir: str) -> bool: def _load_plugin_module_file(self, plugin_file: str) -> bool:
# sourcery skip: extract-method # sourcery skip: extract-method
"""加载单个插件模块文件 """加载单个插件模块文件
@@ -405,12 +395,7 @@ class PluginManager:
""" """
# 生成模块名 # 生成模块名
plugin_path = Path(plugin_file) plugin_path = Path(plugin_file)
if plugin_path.parent.name != "plugins": module_name = ".".join(plugin_path.parent.parts)
# 插件包格式parent_dir.plugin
module_name = f"plugins.{plugin_path.parent.name}.plugin"
else:
# 单文件格式plugins.filename
module_name = f"plugins.{plugin_path.stem}"
try: try:
# 动态导入插件模块 # 动态导入插件模块
@@ -422,16 +407,13 @@ class PluginManager:
module = module_from_spec(spec) module = module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
# 记录插件名和目录路径的映射
self.plugin_paths[plugin_name] = plugin_dir
logger.debug(f"插件模块加载成功: {plugin_file}") logger.debug(f"插件模块加载成功: {plugin_file}")
return True return True
except Exception as e: except Exception as e:
error_msg = f"加载插件模块 {plugin_file} 失败: {e}" error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
logger.error(error_msg) logger.error(error_msg)
self.failed_plugins[plugin_name] = error_msg self.failed_plugins[module_name] = error_msg
return False return False
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
@@ -475,13 +457,14 @@ class PluginManager:
stats = component_registry.get_registry_stats() stats = component_registry.get_registry_stats()
action_count = stats.get("action_components", 0) action_count = stats.get("action_components", 0)
command_count = stats.get("command_components", 0) command_count = stats.get("command_components", 0)
event_handler_count = stats.get("event_handlers", 0)
total_components = stats.get("total_components", 0) total_components = stats.get("total_components", 0)
# 📋 显示插件加载总览 # 📋 显示插件加载总览
if total_registered > 0: if total_registered > 0:
logger.info("🎉 插件系统加载完成!") logger.info("🎉 插件系统加载完成!")
logger.info( logger.info(
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})" f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, EventHandler: {event_handler_count})"
) )
# 显示详细的插件列表 # 显示详细的插件列表
@@ -510,8 +493,9 @@ class PluginManager:
# 组件列表 # 组件列表
if plugin_info.components: if plugin_info.components:
action_components = [c for c in plugin_info.components if c.component_type.name == "ACTION"] action_components = [c for c in plugin_info.components if c.component_type == ComponentType.ACTION]
command_components = [c for c in plugin_info.components if c.component_type.name == "COMMAND"] command_components = [c for c in plugin_info.components if c.component_type == ComponentType.COMMAND]
event_handler_components = [c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER]
if action_components: if action_components:
action_names = [c.name for c in action_components] action_names = [c.name for c in action_components]
@@ -520,6 +504,10 @@ class PluginManager:
if command_components: if command_components:
command_names = [c.name for c in command_components] command_names = [c.name for c in command_components]
logger.info(f" ⚡ Command组件: {', '.join(command_names)}") logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
if event_handler_components:
event_handler_names = [c.name for c in event_handler_components]
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
# 依赖信息 # 依赖信息
if plugin_info.dependencies: if plugin_info.dependencies:
@@ -530,6 +518,12 @@ class PluginManager:
config_status = "" if self.plugin_paths.get(plugin_name) else "" config_status = "" if self.plugin_paths.get(plugin_name) else ""
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}") logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
root_path = Path(__file__)
# 查找项目根目录
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
root_path = root_path.parent
# 显示目录统计 # 显示目录统计
logger.info("📂 加载目录统计:") logger.info("📂 加载目录统计:")
for directory in self.plugin_directories: for directory in self.plugin_directories:
@@ -537,7 +531,11 @@ class PluginManager:
plugins_in_dir = [] plugins_in_dir = []
for plugin_name in self.loaded_plugins.keys(): for plugin_name in self.loaded_plugins.keys():
plugin_path = self.plugin_paths.get(plugin_name, "") plugin_path = self.plugin_paths.get(plugin_name, "")
if plugin_path.startswith(directory): if (
Path(plugin_path)
.resolve()
.is_relative_to(Path(os.path.join(str(root_path), directory)).resolve())
):
plugins_in_dir.append(plugin_name) plugins_in_dir.append(plugin_name)
if plugins_in_dir: if plugins_in_dir:

View File

@@ -80,8 +80,9 @@ class ReplyAction(BaseAction):
logger.info(f"{self.log_prefix} 回复目标: {reply_to}") logger.info(f"{self.log_prefix} 回复目标: {reply_to}")
try: try:
prepared_reply = self.action_data.get("prepared_reply", "") if prepared_reply := self.action_data.get("prepared_reply", ""):
if not prepared_reply: reply_text = prepared_reply
else:
try: try:
success, reply_set, _ = await asyncio.wait_for( success, reply_set, _ = await asyncio.wait_for(
generator_api.generate_reply( generator_api.generate_reply(
@@ -109,9 +110,6 @@ class ReplyAction(BaseAction):
logger.info( logger.info(
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复" f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复"
) )
else:
reply_text = prepared_reply
# 构建回复文本 # 构建回复文本
reply_text = "" reply_text = ""
first_replied = False first_replied = False
@@ -120,11 +118,12 @@ class ReplyAction(BaseAction):
data = reply_seg[1] data = reply_seg[1]
if not first_replied: if not first_replied:
if need_reply: if need_reply:
await self.send_text(content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False) await self.send_text(
first_replied = True content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
)
else: else:
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False) await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False)
first_replied = True first_replied = True
else: else:
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True) await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True)
reply_text += data reply_text += data
@@ -190,17 +189,15 @@ class CoreActionsPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表""" """返回插件包含的组件列表"""
# --- 从配置动态设置Action/Command --- if global_config.emoji.emoji_activate_type == "llm":
emoji_chance = global_config.emoji.emoji_chance
if global_config.emoji.emoji_activate_type == "random":
EmojiAction.random_activation_probability = emoji_chance
EmojiAction.focus_activation_type = ActionActivationType.RANDOM
EmojiAction.normal_activation_type = ActionActivationType.RANDOM
elif global_config.emoji.emoji_activate_type == "llm":
EmojiAction.random_activation_probability = 0.0 EmojiAction.random_activation_probability = 0.0
EmojiAction.focus_activation_type = ActionActivationType.LLM_JUDGE EmojiAction.focus_activation_type = ActionActivationType.LLM_JUDGE
EmojiAction.normal_activation_type = ActionActivationType.LLM_JUDGE EmojiAction.normal_activation_type = ActionActivationType.LLM_JUDGE
elif global_config.emoji.emoji_activate_type == "random":
EmojiAction.random_activation_probability = global_config.emoji.emoji_chance
EmojiAction.focus_activation_type = ActionActivationType.RANDOM
EmojiAction.normal_activation_type = ActionActivationType.RANDOM
# --- 根据配置注册组件 --- # --- 根据配置注册组件 ---
components = [] components = []
if self.get_config("components.enable_reply", True): if self.get_config("components.enable_reply", True):

View File

@@ -33,7 +33,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
Dict: 工具执行结果 Dict: 工具执行结果
""" """
try: try:
query = function_args.get("query") query: str = function_args.get("query") # type: ignore
# threshold = function_args.get("threshold", 0.4) # threshold = function_args.get("threshold", 0.4)
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用

View File

@@ -1,5 +1,5 @@
[inner] [inner]
version = "4.4.3" version = "4.4.4"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请在修改后将version的值进行变更 #如果你想要修改配置文件请在修改后将version的值进行变更
@@ -87,6 +87,7 @@ talk_frequency_adjust = [
# - 时间支持跨天,例如 "00:10,0.3" 表示从凌晨0:10开始使用频率0.3 # - 时间支持跨天,例如 "00:10,0.3" 表示从凌晨0:10开始使用频率0.3
# - 系统会自动将 "platform:id:type" 转换为内部的哈希chat_id进行匹配 # - 系统会自动将 "platform:id:type" 转换为内部的哈希chat_id进行匹配
enable_asr = false # 是否启用语音识别,启用后麦麦可以通过语音输入进行对话,启用该功能需要配置语音识别模型[model.voice]
[message_receive] [message_receive]
# 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息 # 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息
@@ -294,6 +295,12 @@ provider = "SILICONFLOW"
pri_in = 0.35 pri_in = 0.35
pri_out = 0.35 pri_out = 0.35
[model.voice] # 语音识别模型
name = "FunAudioLLM/SenseVoiceSmall"
provider = "SILICONFLOW"
pri_in = 0
pri_out = 0
[model.tool_use] #工具调用模型,需要使用支持工具调用的模型 [model.tool_use] #工具调用模型,需要使用支持工具调用的模型
name = "Qwen/Qwen3-14B" name = "Qwen/Qwen3-14B"
provider = "SILICONFLOW" provider = "SILICONFLOW"