diff --git a/changes.md b/changes.md index 1f53d7e50..0d6b507b9 100644 --- a/changes.md +++ b/changes.md @@ -1,26 +1,60 @@ # 插件API与规范修改 -1. 现在`plugin_system`的`__init__.py`文件中包含了所有插件API的导入,用户可以直接使用`from plugin_system import *`来导入所有API。 +1. 现在`plugin_system`的`__init__.py`文件中包含了所有插件API的导入,用户可以直接使用`from src.plugin_system import *`来导入所有API。 -2. register_plugin函数现在转移到了`plugin_system.apis.plugin_register_api`模块中,用户可以通过`from plugin_system.apis.plugin_register_api import register_plugin`来导入。 +2. register_plugin函数现在转移到了`plugin_system.apis.plugin_register_api`模块中,用户可以通过`from src.plugin_system.apis.plugin_register_api import register_plugin`来导入。 + - 顺便一提,按照1中说法,你可以这么用: + ```python + from src.plugin_system import register_plugin + ``` -3. 现在强制要求的property如下: - - `plugin_name`: 插件名称,必须是唯一的。(与文件夹相同) - - `enable_plugin`: 是否启用插件,默认为`True`。 - - `dependencies`: 插件依赖的其他插件列表,默认为空。**现在并不检查(也许)** - - `python_dependencies`: 插件依赖的Python包列表,默认为空。**现在并不检查** - - `config_file_name`: 插件配置文件名,默认为`config.toml`。 - - `config_schema`: 插件配置文件的schema,用于自动生成配置文件。 +3. 现在强制要求的property如下,即你必须覆盖的属性有: + - `plugin_name`: 插件名称,必须是唯一的。(与文件夹相同) + - `enable_plugin`: 是否启用插件,默认为`True`。 + - `dependencies`: 插件依赖的其他插件列表,默认为空。**现在并不检查(也许)** + - `python_dependencies`: 插件依赖的Python包列表,默认为空。**现在并不检查** + - `config_file_name`: 插件配置文件名,默认为`config.toml`。 + - `config_schema`: 插件配置文件的schema,用于自动生成配置文件。 +4. 部分API的参数类型和返回值进行了调整 + - `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。 + - `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。 + - `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。 +5. 增加了`logging_api`,可以用`get_logger`来获取日志记录器。 # 插件系统修改 1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)** 2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容 -3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。**(可能有遗漏)** -3. 部分API的参数类型和返回值进行了调整 - - `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。 - - `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。 - - `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。 +3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。 4. 现在增加了参数类型检查,完善了对应注释 5. 现在插件抽象出了总基类 `PluginBase` - - 基于`Action`和`Command`的插件基类现在为`BasePlugin`,它继承自`PluginBase`,由`register_plugin`装饰器注册。 - - 基于`Event`的插件基类现在为`BaseEventPlugin`,它也继承自`PluginBase`,由`register_event_plugin`装饰器注册。 \ No newline at end of file + - 基于`Action`和`Command`的插件基类现在为`BasePlugin`。 + - 基于`Event`的插件基类现在为`BaseEventPlugin`。 + - 基于`Action`,`Command`和`Event`的插件基类现在为`BasePlugin`,所有插件都应该继承此基类。 + - `BasePlugin`继承自`PluginBase`。 + - 所有的插件类都由`register_plugin`装饰器注册。 +6. 现在我们终于可以让插件有自定义的名字了! + - 真正实现了插件的`plugin_name`**不受文件夹名称限制**的功能。(吐槽:可乐你的某个小小细节导致我搞了好久……) + - 通过在插件类中定义`plugin_name`属性来指定插件内部标识符。 + - 由于此更改一个文件中现在可以有多个插件类,但每个插件类必须有**唯一的**`plugin_name`。 + - 在某些插件加载失败时,现在会显示包名而不是插件内部标识符。 + - 例如:`MaiMBot.plugins.example_plugin`而不是`example_plugin`。 + - 仅在插件 import 失败时会如此,正常注册过程中失败的插件不会显示包名,而是显示插件内部标识符。(这是特性,但是基本上不可能出现这个情况) +7. 现在不支持单文件插件了,加载方式已经完全删除。 +8. 把`BaseEventPlugin`合并到了`BasePlugin`中,所有插件都应该继承自`BasePlugin`。 + + +# 吐槽 +```python +plugin_path = Path(plugin_file) +if plugin_path.parent.name != "plugins": + # 插件包格式:parent_dir.plugin + module_name = f"plugins.{plugin_path.parent.name}.plugin" +else: + # 单文件格式:plugins.filename + module_name = f"plugins.{plugin_path.stem}" +``` +```python +plugin_path = Path(plugin_file) +module_name = ".".join(plugin_path.parent.parts) +``` +这两个区别很大的。 \ No newline at end of file diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index dc9b8571c..2d645616b 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -7,11 +7,13 @@ from src.plugin_system import ( ComponentInfo, ActionActivationType, ConfigField, + BaseEventHandler, + EventType, + MaiMessages, ) + # ===== Action组件 ===== - - class HelloAction(BaseAction): """问候Action - 简单的问候动作""" @@ -82,7 +84,7 @@ class TimeCommand(BaseCommand): import datetime # 获取当前时间 - time_format = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") + time_format: str = self.get_config("time.format", "%Y-%m-%d %H:%M:%S") # type: ignore now = datetime.datetime.now() time_str = now.strftime(time_format) @@ -93,6 +95,20 @@ class TimeCommand(BaseCommand): return True, f"显示了当前时间: {time_str}" +class PrintMessage(BaseEventHandler): + """打印消息事件处理器 - 处理打印消息事件""" + + event_type = EventType.ON_MESSAGE + handler_name = "print_message_handler" + handler_description = "打印接收到的消息" + + async def execute(self, message: MaiMessages) -> Tuple[bool, str | None]: + """执行打印消息事件处理""" + # 打印接收到的消息 + print(f"接收到消息: {message.raw_message}") + return True, "消息已打印" + + # ===== 插件注册 ===== @@ -122,6 +138,7 @@ class HelloWorldPlugin(BasePlugin): "enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"), }, "time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")}, + "print_message": {"enabled": ConfigField(type=bool, default=True, description="是否启用打印")}, } def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: @@ -129,4 +146,27 @@ class HelloWorldPlugin(BasePlugin): (HelloAction.get_action_info(), HelloAction), (ByeAction.get_action_info(), ByeAction), # 添加告别Action (TimeCommand.get_command_info(), TimeCommand), + (PrintMessage.get_handler_info(), PrintMessage), ] + + +# @register_plugin +# class HelloWorldEventPlugin(BaseEPlugin): +# """Hello World事件插件 - 处理问候和告别事件""" + +# plugin_name = "hello_world_event_plugin" +# enable_plugin = False +# dependencies = [] +# python_dependencies = [] +# config_file_name = "event_config.toml" + +# config_schema = { +# "plugin": { +# "name": ConfigField(type=str, default="hello_world_event_plugin", description="插件名称"), +# "version": ConfigField(type=str, default="1.0.0", description="插件版本"), +# "enabled": ConfigField(type=bool, default=True, description="是否启用插件"), +# }, +# } + +# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: +# return [(PrintMessage.get_handler_info(), PrintMessage)] diff --git a/plugins/take_picture_plugin/plugin.py b/plugins/take_picture_plugin/plugin(deprecated).py similarity index 100% rename from plugins/take_picture_plugin/plugin.py rename to plugins/take_picture_plugin/plugin(deprecated).py diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index dd2d5374a..738479754 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -179,8 +179,7 @@ class HeartFChatting: await asyncio.sleep(10) if self.loop_mode == ChatMode.NORMAL: self.energy_value -= 0.3 - if self.energy_value <= 0.3: - self.energy_value = 0.3 + self.energy_value = max(self.energy_value, 0.3) def print_cycle_info(self, cycle_timers): # 记录循环信息和计时器结果 @@ -257,6 +256,7 @@ class HeartFChatting: return f"{person_name}:{message_data.get('processed_plain_text')}" async def _observe(self, message_data: Optional[Dict[str, Any]] = None): + # sourcery skip: hoist-statement-from-if, merge-comparisons, reintroduce-else if not message_data: message_data = {} action_type = "no_action" @@ -462,7 +462,7 @@ class HeartFChatting: 在"兴趣"模式下,判断是否回复并生成内容。 """ - interested_rate = message_data.get("interest_value", 0.0) * self.willing_amplifier + interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier self.willing_manager.setup(message_data, self.chat_stream) diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 808b8013b..d732683ae 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -106,10 +106,10 @@ class EmbeddingStore: asyncio.get_running_loop() # 如果在事件循环中,使用线程池执行 import concurrent.futures - + def run_in_thread(): return asyncio.run(get_embedding(s)) - + with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(run_in_thread) result = future.result() @@ -294,10 +294,10 @@ class EmbeddingStore: """ if self.faiss_index is None: logger.debug("FaissIndex尚未构建,返回None") - return None + return [] if self.idx2hash is None: logger.warning("idx2hash尚未构建,返回None") - return None + return [] # L2归一化 faiss.normalize_L2(np.array([query], dtype=np.float32)) @@ -318,15 +318,15 @@ class EmbeddingStore: class EmbeddingManager: def __init__(self): self.paragraphs_embedding_store = EmbeddingStore( - local_storage['pg_namespace'], + local_storage["pg_namespace"], # type: ignore EMBEDDING_DATA_DIR_STR, ) self.entities_embedding_store = EmbeddingStore( - local_storage['pg_namespace'], + local_storage["pg_namespace"], # type: ignore EMBEDDING_DATA_DIR_STR, ) self.relation_embedding_store = EmbeddingStore( - local_storage['pg_namespace'], + local_storage["pg_namespace"], # type: ignore EMBEDDING_DATA_DIR_STR, ) self.stored_pg_hashes = set() diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index e18a7da80..083a741d6 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -30,20 +30,20 @@ def _get_kg_dir(): """ 安全地获取KG数据目录路径 """ - root_path = local_storage['root_path'] + root_path: str = local_storage["root_path"] if root_path is None: # 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用 current_dir = os.path.dirname(os.path.abspath(__file__)) root_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}") - + # 获取RAG数据目录 - rag_data_dir = global_config["persistence"]["rag_data_dir"] + rag_data_dir: str = global_config["persistence"]["rag_data_dir"] if rag_data_dir is None: kg_dir = os.path.join(root_path, "data/rag") else: kg_dir = os.path.join(root_path, rag_data_dir) - + return str(kg_dir).replace("\\", "/") @@ -65,9 +65,9 @@ class KGManager: # 持久化相关 - 使用延迟初始化的路径 self.dir_path = get_kg_dir_str() - self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml" - self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet" - self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json" + self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml" + self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet" + self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json" def save_to_file(self): """将KG数据保存到文件""" @@ -91,11 +91,11 @@ class KGManager: """从文件加载KG数据""" # 确保文件存在 if not os.path.exists(self.pg_hash_file_path): - raise Exception(f"KG段落hash文件{self.pg_hash_file_path}不存在") + raise FileNotFoundError(f"KG段落hash文件{self.pg_hash_file_path}不存在") if not os.path.exists(self.ent_cnt_data_path): - raise Exception(f"KG实体计数文件{self.ent_cnt_data_path}不存在") + raise FileNotFoundError(f"KG实体计数文件{self.ent_cnt_data_path}不存在") if not os.path.exists(self.graph_data_path): - raise Exception(f"KG图文件{self.graph_data_path}不存在") + raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在") # 加载段落hash with open(self.pg_hash_file_path, "r", encoding="utf-8") as f: @@ -122,8 +122,8 @@ class KGManager: # 避免自连接 continue # 一个triple就是一条边(同时构建双向联系) - hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) - hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2]) + hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) + hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2]) node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 entity_set.add(hash_key1) @@ -141,8 +141,8 @@ class KGManager: """构建实体节点与文段节点之间的关系""" for idx in triple_list_data: for triple in triple_list_data[idx]: - ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) - pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx) + ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) + pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx) node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 @staticmethod @@ -157,8 +157,8 @@ class KGManager: ent_hash_list = set() for triple_list in triple_list_data.values(): for triple in triple_list: - ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0])) - ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2])) + ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0])) + ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2])) ent_hash_list = list(ent_hash_list) synonym_hash_set = set() @@ -263,7 +263,7 @@ class KGManager: for src_tgt in node_to_node.keys(): for node_hash in src_tgt: if node_hash not in existed_nodes: - if node_hash.startswith(local_storage['ent_namespace']): + if node_hash.startswith(local_storage["ent_namespace"]): # 新增实体节点 node = embedding_manager.entities_embedding_store.store.get(node_hash) if node is None: @@ -275,7 +275,7 @@ class KGManager: node_item["type"] = "ent" node_item["create_time"] = now_time self.graph.update_node(node_item) - elif node_hash.startswith(local_storage['pg_namespace']): + elif node_hash.startswith(local_storage["pg_namespace"]): # 新增文段节点 node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) if node is None: @@ -359,7 +359,7 @@ class KGManager: # 关系三元组 triple = relation[2:-2].split("', '") for ent in [(triple[0]), (triple[2])]: - ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent) + ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent) if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash not in ent_sim_scores: # 尚未记录的实体 ent_sim_scores[ent_hash] = [] @@ -437,7 +437,9 @@ class KGManager: # 获取最终结果 # 从搜索结果中提取文段节点的结果 passage_node_res = [ - (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace']) + (node_key, score) + for node_key, score in ppr_res.items() + if node_key.startswith(local_storage["pg_namespace"]) ] del ppr_res diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index 180a16ca1..1e87d3824 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -33,6 +33,7 @@ RAG_PG_HASH_NAMESPACE = "rag-pg-hash" ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) DATA_PATH = os.path.join(ROOT_PATH, "data") + def _initialize_knowledge_local_storage(): """ 初始化知识库相关的本地存储配置 @@ -41,55 +42,58 @@ def _initialize_knowledge_local_storage(): # 定义所有需要初始化的配置项 default_configs = { # 路径配置 - 'root_path': ROOT_PATH, - 'data_path': f"{ROOT_PATH}/data", - + "root_path": ROOT_PATH, + "data_path": f"{ROOT_PATH}/data", # 实体和命名空间配置 - 'lpmm_invalid_entity': INVALID_ENTITY, - 'pg_namespace': PG_NAMESPACE, - 'ent_namespace': ENT_NAMESPACE, - 'rel_namespace': REL_NAMESPACE, - + "lpmm_invalid_entity": INVALID_ENTITY, + "pg_namespace": PG_NAMESPACE, + "ent_namespace": ENT_NAMESPACE, + "rel_namespace": REL_NAMESPACE, # RAG相关命名空间配置 - 'rag_graph_namespace': RAG_GRAPH_NAMESPACE, - 'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE, - 'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE + "rag_graph_namespace": RAG_GRAPH_NAMESPACE, + "rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE, + "rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE, } - + # 日志级别映射:重要配置用info,其他用debug - important_configs = {'root_path', 'data_path'} - + important_configs = {"root_path", "data_path"} + # 批量设置配置项 initialized_count = 0 for key, default_value in default_configs.items(): if local_storage[key] is None: local_storage[key] = default_value - + # 根据重要性选择日志级别 if key in important_configs: logger.info(f"设置{key}: {default_value}") else: logger.debug(f"设置{key}: {default_value}") - + initialized_count += 1 - + if initialized_count > 0: logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置") else: logger.debug("知识库本地存储配置已存在,跳过初始化") - + + # 初始化本地存储路径 +# sourcery skip: dict-comprehension _initialize_knowledge_local_storage() +qa_manager = None +inspire_manager = None + # 检查LPMM知识库是否启用 if bot_global_config.lpmm_knowledge.enable: logger.info("正在初始化Mai-LPMM") logger.info("创建LLM客户端") - llm_client_list = dict() + llm_client_list = {} for key in global_config["llm_providers"]: llm_client_list[key] = LLMClient( - global_config["llm_providers"][key]["base_url"], - global_config["llm_providers"][key]["api_key"], + global_config["llm_providers"][key]["base_url"], # type: ignore + global_config["llm_providers"][key]["api_key"], # type: ignore ) # 初始化Embedding库 @@ -98,7 +102,7 @@ if bot_global_config.lpmm_knowledge.enable: try: embed_manager.load_from_file() except Exception as e: - logger.warning("此消息不会影响正常使用:从文件加载Embedding库时,{}".format(e)) + logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}") # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") logger.info("Embedding库加载完成") # 初始化KG @@ -107,7 +111,7 @@ if bot_global_config.lpmm_knowledge.enable: try: kg_manager.load_from_file() except Exception as e: - logger.warning("此消息不会影响正常使用:从文件加载KG时,{}".format(e)) + logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}") # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") logger.info("KG加载完成") @@ -116,7 +120,7 @@ if bot_global_config.lpmm_knowledge.enable: # 数据比对:Embedding库与KG的段落hash集合 for pg_hash in kg_manager.stored_paragraph_hashes: - key = PG_NAMESPACE + "-" + pg_hash + key = f"{PG_NAMESPACE}-{pg_hash}" if key not in embed_manager.stored_pg_hashes: logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") @@ -134,5 +138,3 @@ if bot_global_config.lpmm_knowledge.enable: else: logger.info("LPMM知识库已禁用,跳过初始化") # 创建空的占位符对象,避免导入错误 - qa_manager = None - inspire_manager = None diff --git a/src/chat/knowledge/prompt_template.py b/src/chat/knowledge/prompt_template.py index fe5a293c0..485103aad 100644 --- a/src/chat/knowledge/prompt_template.py +++ b/src/chat/knowledge/prompt_template.py @@ -1,5 +1,3 @@ -from .llm_client import LLMMessage - entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体,并以JSON列表的形式输出。 输出格式示例: @@ -63,10 +61,10 @@ qa_system_prompt = """ """ -def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]: - knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]) - messages = [ - LLMMessage("system", qa_system_prompt).to_dict(), - LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(), - ] - return messages +# def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]: +# knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]) +# messages = [ +# LLMMessage("system", qa_system_prompt).to_dict(), +# LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(), +# ] +# return messages diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index b179a3098..36737eb77 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -9,6 +9,7 @@ from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase from src.common.logger import get_logger from src.chat.utils.utils_image import get_image_manager +from src.chat.utils.utils_voice import get_voice_text from .chat_stream import ChatStream install(extra_lines=3) @@ -106,6 +107,7 @@ class MessageRecv(Message): self.has_emoji = False self.is_picid = False self.has_picid = False + self.is_voice = False self.is_mentioned = None self.is_command = False @@ -153,17 +155,27 @@ class MessageRecv(Message): self.has_emoji = True self.is_emoji = True self.is_picid = False + self.is_voice = False if isinstance(segment.data, str): return await get_image_manager().get_emoji_description(segment.data) return "[发了一个表情包,网卡了加载不出来]" + elif segment.type == "voice": + self.is_picid = False + self.is_emoji = False + self.is_voice = True + if isinstance(segment.data, str): + return await get_voice_text(segment.data) + return "[发了一段语音,网卡了加载不出来]" elif segment.type == "mention_bot": self.is_picid = False self.is_emoji = False + self.is_voice = False self.is_mentioned = float(segment.data) # type: ignore return "" elif segment.type == "priority_info": self.is_picid = False self.is_emoji = False + self.is_voice = False if isinstance(segment.data, dict): # 处理优先级信息 self.priority_mode = "priority" @@ -212,10 +224,12 @@ class MessageRecvS4U(MessageRecv): """ try: if segment.type == "text": + self.is_voice = False self.is_picid = False self.is_emoji = False return segment.data # type: ignore elif segment.type == "image": + self.is_voice = False # 如果是base64图片数据 if isinstance(segment.data, str): self.has_picid = True @@ -233,12 +247,22 @@ class MessageRecvS4U(MessageRecv): if isinstance(segment.data, str): return await get_image_manager().get_emoji_description(segment.data) return "[发了一个表情包,网卡了加载不出来]" + elif segment.type == "voice": + self.has_picid = False + self.is_picid = False + self.is_emoji = False + self.is_voice = True + if isinstance(segment.data, str): + return await get_voice_text(segment.data) + return "[发了一段语音,网卡了加载不出来]" elif segment.type == "mention_bot": + self.is_voice = False self.is_picid = False self.is_emoji = False self.is_mentioned = float(segment.data) # type: ignore return "" elif segment.type == "priority_info": + self.is_voice = False self.is_picid = False self.is_emoji = False if isinstance(segment.data, dict): @@ -253,6 +277,7 @@ class MessageRecvS4U(MessageRecv): """ return "" elif segment.type == "gift": + self.is_voice = False self.is_gift = True # 解析gift_info,格式为"名称:数量" name, count = segment.data.split(":", 1) # type: ignore @@ -343,6 +368,10 @@ class MessageProcessBase(Message): if isinstance(seg.data, str): return await get_image_manager().get_emoji_description(seg.data) return "[表情,网卡了加载不出来]" + elif seg.type == "voice": + if isinstance(seg.data, str): + return await get_voice_text(seg.data) + return "[发了一段语音,网卡了加载不出来]" elif seg.type == "at": return f"[@{seg.data}]" elif seg.type == "reply": @@ -455,25 +484,25 @@ class MessageSending(MessageProcessBase): if self.message_segment: self.processed_plain_text = await self._process_message_segments(self.message_segment) - @classmethod - def from_thinking( - cls, - thinking: MessageThinking, - message_segment: Seg, - is_head: bool = False, - is_emoji: bool = False, - ) -> "MessageSending": - """从思考状态消息创建发送状态消息""" - return cls( - message_id=thinking.message_info.message_id, # type: ignore - chat_stream=thinking.chat_stream, - message_segment=message_segment, - bot_user_info=thinking.message_info.user_info, # type: ignore - reply=thinking.reply, - is_head=is_head, - is_emoji=is_emoji, - sender_info=None, - ) + # @classmethod + # def from_thinking( + # cls, + # thinking: MessageThinking, + # message_segment: Seg, + # is_head: bool = False, + # is_emoji: bool = False, + # ) -> "MessageSending": + # """从思考状态消息创建发送状态消息""" + # return cls( + # message_id=thinking.message_info.message_id, # type: ignore + # chat_stream=thinking.chat_stream, + # message_segment=message_segment, + # bot_user_info=thinking.message_info.user_info, # type: ignore + # reply=thinking.reply, + # is_head=is_head, + # is_emoji=is_emoji, + # sender_info=None, + # ) def to_dict(self): ret = super().to_dict() diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index a4876a46d..a2f4c37bd 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -262,4 +262,4 @@ class ActionManager: """ from src.plugin_system.core.component_registry import component_registry - return component_registry.get_component_class(action_name) # type: ignore + return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore diff --git a/src/chat/utils/utils_voice.py b/src/chat/utils/utils_voice.py new file mode 100644 index 000000000..1bc3e7dda --- /dev/null +++ b/src/chat/utils/utils_voice.py @@ -0,0 +1,35 @@ +import base64 + +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest + +from src.common.logger import get_logger +from rich.traceback import install +install(extra_lines=3) + +logger = get_logger("chat_voice") + +async def get_voice_text(voice_base64: str) -> str: + """获取音频文件描述""" + if not global_config.chat.enable_asr: + logger.warning("语音识别未启用,无法处理语音消息") + return "[语音]" + try: + # 解码base64音频数据 + # 确保base64字符串只包含ASCII字符 + if isinstance(voice_base64, str): + voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii") + voice_bytes = base64.b64decode(voice_base64) + _llm = LLMRequest(model=global_config.model.voice, request_type="voice") + text = await _llm.generate_response_for_voice(voice_bytes) + if text is None: + logger.warning("未能生成语音文本") + return "[语音(文本生成失败)]" + + logger.debug(f"描述是{text}") + + return f"[语音:{text}]" + except Exception as e: + logger.error(f"语音转文字失败: {str(e)}") + return "[语音]" + diff --git a/src/chat/willing/mode_classical.py b/src/chat/willing/mode_classical.py index d63ba0a23..e15272332 100644 --- a/src/chat/willing/mode_classical.py +++ b/src/chat/willing/mode_classical.py @@ -21,6 +21,7 @@ class ClassicalWillingManager(BaseWillingManager): self._decay_task = asyncio.create_task(self._decay_reply_willing()) async def get_reply_probability(self, message_id): + # sourcery skip: inline-immediately-returned-variable willing_info = self.ongoing_messages[message_id] chat_id = willing_info.chat_id current_willing = self.chat_reply_willing.get(chat_id, 0) diff --git a/src/chat/willing/mode_mxp.py b/src/chat/willing/mode_mxp.py index 7b9e55568..5a13a628a 100644 --- a/src/chat/willing/mode_mxp.py +++ b/src/chat/willing/mode_mxp.py @@ -25,6 +25,8 @@ import asyncio import time import math +from src.chat.message_receive.chat_stream import ChatStream + class MxpWillingManager(BaseWillingManager): """Mxp意愿管理器""" @@ -76,7 +78,7 @@ class MxpWillingManager(BaseWillingManager): self.chat_bot_message_time[w_info.chat_id].append(current_time) if len(self.chat_bot_message_time[w_info.chat_id]) == int(self.fatigue_messages_triggered_num): time_interval = 60 - (current_time - self.chat_bot_message_time[w_info.chat_id].pop(0)) - self.chat_fatigue_punishment_list[w_info.chat_id].append([current_time, time_interval * 2]) + self.chat_fatigue_punishment_list[w_info.chat_id].append((current_time, time_interval * 2)) async def after_generate_reply_handle(self, message_id: str): """回复后处理""" @@ -87,12 +89,14 @@ class MxpWillingManager(BaseWillingManager): # rel_level = self._get_relationship_level_num(rel_value) # self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += rel_level * 0.05 - now_chat_new_person = self.last_response_person.get(w_info.chat_id, [w_info.person_id, 0]) + now_chat_new_person = self.last_response_person.get(w_info.chat_id, (w_info.person_id, 0)) if now_chat_new_person[0] == w_info.person_id: if now_chat_new_person[1] < 3: - now_chat_new_person[1] += 1 + tmp_list = list(now_chat_new_person) + tmp_list[1] += 1 # type: ignore + self.last_response_person[w_info.chat_id] = tuple(tmp_list) # type: ignore else: - self.last_response_person[w_info.chat_id] = [w_info.person_id, 0] + self.last_response_person[w_info.chat_id] = (w_info.person_id, 0) async def not_reply_handle(self, message_id: str): """不回复处理""" @@ -108,11 +112,12 @@ class MxpWillingManager(BaseWillingManager): self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * ( 2 * self.last_response_person[w_info.chat_id][1] - 1 ) - now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0]) + now_chat_new_person = self.last_response_person.get(w_info.chat_id, ("", 0)) if now_chat_new_person[0] != w_info.person_id: - self.last_response_person[w_info.chat_id] = [w_info.person_id, 0] + self.last_response_person[w_info.chat_id] = (w_info.person_id, 0) async def get_reply_probability(self, message_id: str): + # sourcery skip: merge-duplicate-blocks, remove-redundant-if """获取回复概率""" async with self.lock: w_info = self.ongoing_messages[message_id] @@ -121,17 +126,16 @@ class MxpWillingManager(BaseWillingManager): self.logger.debug(f"基础意愿值:{current_willing}") if w_info.is_mentioned_bot: - current_willing_ = self.mention_willing_gain / (int(current_willing) + 1) - current_willing += current_willing_ + willing_gain = self.mention_willing_gain / (int(current_willing) + 1) + current_willing += willing_gain if self.is_debug: - self.logger.debug(f"提及增益:{current_willing_}") + self.logger.debug(f"提及增益:{willing_gain}") if w_info.interested_rate > 0: - current_willing += math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain + willing_gain = math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain + current_willing += willing_gain if self.is_debug: - self.logger.debug( - f"兴趣增益:{math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain}" - ) + self.logger.debug(f"兴趣增益:{willing_gain}") self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] = current_willing @@ -152,8 +156,8 @@ class MxpWillingManager(BaseWillingManager): self.logger.debug(f"疲劳衰减:{self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)}") chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id] - chat_person_ogoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id] - if len(chat_person_ogoing_messages) >= 2: + chat_person_ongoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id] + if len(chat_person_ongoing_messages) >= 2: current_willing = 0 if self.is_debug: self.logger.debug("进行中消息惩罚:归0") @@ -191,34 +195,33 @@ class MxpWillingManager(BaseWillingManager): basic_willing + (willing - basic_willing) * self.intention_decay_rate ) - def setup(self, message, chat, is_mentioned_bot, interested_rate): - super().setup(message, chat, is_mentioned_bot, interested_rate) - - self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get( - chat.stream_id, self.basic_maximum_willing - ) - self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {}) - self.chat_person_reply_willing[chat.stream_id][ - self.ongoing_messages[message.message_info.message_id].person_id - ] = self.chat_person_reply_willing[chat.stream_id].get( - self.ongoing_messages[message.message_info.message_id].person_id, self.chat_reply_willing[chat.stream_id] + def setup(self, message: dict, chat_stream: ChatStream): + super().setup(message, chat_stream) + stream_id = chat_stream.stream_id + self.chat_reply_willing[stream_id] = self.chat_reply_willing.get(stream_id, self.basic_maximum_willing) + self.chat_person_reply_willing[stream_id] = self.chat_person_reply_willing.get(stream_id, {}) + self.chat_person_reply_willing[stream_id][self.ongoing_messages[message.get("message_id", "")].person_id] = ( + self.chat_person_reply_willing[stream_id].get( + self.ongoing_messages[message.get("message_id", "")].person_id, + self.chat_reply_willing[stream_id], + ) ) current_time = time.time() - if chat.stream_id not in self.chat_new_message_time: - self.chat_new_message_time[chat.stream_id] = [] - self.chat_new_message_time[chat.stream_id].append(current_time) - if len(self.chat_new_message_time[chat.stream_id]) > self.number_of_message_storage: - self.chat_new_message_time[chat.stream_id].pop(0) + if stream_id not in self.chat_new_message_time: + self.chat_new_message_time[stream_id] = [] + self.chat_new_message_time[stream_id].append(current_time) + if len(self.chat_new_message_time[stream_id]) > self.number_of_message_storage: + self.chat_new_message_time[stream_id].pop(0) - if chat.stream_id not in self.chat_fatigue_punishment_list: - self.chat_fatigue_punishment_list[chat.stream_id] = [ + if stream_id not in self.chat_fatigue_punishment_list: + self.chat_fatigue_punishment_list[stream_id] = [ ( current_time, self.number_of_message_storage * self.basic_maximum_willing / self.expected_replies_per_min * 60, ) ] - self.chat_fatigue_willing_attenuation[chat.stream_id] = ( + self.chat_fatigue_willing_attenuation[stream_id] = ( -2 * self.basic_maximum_willing * self.fatigue_coefficient ) @@ -227,12 +230,11 @@ class MxpWillingManager(BaseWillingManager): """意愿值转化为概率""" willing = max(0, willing) if willing < 2: - probability = math.atan(willing * 2) / math.pi * 2 + return math.atan(willing * 2) / math.pi * 2 elif willing < 2.5: - probability = math.atan(willing * 4) / math.pi * 2 + return math.atan(willing * 4) / math.pi * 2 else: - probability = 1 - return probability + return 1 async def _chat_new_message_to_change_basic_willing(self): """聊天流新消息改变基础意愿""" @@ -259,7 +261,7 @@ class MxpWillingManager(BaseWillingManager): update_time = 20 elif len(message_times) == self.number_of_message_storage: time_interval = current_time - message_times[0] - basic_willing = self._basic_willing_culculate(time_interval) + basic_willing = self._basic_willing_calculate(time_interval) self.chat_reply_willing[chat_id] = basic_willing update_time = 17 * basic_willing / self.basic_maximum_willing + 3 else: @@ -268,7 +270,7 @@ class MxpWillingManager(BaseWillingManager): if self.is_debug: self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}") - def _basic_willing_culculate(self, t: float) -> float: + def _basic_willing_calculate(self, t: float) -> float: """基础意愿值计算""" return math.tan(t * self.expected_replies_per_min * math.pi / 120 / self.number_of_message_storage) / 2 diff --git a/src/chat/willing/willing_manager.py b/src/chat/willing/willing_manager.py index 6c53273f5..31ea49399 100644 --- a/src/chat/willing/willing_manager.py +++ b/src/chat/willing/willing_manager.py @@ -104,7 +104,7 @@ class BaseWillingManager(ABC): is_mentioned_bot=message.get("is_mentioned", False), is_emoji=message.get("is_emoji", False), is_picid=message.get("is_picid", False), - interested_rate=message.get("interest_value", 0), + interested_rate = message.get("interest_value") or 0.0, ) def delete(self, message_id: str): diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 67b314f7f..be3ac1834 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -106,6 +106,9 @@ class ChatConfig(ConfigBase): focus_value: float = 1.0 """麦麦的专注思考能力,越低越容易专注,消耗token也越多""" + enable_asr: bool = False + """是否启用语音识别""" + def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float: """ 根据当前时间和聊天流获取对应的 talk_frequency @@ -630,6 +633,9 @@ class ModelConfig(ConfigBase): vlm: dict[str, Any] = field(default_factory=lambda: {}) """视觉语言模型配置""" + voice: dict[str, Any] = field(default_factory=lambda: {}) + """语音识别模型配置""" + tool_use: dict[str, Any] = field(default_factory=lambda: {}) """专注工具使用模型配置""" diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 1077cfa09..2e1d426a6 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -216,6 +216,8 @@ class LLMRequest: prompt: str = None, image_base64: str = None, image_format: str = None, + file_bytes: bytes = None, + file_format: str = None, payload: dict = None, retry_policy: dict = None, ) -> Dict[str, Any]: @@ -225,6 +227,8 @@ class LLMRequest: prompt: prompt文本 image_base64: 图片的base64编码 image_format: 图片格式 + file_bytes: 文件的二进制数据 + file_format: 文件格式 payload: 请求体数据 retry_policy: 自定义重试策略 request_type: 请求类型 @@ -246,30 +250,33 @@ class LLMRequest: # 构建请求体 if image_base64: payload = await self._build_payload(prompt, image_base64, image_format) + elif file_bytes: + payload = await self._build_formdata_payload(file_bytes, file_format) elif payload is None: payload = await self._build_payload(prompt) - if stream_mode: - payload["stream"] = stream_mode + if not file_bytes: + if stream_mode: + payload["stream"] = stream_mode - if self.temp != 0.7: - payload["temperature"] = self.temp + if self.temp != 0.7: + payload["temperature"] = self.temp - # 添加enable_thinking参数(如果不是默认值False) - if not self.enable_thinking: - payload["enable_thinking"] = False + # 添加enable_thinking参数(如果不是默认值False) + if not self.enable_thinking: + payload["enable_thinking"] = False - if self.thinking_budget != 4096: - payload["thinking_budget"] = self.thinking_budget + if self.thinking_budget != 4096: + payload["thinking_budget"] = self.thinking_budget - if self.max_tokens: - payload["max_tokens"] = self.max_tokens + if self.max_tokens: + payload["max_tokens"] = self.max_tokens - # if "max_tokens" not in payload and "max_completion_tokens" not in payload: - # payload["max_tokens"] = global_config.model.model_max_output_length - # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: - payload["max_completion_tokens"] = payload.pop("max_tokens") + # if "max_tokens" not in payload and "max_completion_tokens" not in payload: + # payload["max_tokens"] = global_config.model.model_max_output_length + # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 + if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: + payload["max_completion_tokens"] = payload.pop("max_tokens") return { "policy": policy, @@ -278,6 +285,8 @@ class LLMRequest: "stream_mode": stream_mode, "image_base64": image_base64, # 保留必要的exception处理所需的原始数据 "image_format": image_format, + "file_bytes": file_bytes, + "file_format": file_format, "prompt": prompt, } @@ -287,6 +296,8 @@ class LLMRequest: prompt: str = None, image_base64: str = None, image_format: str = None, + file_bytes: bytes = None, + file_format: str = None, payload: dict = None, retry_policy: dict = None, response_handler: callable = None, @@ -299,6 +310,8 @@ class LLMRequest: prompt: prompt文本 image_base64: 图片的base64编码 image_format: 图片格式 + file_bytes: 文件的二进制数据 + file_format: 文件格式 payload: 请求体数据 retry_policy: 自定义重试策略 response_handler: 自定义响应处理器 @@ -307,25 +320,36 @@ class LLMRequest: """ # 获取请求配置 request_content = await self._prepare_request( - endpoint, prompt, image_base64, image_format, payload, retry_policy + endpoint, prompt, image_base64, image_format, file_bytes, file_format, payload, retry_policy ) if request_type is None: request_type = self.request_type for retry in range(request_content["policy"]["max_retries"]): try: # 使用上下文管理器处理会话 - headers = await self._build_headers() + if file_bytes: + headers = await self._build_headers(is_formdata=True) + else: + headers = await self._build_headers(is_formdata=False) # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响 if request_content["stream_mode"]: headers["Accept"] = "text/event-stream" async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: - async with session.post( - request_content["api_url"], headers=headers, json=request_content["payload"] + post_kwargs = {"headers": headers} + #form-data数据上传方式不同 + if file_bytes: + post_kwargs["data"] = request_content["payload"] + else: + post_kwargs["json"] = request_content["payload"] + + async with session.post( + request_content["api_url"], **post_kwargs ) as response: handled_result = await self._handle_response( response, request_content, retry, response_handler, user_id, request_type, endpoint ) - return handled_result + return handled_result + except Exception as e: handled_payload, count_delta = await self._handle_exception(e, retry, request_content) retry += count_delta # 降级不计入重试次数 @@ -605,7 +629,7 @@ class LLMRequest: ) # 安全地检查和记录请求详情 handled_payload = await _safely_record(request_content, payload) - logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}") + logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}") raise RuntimeError( f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}" ) @@ -619,7 +643,7 @@ class LLMRequest: logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}") # 安全地检查和记录请求详情 handled_payload = await _safely_record(request_content, payload) - logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}") + logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}") raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") async def _transform_parameters(self, params: dict) -> dict: @@ -640,6 +664,33 @@ class LLMRequest: new_params["max_completion_tokens"] = new_params.pop("max_tokens") return new_params + async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData: + """构建form-data请求体""" + # 目前只适配了音频文件 + # 如果后续要支持其他类型的文件,可以在这里添加更多的处理逻辑 + data = aiohttp.FormData() + content_type_list = { + "wav": "audio/wav", + "mp3": "audio/mpeg", + "ogg": "audio/ogg", + "flac": "audio/flac", + "aac": "audio/aac", + } + + content_type = content_type_list.get(file_format) + if not content_type: + logger.warning(f"暂不支持的文件类型: {file_format}") + + data.add_field( + "file",io.BytesIO(file_bytes), + filename=f"file.{file_format}", + content_type=f'{content_type}' # 根据实际文件类型设置 + ) + data.add_field( + "model", self.model_name + ) + return data + async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict: """构建请求体""" # 复制一份参数,避免直接修改 self.params @@ -725,7 +776,8 @@ class LLMRequest: return content, reasoning_content, tool_calls else: return content, reasoning_content - + elif "text" in result and result["text"]: + return result["text"] return "没有返回结果", "" @staticmethod @@ -739,11 +791,15 @@ class LLMRequest: reasoning = "" return content, reasoning - async def _build_headers(self, no_key: bool = False) -> dict: + async def _build_headers(self, no_key: bool = False, is_formdata: bool = False) -> dict: """构建请求头""" if no_key: + if is_formdata: + return {"Authorization": "Bearer **********"} return {"Authorization": "Bearer **********", "Content-Type": "application/json"} else: + if is_formdata: + return {"Authorization": f"Bearer {self.api_key}"} return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} # 防止小朋友们截图自己的key @@ -761,6 +817,11 @@ class LLMRequest: content, reasoning_content = response return content, reasoning_content + async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple: + """根据输入的语音文件生成模型的异步响应""" + response = await self._execute_request(endpoint="/audio/transcriptions",file_bytes=voice_bytes, file_format='wav') + return response + async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: """异步方式根据输入的提示生成模型的响应""" # 构建请求体,不硬编码max_tokens diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index 59e240811..491da7c1c 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -18,11 +18,16 @@ from .base import ( CommandInfo, PluginInfo, PythonDependency, + BaseEventHandler, + EventHandlerInfo, + EventType, + MaiMessages, ) -from .core.plugin_manager import ( +from .core import ( plugin_manager, component_registry, dependency_manager, + events_manager, ) # 导入工具模块 @@ -33,7 +38,7 @@ from .utils import ( # generate_plugin_manifest, ) -from .apis.plugin_register_api import register_plugin +from .apis import register_plugin, get_logger __version__ = "1.0.0" @@ -43,6 +48,7 @@ __all__ = [ "BasePlugin", "BaseAction", "BaseCommand", + "BaseEventHandler", # 类型定义 "ComponentType", "ActionActivationType", @@ -52,15 +58,21 @@ __all__ = [ "CommandInfo", "PluginInfo", "PythonDependency", + "EventHandlerInfo", + "EventType", + # 消息 + "MaiMessages", # 管理器 "plugin_manager", "component_registry", "dependency_manager", + "events_manager", # 装饰器 "register_plugin", "ConfigField", # 工具函数 "ManifestValidator", + "get_logger", # "ManifestGenerator", # "validate_plugin_manifest", # "generate_plugin_manifest", diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index 15ef547ef..05cc62c72 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -18,7 +18,8 @@ from src.plugin_system.apis import ( utils_api, plugin_register_api, ) - +from .logging_api import get_logger +from .plugin_register_api import register_plugin # 导出所有API模块,使它们可以通过 apis.xxx 方式访问 __all__ = [ "chat_api", @@ -32,4 +33,6 @@ __all__ = [ "send_api", "utils_api", "plugin_register_api", + "get_logger", + "register_plugin", ] diff --git a/src/plugin_system/apis/logging_api.py b/src/plugin_system/apis/logging_api.py new file mode 100644 index 000000000..7aeec4133 --- /dev/null +++ b/src/plugin_system/apis/logging_api.py @@ -0,0 +1,3 @@ +from src.common.logger import get_logger + +__all__ = ["get_logger"] diff --git a/src/plugin_system/apis/plugin_register_api.py b/src/plugin_system/apis/plugin_register_api.py index 7970f3421..879c09b32 100644 --- a/src/plugin_system/apis/plugin_register_api.py +++ b/src/plugin_system/apis/plugin_register_api.py @@ -1,6 +1,8 @@ +from pathlib import Path + from src.common.logger import get_logger -logger = get_logger("plugin_register") +logger = get_logger("plugin_manager") # 复用plugin_manager名称 def register_plugin(cls): @@ -22,16 +24,23 @@ def register_plugin(cls): # 只是注册插件类,不立即实例化 # 插件管理器会负责实例化和注册 - plugin_name = cls.plugin_name or cls.__name__ - plugin_manager.plugin_classes[plugin_name] = cls # type: ignore - logger.debug(f"插件类已注册: {plugin_name}") + plugin_name: str = cls.plugin_name # type: ignore + if "." in plugin_name: + logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") + raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") + plugin_manager.plugin_classes[plugin_name] = cls + splitted_name = cls.__module__.split(".") + root_path = Path(__file__) + + # 查找项目根目录 + while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path: + root_path = root_path.parent + + if not (root_path / "pyproject.toml").exists(): + logger.error(f"注册 {plugin_name} 无法找到项目根目录") + return cls + + plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve()) + logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}") return cls - -def register_event_plugin(cls, *args, **kwargs): - - """事件插件注册装饰器 - - 用法: - @register_event_plugin - """ \ No newline at end of file diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index bff325948..a95e05aed 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -7,6 +7,7 @@ from .base_plugin import BasePlugin from .base_action import BaseAction from .base_command import BaseCommand +from .base_events_handler import BaseEventHandler from .component_types import ( ComponentType, ActionActivationType, @@ -16,6 +17,9 @@ from .component_types import ( CommandInfo, PluginInfo, PythonDependency, + EventHandlerInfo, + EventType, + MaiMessages, ) from .config_types import ConfigField @@ -32,4 +36,8 @@ __all__ = [ "PluginInfo", "PythonDependency", "ConfigField", + "EventHandlerInfo", + "EventType", + "BaseEventHandler", + "MaiMessages", ] diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 74ab22e67..a61a0339d 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -41,6 +41,7 @@ class BaseAction(ABC): action_message: Optional[dict] = None, **kwargs, ): + # sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs """初始化Action组件 Args: @@ -355,7 +356,9 @@ class BaseAction(ABC): # 从类属性读取名称,如果没有定义则使用类名自动生成 name = getattr(cls, "action_name", cls.__name__.lower().replace("action", "")) - + if "." in name: + logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代") + raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代") # 获取focus_activation_type和normal_activation_type focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS) normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS) diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index caf68567b..5387e01dd 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -219,7 +219,9 @@ class BaseCommand(ABC): Returns: CommandInfo: 生成的Command信息对象 """ - + if "." in cls.command_name: + logger.error(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代") + raise ValueError(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代") return CommandInfo( name=cls.command_name, component_type=ComponentType.COMMAND, diff --git a/src/plugin_system/base/base_event_plugin.py b/src/plugin_system/base/base_event_plugin.py deleted file mode 100644 index 859d43f06..000000000 --- a/src/plugin_system/base/base_event_plugin.py +++ /dev/null @@ -1,14 +0,0 @@ -from abc import abstractmethod - -from .plugin_base import PluginBase -from src.common.logger import get_logger - - -class BaseEventPlugin(PluginBase): - """基于事件的插件基类 - - 所有事件类型的插件都应该继承这个基类 - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py new file mode 100644 index 000000000..db6c20b62 --- /dev/null +++ b/src/plugin_system/base/base_events_handler.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import Tuple, Optional + +from src.common.logger import get_logger +from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType + +logger = get_logger("base_event_handler") + + +class BaseEventHandler(ABC): + """事件处理器基类 + + 所有事件处理器都应该继承这个基类,提供事件处理的基本接口 + """ + + event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知 + handler_name: str = "" # 处理器名称 + handler_description: str = "" + weight: int = 0 # 权重,数值越大优先级越高 + intercept_message: bool = False # 是否拦截消息,默认为否 + + def __init__(self): + self.log_prefix = "[EventHandler]" + if self.event_type == EventType.UNKNOWN: + raise NotImplementedError("事件处理器必须指定 event_type") + + @abstractmethod + async def execute(self, message: MaiMessages) -> Tuple[bool, Optional[str]]: + """执行事件处理的抽象方法,子类必须实现 + + Returns: + Tuple[bool, Optional[str]]: (是否执行成功, 可选的返回消息) + """ + raise NotImplementedError("子类必须实现 execute 方法") + + @classmethod + def get_handler_info(cls) -> "EventHandlerInfo": + """获取事件处理器的信息""" + # 从类属性读取名称,如果没有定义则使用类名自动生成 + name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", "")) + if "." in name: + logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代") + raise ValueError(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代") + return EventHandlerInfo( + name=name, + component_type=ComponentType.EVENT_HANDLER, + description=getattr(cls, "handler_description", "events处理器"), + event_type=cls.event_type, + weight=cls.weight, + intercept_message=cls.intercept_message, + ) diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index a93de5fab..1e6841eba 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -1,9 +1,12 @@ from abc import abstractmethod -from typing import List, Type +from typing import List, Type, Tuple, Union from .plugin_base import PluginBase from src.common.logger import get_logger -from src.plugin_system.base.component_types import ComponentInfo +from src.plugin_system.base.component_types import ComponentInfo, ActionInfo, CommandInfo, EventHandlerInfo +from .base_action import BaseAction +from .base_command import BaseCommand +from .base_events_handler import BaseEventHandler logger = get_logger("base_plugin") @@ -21,7 +24,15 @@ class BasePlugin(PluginBase): super().__init__(*args, **kwargs) @abstractmethod - def get_plugin_components(self) -> List[tuple[ComponentInfo, Type]]: + def get_plugin_components( + self, + ) -> List[ + Union[ + Tuple[ActionInfo, Type[BaseAction]], + Tuple[CommandInfo, Type[BaseCommand]], + Tuple[EventHandlerInfo, Type[BaseEventHandler]], + ] + ]: """获取插件包含的组件列表 子类必须实现此方法,返回组件信息和组件类的列表 diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 14025ed99..774daa598 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -1,6 +1,7 @@ from enum import Enum -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional from dataclasses import dataclass, field +from maim_message import Seg # 组件类型枚举 @@ -10,7 +11,10 @@ class ComponentType(Enum): ACTION = "action" # 动作组件 COMMAND = "command" # 命令组件 SCHEDULER = "scheduler" # 定时任务组件(预留) - LISTENER = "listener" # 事件监听组件(预留) + EVENT_HANDLER = "event_handler" # 事件处理组件(预留) + + def __str__(self) -> str: + return self.value # 动作激活类型枚举 @@ -46,12 +50,17 @@ class EventType(Enum): 事件类型枚举类 """ + ON_START = "on_start" # 启动事件,用于调用按时任务 ON_MESSAGE = "on_message" ON_PLAN = "on_plan" POST_LLM = "post_llm" AFTER_LLM = "after_llm" POST_SEND = "post_send" AFTER_SEND = "after_send" + UNKNOWN = "unknown" # 未知事件类型 + + def __str__(self) -> str: + return self.value @dataclass @@ -142,6 +151,19 @@ class CommandInfo(ComponentInfo): self.component_type = ComponentType.COMMAND +@dataclass +class EventHandlerInfo(ComponentInfo): + """事件处理器组件信息""" + + event_type: EventType = EventType.ON_MESSAGE # 监听事件类型 + intercept_message: bool = False # 是否拦截消息处理(默认不拦截) + weight: int = 0 # 事件处理器权重,决定执行顺序 + + def __post_init__(self): + super().__post_init__() + self.component_type = ComponentType.EVENT_HANDLER + + @dataclass class PluginInfo: """插件信息""" @@ -198,3 +220,42 @@ class PluginInfo: def get_pip_requirements(self) -> List[str]: """获取所有pip安装格式的依赖""" return [dep.get_pip_requirement() for dep in self.python_dependencies] + + +@dataclass +class MaiMessages: + """MaiM插件消息""" + + message_segments: List[Seg] = field(default_factory=list) + """消息段列表,支持多段消息""" + + message_base_info: Dict[str, Any] = field(default_factory=dict) + """消息基本信息,包含平台,用户信息等数据""" + + plain_text: str = "" + """纯文本消息内容""" + + raw_message: Optional[str] = None + """原始消息内容""" + + is_group_message: bool = False + """是否为群组消息""" + + is_private_message: bool = False + """是否为私聊消息""" + + stream_id: Optional[str] = None + """流ID,用于标识消息流""" + + llm_prompt: Optional[str] = None + """LLM提示词""" + + llm_response: Optional[str] = None + """LLM响应内容""" + + additional_data: Dict[Any, Any] = field(default_factory=dict) + """附加数据,可以存储额外信息""" + + def __post_init__(self): + if self.message_segments is None: + self.message_segments = [] diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index 50537b903..c6041ece7 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -7,9 +7,11 @@ from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.dependency_manager import dependency_manager +from src.plugin_system.core.events_manager import events_manager __all__ = [ "plugin_manager", "component_registry", "dependency_manager", + "events_manager", ] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 917069e11..29a45c604 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -1,16 +1,19 @@ -from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type import re + +from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type + from src.common.logger import get_logger from src.plugin_system.base.component_types import ( ComponentInfo, ActionInfo, CommandInfo, + EventHandlerInfo, PluginInfo, ComponentType, ) - from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_action import BaseAction +from src.plugin_system.base.base_events_handler import BaseEventHandler logger = get_logger("component_registry") @@ -23,12 +26,11 @@ class ComponentRegistry: def __init__(self): # 组件注册表 - self._components: Dict[str, ComponentInfo] = {} # 组件名 -> 组件信息 - self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = { - ComponentType.ACTION: {}, - ComponentType.COMMAND: {}, - } - self._component_classes: Dict[str, Union[Type[BaseCommand], Type[BaseAction]]] = {} # 组件名 -> 组件类 + self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息 + # 类型 -> 命名空间式名称 -> 组件信息 + self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} + # 命名空间式组件名 -> 组件类 + self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {} # 插件注册表 self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息 @@ -39,20 +41,43 @@ class ComponentRegistry: # Command特定注册表 self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类 - self._command_patterns: Dict[Pattern, Type[BaseCommand]] = {} # 编译后的正则 -> command类 + self._command_patterns: Dict[Pattern, str] = {} # 编译后的正则 -> command名 + + # EventHandler特定注册表 + self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} # event_handler名 -> event_handler类 + self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} # 启用的事件处理器 logger.info("组件注册中心初始化完成") - # === 通用组件注册方法 === + # == 注册方法 == + + def register_plugin(self, plugin_info: PluginInfo) -> bool: + """注册插件 + + Args: + plugin_info: 插件信息 + + Returns: + bool: 是否注册成功 + """ + plugin_name = plugin_info.name + + if plugin_name in self._plugins: + logger.warning(f"插件 {plugin_name} 已存在,跳过注册") + return False + + self._plugins[plugin_name] = plugin_info + logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})") + return True def register_component( - self, component_info: ComponentInfo, component_class: Union[Type[BaseCommand], Type[BaseAction]] + self, component_info: ComponentInfo, component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler]] ) -> bool: """注册组件 Args: - component_info: 组件信息 - component_class: 组件类 + component_info (ComponentInfo): 组件信息 + component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类 Returns: bool: 是否注册成功 @@ -60,68 +85,110 @@ class ComponentRegistry: component_name = component_info.name component_type = component_info.component_type plugin_name = getattr(component_info, "plugin_name", "unknown") + if "." in component_name: + logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代") + return False + if "." in plugin_name: + logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代") + return False - # 🔥 系统级别自动区分:为不同类型的组件添加命名空间前缀 - if component_type == ComponentType.ACTION: - namespaced_name = f"action.{component_name}" - elif component_type == ComponentType.COMMAND: - namespaced_name = f"command.{component_name}" - else: - # 未来扩展的组件类型 - namespaced_name = f"{component_type.value}.{component_name}" + namespaced_name = f"{component_type}.{component_name}" - # 检查命名空间化的名称是否冲突 if namespaced_name in self._components: existing_info = self._components[namespaced_name] existing_plugin = getattr(existing_info, "plugin_name", "unknown") logger.warning( - f"组件冲突: {component_type.value}组件 '{component_name}' " - f"已被插件 '{existing_plugin}' 注册,跳过插件 '{plugin_name}' 的注册" + f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册" ) return False - # 注册到通用注册表(使用命名空间化的名称) - self._components[namespaced_name] = component_info + self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称) self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名 - self._component_classes[namespaced_name] = component_class + self._components_classes[namespaced_name] = component_class # 根据组件类型进行特定注册(使用原始名称) - if component_type == ComponentType.ACTION: - self._register_action_component(component_info, component_class) # type: ignore - elif component_type == ComponentType.COMMAND: - self._register_command_component(component_info, component_class) # type: ignore + match component_type: + case ComponentType.ACTION: + ret = self._register_action_component(component_info, component_class) # type: ignore + case ComponentType.COMMAND: + ret = self._register_command_component(component_info, component_class) # type: ignore + case ComponentType.EVENT_HANDLER: + ret = self._register_event_handler_component(component_info, component_class) # type: ignore + case _: + logger.warning(f"未知组件类型: {component_type}") + if not ret: + return False logger.debug( - f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' " + f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' " f"({component_class.__name__}) [插件: {plugin_name}]" ) return True - def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]): - # -------------------------------- NEED REFACTORING -------------------------------- - # -------------------------------- LOGIC ERROR ------------------------------------- + def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool: """注册Action组件到Action特定注册表""" - action_name = action_info.name + if not (action_name := action_info.name): + logger.error(f"Action组件 {action_class.__name__} 必须指定名称") + return False + if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction): + logger.error(f"注册失败: {action_name} 不是有效的Action") + return False + self._action_registry[action_name] = action_class # 如果启用,添加到默认动作集 if action_info.enabled: self._default_actions[action_name] = action_info - def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]): + return True + + def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool: """注册Command组件到Command特定注册表""" - command_name = command_info.name + if not (command_name := command_info.name): + logger.error(f"Command组件 {command_class.__name__} 必须指定名称") + return False + if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand): + logger.error(f"注册失败: {command_name} 不是有效的Command") + return False + self._command_registry[command_name] = command_class - # 编译正则表达式并注册 - if command_info.command_pattern: + # 如果启用了且有匹配模式 + if command_info.enabled and command_info.command_pattern: pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL) - self._command_patterns[pattern] = command_class + if pattern not in self._command_patterns: + self._command_patterns[pattern] = command_name + + logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令") + + return True + + def _register_event_handler_component( + self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler] + ) -> bool: + if not (handler_name := handler_info.name): + logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称") + return False + if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler): + logger.error(f"注册失败: {handler_name} 不是有效的EventHandler") + return False + + self._event_handler_registry[handler_name] = handler_class + + from .events_manager import events_manager # 延迟导入防止循环导入问题 + + if events_manager.register_event_subscriber(handler_info, handler_class): + self._enabled_event_handlers[handler_name] = handler_class + return True + else: + logger.error(f"注册事件处理器 {handler_name} 失败") + return False # === 组件查询方法 === - - def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: # type: ignore + def get_component_info( + self, component_name: str, component_type: Optional[ComponentType] = None + ) -> Optional[ComponentInfo]: # sourcery skip: class-extract-method """获取组件信息,支持自动命名空间解析 @@ -138,18 +205,12 @@ class ComponentRegistry: # 2. 如果指定了组件类型,构造命名空间化的名称查找 if component_type: - if component_type == ComponentType.ACTION: - namespaced_name = f"action.{component_name}" - elif component_type == ComponentType.COMMAND: - namespaced_name = f"command.{component_name}" - else: - namespaced_name = f"{component_type.value}.{component_name}" - + namespaced_name = f"{component_type}.{component_name}" return self._components.get(namespaced_name) # 3. 如果没有指定类型,尝试在所有命名空间中查找 candidates = [] - for namespace_prefix in ["action", "command"]: + for namespace_prefix in [types.value for types in ComponentType]: namespaced_name = f"{namespace_prefix}.{component_name}" if component_info := self._components.get(namespaced_name): candidates.append((namespace_prefix, namespaced_name, component_info)) @@ -171,8 +232,8 @@ class ComponentRegistry: def get_component_class( self, component_name: str, - component_type: ComponentType = None, # type: ignore - ) -> Optional[Union[Type[BaseCommand], Type[BaseAction]]]: + component_type: Optional[ComponentType] = None, + ) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler]]]: """获取组件类,支持自动命名空间解析 Args: @@ -184,29 +245,23 @@ class ComponentRegistry: """ # 1. 如果已经是命名空间化的名称,直接查找 if "." in component_name: - return self._component_classes.get(component_name) + return self._components_classes.get(component_name) # 2. 如果指定了组件类型,构造命名空间化的名称查找 if component_type: - if component_type == ComponentType.ACTION: - namespaced_name = f"action.{component_name}" - elif component_type == ComponentType.COMMAND: - namespaced_name = f"command.{component_name}" - else: - namespaced_name = f"{component_type.value}.{component_name}" - - return self._component_classes.get(namespaced_name) + namespaced_name = f"{component_type.value}.{component_name}" + return self._components_classes.get(namespaced_name) # 3. 如果没有指定类型,尝试在所有命名空间中查找 candidates = [] - for namespace_prefix in ["action", "command"]: + for namespace_prefix in [types.value for types in ComponentType]: namespaced_name = f"{namespace_prefix}.{component_name}" - if component_class := self._component_classes.get(namespaced_name): + if component_class := self._components_classes.get(namespaced_name): candidates.append((namespace_prefix, namespaced_name, component_class)) if len(candidates) == 1: # 只有一个匹配,直接返回 - namespace, full_name, cls = candidates[0] + _, full_name, cls = candidates[0] logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'") return cls elif len(candidates) > 1: @@ -235,7 +290,7 @@ class ComponentRegistry: """获取Action注册表(用于兼容现有系统)""" return self._action_registry.copy() - def get_action_info(self, action_name: str) -> Optional[ActionInfo]: + def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]: """获取Action信息""" info = self.get_component_info(action_name, ComponentType.ACTION) return info if isinstance(info, ActionInfo) else None @@ -247,18 +302,18 @@ class ComponentRegistry: # === Command特定查询方法 === def get_command_registry(self) -> Dict[str, Type[BaseCommand]]: - """获取Command注册表(用于兼容现有系统)""" + """获取Command注册表""" return self._command_registry.copy() - def get_command_patterns(self) -> Dict[Pattern, Type[BaseCommand]]: - """获取Command模式注册表(用于兼容现有系统)""" - return self._command_patterns.copy() - - def get_command_info(self, command_name: str) -> Optional[CommandInfo]: + def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]: """获取Command信息""" info = self.get_component_info(command_name, ComponentType.COMMAND) return info if isinstance(info, CommandInfo) else None + def get_command_patterns(self) -> Dict[Pattern, str]: + """获取Command模式注册表""" + return self._command_patterns.copy() + def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]: # sourcery skip: use-named-expression, use-next """根据文本查找匹配的命令 @@ -270,47 +325,36 @@ class ComponentRegistry: Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None """ - for pattern, command_class in self._command_patterns.items(): - if match := pattern.match(text): - command_name = None - # 查找对应的组件信息 - for name, cls in self._command_registry.items(): - if cls == command_class: - command_name = name - break + candidates = [pattern for pattern in self._command_patterns if pattern.match(text)] + if not candidates: + return None + if len(candidates) > 1: + logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配") + command_name = self._command_patterns[candidates[0]] + command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore + return ( + self._command_registry[command_name], + candidates[0].match(text).groupdict(), # type: ignore + command_info.intercept_message, + command_info.plugin_name, + ) - # 检查命令是否启用 - if command_name: - command_info = self.get_command_info(command_name) - if command_info and command_info.enabled: - return ( - command_class, - match.groupdict(), - command_info.intercept_message, - command_info.plugin_name, - ) - return None + # === 事件处理器特定查询方法 === - # === 插件管理方法 === + def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]: + """获取事件处理器注册表""" + return self._event_handler_registry.copy() - def register_plugin(self, plugin_info: PluginInfo) -> bool: - """注册插件 + def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]: + """获取事件处理器信息""" + info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) + return info if isinstance(info, EventHandlerInfo) else None - Args: - plugin_info: 插件信息 + def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]: + """获取启用的事件处理器""" + return self._enabled_event_handlers.copy() - Returns: - bool: 是否注册成功 - """ - plugin_name = plugin_info.name - - if plugin_name in self._plugins: - logger.warning(f"插件 {plugin_name} 已存在,跳过注册") - return False - - self._plugins[plugin_name] = plugin_info - logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})") - return True + # === 插件查询方法 === def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]: """获取插件信息""" @@ -344,82 +388,22 @@ class ComponentRegistry: plugin_instance = plugin_manager.get_plugin_instance(plugin_name) return plugin_instance.config if plugin_instance else None - # === 状态管理方法 === - - # def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool: - # # -------------------------------- NEED REFACTORING -------------------------------- - # # -------------------------------- LOGIC ERROR ------------------------------------- - # """启用组件,支持命名空间解析""" - # # 首先尝试找到正确的命名空间化名称 - # component_info = self.get_component_info(component_name, component_type) - # if not component_info: - # return False - - # # 根据组件类型构造正确的命名空间化名称 - # if component_info.component_type == ComponentType.ACTION: - # namespaced_name = f"action.{component_name}" if "." not in component_name else component_name - # elif component_info.component_type == ComponentType.COMMAND: - # namespaced_name = f"command.{component_name}" if "." not in component_name else component_name - # else: - # namespaced_name = ( - # f"{component_info.component_type.value}.{component_name}" - # if "." not in component_name - # else component_name - # ) - - # if namespaced_name in self._components: - # self._components[namespaced_name].enabled = True - # # 如果是Action,更新默认动作集 - # # ---- HERE ---- - # # if isinstance(component_info, ActionInfo): - # # self._action_descriptions[component_name] = component_info.description - # logger.debug(f"已启用组件: {component_name} -> {namespaced_name}") - # return True - # return False - - # def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool: - # # -------------------------------- NEED REFACTORING -------------------------------- - # # -------------------------------- LOGIC ERROR ------------------------------------- - # """禁用组件,支持命名空间解析""" - # # 首先尝试找到正确的命名空间化名称 - # component_info = self.get_component_info(component_name, component_type) - # if not component_info: - # return False - - # # 根据组件类型构造正确的命名空间化名称 - # if component_info.component_type == ComponentType.ACTION: - # namespaced_name = f"action.{component_name}" if "." not in component_name else component_name - # elif component_info.component_type == ComponentType.COMMAND: - # namespaced_name = f"command.{component_name}" if "." not in component_name else component_name - # else: - # namespaced_name = ( - # f"{component_info.component_type.value}.{component_name}" - # if "." not in component_name - # else component_name - # ) - - # if namespaced_name in self._components: - # self._components[namespaced_name].enabled = False - # # 如果是Action,从默认动作集中移除 - # # ---- HERE ---- - # # if component_name in self._action_descriptions: - # # del self._action_descriptions[component_name] - # logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}") - # return True - # return False - def get_registry_stats(self) -> Dict[str, Any]: """获取注册中心统计信息""" action_components: int = 0 command_components: int = 0 + events_handlers: int = 0 for component in self._components.values(): if component.component_type == ComponentType.ACTION: action_components += 1 elif component.component_type == ComponentType.COMMAND: command_components += 1 + elif component.component_type == ComponentType.EVENT_HANDLER: + events_handlers += 1 return { "action_components": action_components, "command_components": command_components, + "event_handlers": events_handlers, "total_components": len(self._components), "total_plugins": len(self._plugins), "components_by_type": { @@ -430,5 +414,4 @@ class ComponentRegistry: } -# 全局组件注册中心实例 component_registry = ComponentRegistry() diff --git a/src/plugin_system/core/events_manager.py b/src/plugin_system/core/events_manager.py index 1b96da44c..2c48f9d6d 100644 --- a/src/plugin_system/core/events_manager.py +++ b/src/plugin_system/core/events_manager.py @@ -1,11 +1,136 @@ -from typing import List, Dict, Type +import asyncio +from typing import List, Dict, Optional, Type -from src.plugin_system.base.component_types import EventType +from src.chat.message_receive.message import MessageRecv +from src.common.logger import get_logger +from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages +from src.plugin_system.base.base_events_handler import BaseEventHandler + +logger = get_logger("events_manager") class EventsManager: def __init__(self): # 有权重的 events 订阅者注册表 - self.events_subscribers: Dict[EventType, List[Dict[int, Type]]] = {event: [] for event in EventType} + self.events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType} + self.handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表 -events_manager = EventsManager() \ No newline at end of file + def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool: + """注册事件处理器 + + Args: + handler_info (EventHandlerInfo): 事件处理器信息 + handler_class (Type[BaseEventHandler]): 事件处理器类 + + Returns: + bool: 是否注册成功 + """ + handler_name = handler_info.name + plugin_name = getattr(handler_info, "plugin_name", "unknown") + + namespace_name = f"{plugin_name}.{handler_name}" + if namespace_name in self.handler_mapping: + logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册") + return False + + if not issubclass(handler_class, BaseEventHandler): + logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类") + return False + + self.handler_mapping[namespace_name] = handler_class + + return self._insert_event_handler(handler_class) + + async def handler_mai_events( + self, + event_type: EventType, + message: MessageRecv, + llm_prompt: Optional[str] = None, + llm_response: Optional[str] = None, + ) -> None: + """处理 events""" + transformed_message = self._transform_event_message(message, llm_prompt, llm_response) + for handler in self.events_subscribers.get(event_type, []): + if handler.intercept_message: + await handler.execute(transformed_message) + else: + asyncio.create_task(handler.execute(transformed_message)) + + def _insert_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: + """插入事件处理器到对应的事件类型列表中""" + if handler_class.event_type == EventType.UNKNOWN: + logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册") + return False + + self.events_subscribers[handler_class.event_type].append(handler_class()) + self.events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True) + + return True + + def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool: + """从事件类型列表中移除事件处理器""" + if handler_class.event_type == EventType.UNKNOWN: + logger.warning(f"事件处理器 {handler_class.__name__} 的事件类型未知,不存在于处理器列表中") + return False + + handlers = self.events_subscribers[handler_class.event_type] + for i, handler in enumerate(handlers): + if isinstance(handler, handler_class): + del handlers[i] + logger.debug(f"事件处理器 {handler_class.__name__} 已移除") + return True + + logger.warning(f"未找到事件处理器 {handler_class.__name__},无法移除") + return False + + def _transform_event_message( + self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[str] = None + ) -> MaiMessages: + """转换事件消息格式""" + # 直接赋值部分内容 + transformed_message = MaiMessages( + llm_prompt=llm_prompt, + llm_response=llm_response, + raw_message=message.raw_message, + additional_data=message.message_info.additional_config or {}, + ) + + # 消息段处理 + if message.message_segment.type == "seglist": + transformed_message.message_segments = list(message.message_segment.data) # type: ignore + else: + transformed_message.message_segments = [message.message_segment] + + # stream_id 处理 + if hasattr(message, "chat_stream"): + transformed_message.stream_id = message.chat_stream.stream_id + + # 处理后文本 + transformed_message.plain_text = message.processed_plain_text + + # 基本信息 + if message.message_info.platform: + transformed_message.message_base_info["platform"] = message.message_info.platform + if message.message_info.group_info: + transformed_message.is_group_message = True + transformed_message.message_base_info.update( + { + "group_id": message.message_info.group_info.group_id, + "group_name": message.message_info.group_info.group_name, + } + ) + if message.message_info.user_info: + if not transformed_message.is_group_message: + transformed_message.is_private_message = True + transformed_message.message_base_info.update( + { + "user_id": message.message_info.user_info.user_id, + "user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称 + "user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名) + } + ) + + return transformed_message + + +events_manager = EventsManager() diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index b4050794f..3ce9c9e52 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -1,10 +1,12 @@ -from typing import Dict, List, Optional, Tuple, Type, Any import os -from importlib.util import spec_from_file_location, module_from_spec -from inspect import getmodule -from pathlib import Path +import inspect import traceback +from typing import Dict, List, Optional, Tuple, Type, Any +from importlib.util import spec_from_file_location, module_from_spec +from pathlib import Path + + from src.common.logger import get_logger from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.dependency_manager import dependency_manager @@ -28,7 +30,7 @@ class PluginManager: self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径 self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例 - self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件类及其错误信息,插件名 -> 错误信息 + self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息 # 确保插件目录存在 self._ensure_plugin_directories() @@ -107,13 +109,9 @@ class PluginManager: # 使用记录的插件目录路径 plugin_dir = self.plugin_paths.get(plugin_name) - # 如果没有记录,则尝试查找(fallback) + # 如果没有记录,直接返回失败 if not plugin_dir: - plugin_dir = self._find_plugin_directory(plugin_class) - if plugin_dir: - self.plugin_paths[plugin_name] = plugin_dir # 更新路径 - else: - return False, 1 + return False, 1 plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败) if not plugin_instance: @@ -360,24 +358,14 @@ class PluginManager: logger.debug(f"正在扫描插件根目录: {directory}") - # 遍历目录中的所有Python文件和包 + # 遍历目录中的所有包 for item in os.listdir(directory): item_path = os.path.join(directory, item) - if os.path.isfile(item_path) and item.endswith(".py") and item != "__init__.py": - # 单文件插件 - plugin_name = Path(item_path).stem - if self._load_plugin_module_file(item_path, plugin_name, directory): - loaded_count += 1 - else: - failed_count += 1 - - elif os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"): - # 插件包 + if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"): plugin_file = os.path.join(item_path, "plugin.py") if os.path.exists(plugin_file): - plugin_name = item # 使用目录名作为插件名 - if self._load_plugin_module_file(plugin_file, plugin_name, item_path): + if self._load_plugin_module_file(plugin_file): loaded_count += 1 else: failed_count += 1 @@ -387,14 +375,16 @@ class PluginManager: def _find_plugin_directory(self, plugin_class: Type[PluginBase]) -> Optional[str]: """查找插件类对应的目录路径""" try: - module = getmodule(plugin_class) - if module and hasattr(module, "__file__") and module.__file__: - return os.path.dirname(module.__file__) + # module = getmodule(plugin_class) + # if module and hasattr(module, "__file__") and module.__file__: + # return os.path.dirname(module.__file__) + file_path = inspect.getfile(plugin_class) + return os.path.dirname(file_path) except Exception as e: logger.debug(f"通过inspect获取插件目录失败: {e}") return None - def _load_plugin_module_file(self, plugin_file: str, plugin_name: str, plugin_dir: str) -> bool: + def _load_plugin_module_file(self, plugin_file: str) -> bool: # sourcery skip: extract-method """加载单个插件模块文件 @@ -405,12 +395,7 @@ class PluginManager: """ # 生成模块名 plugin_path = Path(plugin_file) - if plugin_path.parent.name != "plugins": - # 插件包格式:parent_dir.plugin - module_name = f"plugins.{plugin_path.parent.name}.plugin" - else: - # 单文件格式:plugins.filename - module_name = f"plugins.{plugin_path.stem}" + module_name = ".".join(plugin_path.parent.parts) try: # 动态导入插件模块 @@ -422,16 +407,13 @@ class PluginManager: module = module_from_spec(spec) spec.loader.exec_module(module) - # 记录插件名和目录路径的映射 - self.plugin_paths[plugin_name] = plugin_dir - logger.debug(f"插件模块加载成功: {plugin_file}") return True except Exception as e: error_msg = f"加载插件模块 {plugin_file} 失败: {e}" logger.error(error_msg) - self.failed_plugins[plugin_name] = error_msg + self.failed_plugins[module_name] = error_msg return False def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]: @@ -475,13 +457,14 @@ class PluginManager: stats = component_registry.get_registry_stats() action_count = stats.get("action_components", 0) command_count = stats.get("command_components", 0) + event_handler_count = stats.get("event_handlers", 0) total_components = stats.get("total_components", 0) # 📋 显示插件加载总览 if total_registered > 0: logger.info("🎉 插件系统加载完成!") logger.info( - f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})" + f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, EventHandler: {event_handler_count})" ) # 显示详细的插件列表 @@ -510,8 +493,9 @@ class PluginManager: # 组件列表 if plugin_info.components: - action_components = [c for c in plugin_info.components if c.component_type.name == "ACTION"] - command_components = [c for c in plugin_info.components if c.component_type.name == "COMMAND"] + action_components = [c for c in plugin_info.components if c.component_type == ComponentType.ACTION] + command_components = [c for c in plugin_info.components if c.component_type == ComponentType.COMMAND] + event_handler_components = [c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER] if action_components: action_names = [c.name for c in action_components] @@ -520,6 +504,10 @@ class PluginManager: if command_components: command_names = [c.name for c in command_components] logger.info(f" ⚡ Command组件: {', '.join(command_names)}") + + if event_handler_components: + event_handler_names = [c.name for c in event_handler_components] + logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}") # 依赖信息 if plugin_info.dependencies: @@ -530,6 +518,12 @@ class PluginManager: config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌" logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}") + root_path = Path(__file__) + + # 查找项目根目录 + while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path: + root_path = root_path.parent + # 显示目录统计 logger.info("📂 加载目录统计:") for directory in self.plugin_directories: @@ -537,7 +531,11 @@ class PluginManager: plugins_in_dir = [] for plugin_name in self.loaded_plugins.keys(): plugin_path = self.plugin_paths.get(plugin_name, "") - if plugin_path.startswith(directory): + if ( + Path(plugin_path) + .resolve() + .is_relative_to(Path(os.path.join(str(root_path), directory)).resolve()) + ): plugins_in_dir.append(plugin_name) if plugins_in_dir: diff --git a/src/plugins/built_in/core_actions/plugin.py b/src/plugins/built_in/core_actions/plugin.py index 123ab2a16..82484f7fb 100644 --- a/src/plugins/built_in/core_actions/plugin.py +++ b/src/plugins/built_in/core_actions/plugin.py @@ -80,8 +80,9 @@ class ReplyAction(BaseAction): logger.info(f"{self.log_prefix} 回复目标: {reply_to}") try: - prepared_reply = self.action_data.get("prepared_reply", "") - if not prepared_reply: + if prepared_reply := self.action_data.get("prepared_reply", ""): + reply_text = prepared_reply + else: try: success, reply_set, _ = await asyncio.wait_for( generator_api.generate_reply( @@ -109,9 +110,6 @@ class ReplyAction(BaseAction): logger.info( f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复" ) - else: - reply_text = prepared_reply - # 构建回复文本 reply_text = "" first_replied = False @@ -120,11 +118,12 @@ class ReplyAction(BaseAction): data = reply_seg[1] if not first_replied: if need_reply: - await self.send_text(content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False) - first_replied = True + await self.send_text( + content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False + ) else: await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False) - first_replied = True + first_replied = True else: await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True) reply_text += data @@ -190,17 +189,15 @@ class CoreActionsPlugin(BasePlugin): def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: """返回插件包含的组件列表""" - # --- 从配置动态设置Action/Command --- - emoji_chance = global_config.emoji.emoji_chance - if global_config.emoji.emoji_activate_type == "random": - EmojiAction.random_activation_probability = emoji_chance - EmojiAction.focus_activation_type = ActionActivationType.RANDOM - EmojiAction.normal_activation_type = ActionActivationType.RANDOM - elif global_config.emoji.emoji_activate_type == "llm": + if global_config.emoji.emoji_activate_type == "llm": EmojiAction.random_activation_probability = 0.0 EmojiAction.focus_activation_type = ActionActivationType.LLM_JUDGE EmojiAction.normal_activation_type = ActionActivationType.LLM_JUDGE + elif global_config.emoji.emoji_activate_type == "random": + EmojiAction.random_activation_probability = global_config.emoji.emoji_chance + EmojiAction.focus_activation_type = ActionActivationType.RANDOM + EmojiAction.normal_activation_type = ActionActivationType.RANDOM # --- 根据配置注册组件 --- components = [] if self.get_config("components.enable_reply", True): diff --git a/src/tools/not_using/lpmm_get_knowledge.py b/src/tools/not_using/lpmm_get_knowledge.py index 180c5e699..467db6ed1 100644 --- a/src/tools/not_using/lpmm_get_knowledge.py +++ b/src/tools/not_using/lpmm_get_knowledge.py @@ -33,7 +33,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool): Dict: 工具执行结果 """ try: - query = function_args.get("query") + query: str = function_args.get("query") # type: ignore # threshold = function_args.get("threshold", 0.4) # 检查LPMM知识库是否启用 diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index fbb816621..69d21b564 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "4.4.3" +version = "4.4.4" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 @@ -87,6 +87,7 @@ talk_frequency_adjust = [ # - 时间支持跨天,例如 "00:10,0.3" 表示从凌晨0:10开始使用频率0.3 # - 系统会自动将 "platform:id:type" 转换为内部的哈希chat_id进行匹配 +enable_asr = false # 是否启用语音识别,启用后麦麦可以通过语音输入进行对话,启用该功能需要配置语音识别模型[model.voice] [message_receive] # 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息 @@ -294,6 +295,12 @@ provider = "SILICONFLOW" pri_in = 0.35 pri_out = 0.35 +[model.voice] # 语音识别模型 +name = "FunAudioLLM/SenseVoiceSmall" +provider = "SILICONFLOW" +pri_in = 0 +pri_out = 0 + [model.tool_use] #工具调用模型,需要使用支持工具调用的模型 name = "Qwen/Qwen3-14B" provider = "SILICONFLOW"