diff --git a/src/chat/actions/base_action.py b/src/chat/actions/base_action.py index 3b56a5a3d..624f163ea 100644 --- a/src/chat/actions/base_action.py +++ b/src/chat/actions/base_action.py @@ -8,19 +8,22 @@ logger = get_logger("base_action") _ACTION_REGISTRY: Dict[str, Type["BaseAction"]] = {} _DEFAULT_ACTIONS: Dict[str, str] = {} + # 动作激活类型枚举 class ActionActivationType: ALWAYS = "always" # 默认参与到planner - LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner + LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner RANDOM = "random" # 随机启用action到planner KEYWORD = "keyword" # 关键词触发启用action到planner + # 聊天模式枚举 class ChatMode: FOCUS = "focus" # Focus聊天模式 NORMAL = "normal" # Normal聊天模式 ALL = "all" # 所有聊天模式 + def register_action(cls): """ 动作注册装饰器 @@ -81,13 +84,13 @@ class BaseAction(ABC): self.action_description: str = "基础动作" self.action_parameters: dict = {} self.action_require: list[str] = [] - + # 动作激活类型设置 # Focus模式下的激活类型,默认为always self.focus_activation_type: str = ActionActivationType.ALWAYS - # Normal模式下的激活类型,默认为always + # Normal模式下的激活类型,默认为always self.normal_activation_type: str = ActionActivationType.ALWAYS - + # 随机激活的概率(0.0-1.0),用于RANDOM激活类型 self.random_activation_probability: float = 0.3 # LLM判定的提示词,用于LLM_JUDGE激活类型 diff --git a/src/chat/actions/default_actions/__init__.py b/src/chat/actions/default_actions/__init__.py index 47a679520..537090dc1 100644 --- a/src/chat/actions/default_actions/__init__.py +++ b/src/chat/actions/default_actions/__init__.py @@ -4,4 +4,4 @@ from . import no_reply_action # noqa from . import exit_focus_chat_action # noqa from . import emoji_action # noqa -# 在此处添加更多动作模块导入 \ No newline at end of file +# 在此处添加更多动作模块导入 diff --git a/src/chat/actions/default_actions/emoji_action.py b/src/chat/actions/default_actions/emoji_action.py index 1e9571808..df99ded26 100644 --- a/src/chat/actions/default_actions/emoji_action.py +++ b/src/chat/actions/default_actions/emoji_action.py @@ -22,29 +22,26 @@ class EmojiAction(BaseAction): action_parameters: dict[str:str] = { "description": "文字描述你想要发送的表情包内容", } - action_require: list[str] = [ - "表达情绪时可以选择使用", - "重点:不要连续发,如果你已经发过[表情包],就不要选择此动作"] + action_require: list[str] = ["表达情绪时可以选择使用", "重点:不要连续发,如果你已经发过[表情包],就不要选择此动作"] associated_types: list[str] = ["emoji"] enable_plugin = True - + focus_activation_type = ActionActivationType.LLM_JUDGE normal_activation_type = ActionActivationType.RANDOM - + random_activation_probability = global_config.normal_chat.emoji_chance - + parallel_action = True - - + llm_judge_prompt = """ 判定是否需要使用表情动作的条件: 1. 用户明确要求使用表情包 2. 这是一个适合表达强烈情绪的场合 3. 不要发送太多表情包,如果你已经发送过多个表情包 """ - + # 模式启用设置 - 表情动作只在Focus模式下使用 mode_enable = ChatMode.ALL @@ -147,4 +144,4 @@ class EmojiAction(BaseAction): elif type == "emoji": reply_text += data - return success, reply_text \ No newline at end of file + return success, reply_text diff --git a/src/chat/actions/default_actions/exit_focus_chat_action.py b/src/chat/actions/default_actions/exit_focus_chat_action.py index 8aa9976ae..d1a54328b 100644 --- a/src/chat/actions/default_actions/exit_focus_chat_action.py +++ b/src/chat/actions/default_actions/exit_focus_chat_action.py @@ -27,7 +27,7 @@ class ExitFocusChatAction(BaseAction): ] # 退出专注聊天是系统核心功能,不是插件,但默认不启用(需要特定条件触发) enable_plugin = False - + # 模式启用设置 - 退出专注聊天动作只在Focus模式下使用 mode_enable = ChatMode.FOCUS diff --git a/src/chat/actions/default_actions/no_reply_action.py b/src/chat/actions/default_actions/no_reply_action.py index b7ac95497..a319eedb1 100644 --- a/src/chat/actions/default_actions/no_reply_action.py +++ b/src/chat/actions/default_actions/no_reply_action.py @@ -29,10 +29,10 @@ class NoReplyAction(BaseAction): "想要休息一下", ] enable_plugin = True - + # 激活类型设置 focus_activation_type = ActionActivationType.ALWAYS - + # 模式启用设置 - no_reply动作只在Focus模式下使用 mode_enable = ChatMode.FOCUS diff --git a/src/chat/actions/default_actions/reply_action.py b/src/chat/actions/default_actions/reply_action.py index 571c1887f..5e9c236e1 100644 --- a/src/chat/actions/default_actions/reply_action.py +++ b/src/chat/actions/default_actions/reply_action.py @@ -31,16 +31,16 @@ class ReplyAction(BaseAction): action_require: list[str] = [ "你想要闲聊或者随便附和", "有人提到你", - "如果你刚刚进行了回复,不要对同一个话题重复回应" + "如果你刚刚进行了回复,不要对同一个话题重复回应", ] associated_types: list[str] = ["text"] enable_plugin = True - + # 激活类型设置 focus_activation_type = ActionActivationType.ALWAYS - + # 模式启用设置 - 回复动作只在Focus模式下使用 mode_enable = ChatMode.FOCUS @@ -89,12 +89,12 @@ class ReplyAction(BaseAction): cycle_timers=self.cycle_timers, thinking_id=self.thinking_id, ) - + await self.store_action_info( action_build_into_prompt=False, action_prompt_display=f"{reply_text}", ) - + return success, reply_text async def _handle_reply( @@ -115,22 +115,22 @@ class ReplyAction(BaseAction): chatting_observation: ChattingObservation = next( obs for obs in self.observations if isinstance(obs, ChattingObservation) ) - + reply_to = reply_data.get("reply_to", "none") - + # sender = "" target = "" if ":" in reply_to or ":" in reply_to: # 使用正则表达式匹配中文或英文冒号 - parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1) + parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1) if len(parts) == 2: # sender = parts[0].strip() target = parts[1].strip() anchor_message = chatting_observation.search_message_by_text(target) else: anchor_message = None - - if anchor_message: + + if anchor_message: anchor_message.update_chat_stream(self.chat_stream) else: logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") @@ -138,7 +138,6 @@ class ReplyAction(BaseAction): self.chat_stream.platform, self.chat_stream.group_info, self.chat_stream ) - success, reply_set = await self.replyer.deal_reply( cycle_timers=cycle_timers, action_data=reply_data, @@ -158,8 +157,9 @@ class ReplyAction(BaseAction): return success, reply_text - - async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None: + async def store_action_info( + self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True + ) -> None: """存储action执行信息到数据库 Args: @@ -188,9 +188,9 @@ class ReplyAction(BaseAction): chat_info_platform=chat_stream.platform, user_id=chat_stream.user_info.user_id if chat_stream.user_info else "", user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "", - user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "" + user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "", ) logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}") except Exception as e: logger.error(f"{self.log_prefix} 存储action信息时出错: {e}") - traceback.print_exc() \ No newline at end of file + traceback.print_exc() diff --git a/src/chat/actions/plugin_action.py b/src/chat/actions/plugin_action.py index ceda4adb8..04f9a545c 100644 --- a/src/chat/actions/plugin_action.py +++ b/src/chat/actions/plugin_action.py @@ -1,10 +1,6 @@ -import traceback -from typing import Tuple, Dict, List, Any, Optional, Union, Type +from typing import Tuple, Dict, Any, Optional from src.chat.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode # noqa F401 -from src.chat.heart_flow.observation.chatting_observation import ChattingObservation -from src.chat.focus_chat.hfc_utils import create_empty_anchor_message from src.common.logger_manager import get_logger -from src.config.config import global_config import os import inspect import toml # 导入 toml 库 @@ -20,10 +16,10 @@ from src.chat.actions.plugin_api.stream_api import StreamAPI from src.chat.actions.plugin_api.hearflow_api import HearflowAPI # 以下为类型注解需要 -from src.chat.message_receive.chat_stream import ChatStream # noqa -from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor # noqa -from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer # noqa -from src.chat.focus_chat.info.obs_info import ObsInfo # noqa +from src.chat.message_receive.chat_stream import ChatStream # noqa +from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor # noqa +from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer # noqa +from src.chat.focus_chat.info.obs_info import ObsInfo # noqa logger = get_logger("plugin_action") @@ -35,7 +31,7 @@ class PluginAction(BaseAction, MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, Utils """ action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名 - + # 默认激活类型设置,插件可以覆盖 focus_activation_type = ActionActivationType.ALWAYS normal_activation_type = ActionActivationType.ALWAYS @@ -43,7 +39,7 @@ class PluginAction(BaseAction, MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, Utils llm_judge_prompt: str = "" activation_keywords: list[str] = [] keyword_case_sensitive: bool = False - + # 默认模式启用设置 - 插件动作默认在所有模式下可用,插件可以覆盖 mode_enable = ChatMode.ALL diff --git a/src/chat/actions/plugin_api/__init__.py b/src/chat/actions/plugin_api/__init__.py index 93c59c01e..db85ee2f2 100644 --- a/src/chat/actions/plugin_api/__init__.py +++ b/src/chat/actions/plugin_api/__init__.py @@ -7,11 +7,11 @@ from src.chat.actions.plugin_api.stream_api import StreamAPI from src.chat.actions.plugin_api.hearflow_api import HearflowAPI __all__ = [ - 'MessageAPI', - 'LLMAPI', - 'DatabaseAPI', - 'ConfigAPI', - 'UtilsAPI', - 'StreamAPI', - 'HearflowAPI', -] \ No newline at end of file + "MessageAPI", + "LLMAPI", + "DatabaseAPI", + "ConfigAPI", + "UtilsAPI", + "StreamAPI", + "HearflowAPI", +] diff --git a/src/chat/actions/plugin_api/config_api.py b/src/chat/actions/plugin_api/config_api.py index f136cea7e..0ca617bb4 100644 --- a/src/chat/actions/plugin_api/config_api.py +++ b/src/chat/actions/plugin_api/config_api.py @@ -5,32 +5,33 @@ from src.person_info.person_info import person_info_manager logger = get_logger("config_api") + class ConfigAPI: """配置API模块 - + 提供了配置读取和用户信息获取等功能 """ - + def get_global_config(self, key: str, default: Any = None) -> Any: """ 安全地从全局配置中获取一个值。 插件应使用此方法读取全局配置,以保证只读和隔离性。 - + Args: key: 配置键名 default: 如果配置不存在时返回的默认值 - + Returns: Any: 配置值或默认值 """ return global_config.get(key, default) - + async def get_user_id_by_person_name(self, person_name: str) -> tuple[str, str]: """根据用户名获取用户ID - + Args: person_name: 用户名 - + Returns: tuple[str, str]: (平台, 用户ID) """ @@ -38,16 +39,16 @@ class ConfigAPI: user_id = await person_info_manager.get_value(person_id, "user_id") platform = await person_info_manager.get_value(person_id, "platform") return platform, user_id - + async def get_person_info(self, person_id: str, key: str, default: Any = None) -> Any: """获取用户信息 - + Args: person_id: 用户ID key: 信息键名 default: 默认值 - + Returns: Any: 用户信息值或默认值 """ - return await person_info_manager.get_value(person_id, key, default) \ No newline at end of file + return await person_info_manager.get_value(person_id, key, default) diff --git a/src/chat/actions/plugin_api/database_api.py b/src/chat/actions/plugin_api/database_api.py index 3342a3d6c..d9c7703bf 100644 --- a/src/chat/actions/plugin_api/database_api.py +++ b/src/chat/actions/plugin_api/database_api.py @@ -8,13 +8,16 @@ from peewee import Model, DoesNotExist logger = get_logger("database_api") + class DatabaseAPI: """数据库API模块 - + 提供了数据库操作相关的功能 """ - - async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None: + + async def store_action_info( + self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True + ) -> None: """存储action执行信息到数据库 Args: @@ -44,13 +47,13 @@ class DatabaseAPI: chat_info_platform=chat_stream.platform, user_id=chat_stream.user_info.user_id if chat_stream.user_info else "", user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "", - user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "" + user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "", ) logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}") except Exception as e: logger.error(f"{self.log_prefix} 存储action信息时出错: {e}") traceback.print_exc() - + async def db_query( self, model_class: Type[Model], @@ -59,12 +62,12 @@ class DatabaseAPI: data: Dict[str, Any] = None, limit: int = None, order_by: List[str] = None, - single_result: bool = False + single_result: bool = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: """执行数据库查询操作 - + 这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。 - + Args: model_class: Peewee 模型类,例如 ActionRecords, Messages 等 query_type: 查询类型,可选值: "get", "create", "update", "delete", "count" @@ -73,7 +76,7 @@ class DatabaseAPI: limit: 限制结果数量 order_by: 排序字段列表,使用字段名,前缀'-'表示降序 single_result: 是否只返回单个结果 - + Returns: 根据查询类型返回不同的结果: - "get": 返回查询结果列表或单个结果(如果 single_result=True) @@ -81,24 +84,24 @@ class DatabaseAPI: - "update": 返回受影响的行数 - "delete": 返回受影响的行数 - "count": 返回记录数量 - + 示例: # 查询最近10条消息 messages = await self.db_query( - Messages, + Messages, query_type="get", filters={"chat_id": chat_stream.stream_id}, limit=10, order_by=["-time"] ) - + # 创建一条记录 new_record = await self.db_query( ActionRecords, query_type="create", data={"action_id": "123", "time": time.time(), "action_name": "TestAction"} ) - + # 更新记录 updated_count = await self.db_query( ActionRecords, @@ -106,14 +109,14 @@ class DatabaseAPI: filters={"action_id": "123"}, data={"action_done": True} ) - + # 删除记录 deleted_count = await self.db_query( ActionRecords, query_type="delete", filters={"action_id": "123"} ) - + # 计数 count = await self.db_query( Messages, @@ -125,12 +128,12 @@ class DatabaseAPI: # 构建基本查询 if query_type in ["get", "update", "delete", "count"]: query = model_class.select() - + # 应用过滤条件 if filters: for field, value in filters.items(): query = query.where(getattr(model_class, field) == value) - + # 执行查询 if query_type == "get": # 应用排序 @@ -140,56 +143,56 @@ class DatabaseAPI: query = query.order_by(getattr(model_class, field[1:]).desc()) else: query = query.order_by(getattr(model_class, field)) - + # 应用限制 if limit: query = query.limit(limit) - + # 执行查询 results = list(query.dicts()) - + # 返回结果 if single_result: return results[0] if results else None return results - + elif query_type == "create": if not data: raise ValueError("创建记录需要提供data参数") - + # 创建记录 record = model_class.create(**data) # 返回创建的记录 return model_class.select().where(model_class.id == record.id).dicts().get() - + elif query_type == "update": if not data: raise ValueError("更新记录需要提供data参数") - + # 更新记录 return query.update(**data).execute() - + elif query_type == "delete": # 删除记录 return query.delete().execute() - + elif query_type == "count": # 计数 return query.count() - + else: raise ValueError(f"不支持的查询类型: {query_type}") - + except DoesNotExist: # 记录不存在 if query_type == "get" and single_result: return None return [] - + except Exception as e: logger.error(f"{self.log_prefix} 数据库操作出错: {e}") traceback.print_exc() - + # 根据查询类型返回合适的默认值 if query_type == "get": return None if single_result else [] @@ -198,21 +201,18 @@ class DatabaseAPI: raise "unknown query type" async def db_raw_query( - self, - sql: str, - params: List[Any] = None, - fetch_results: bool = True + self, sql: str, params: List[Any] = None, fetch_results: bool = True ) -> Union[List[Dict[str, Any]], int, None]: """执行原始SQL查询 - + 警告: 使用此方法需要小心,确保SQL语句已正确构造以避免SQL注入风险。 - + Args: sql: 原始SQL查询字符串 params: 查询参数列表,用于替换SQL中的占位符 fetch_results: 是否获取查询结果,对于SELECT查询设为True,对于 UPDATE/INSERT/DELETE等操作设为False - + Returns: 如果fetch_results为True,返回查询结果列表; 如果fetch_results为False,返回受影响的行数; @@ -220,55 +220,51 @@ class DatabaseAPI: """ try: cursor = db.execute_sql(sql, params or []) - + if fetch_results: # 获取列名 columns = [col[0] for col in cursor.description] - + # 构建结果字典列表 results = [] for row in cursor.fetchall(): results.append(dict(zip(columns, row))) - + return results else: # 返回受影响的行数 return cursor.rowcount - + except Exception as e: logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}") traceback.print_exc() return None - + async def db_save( - self, - model_class: Type[Model], - data: Dict[str, Any], - key_field: str = None, - key_value: Any = None + self, model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None ) -> Union[Dict[str, Any], None]: """保存数据到数据库(创建或更新) - + 如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新; 如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。 - + Args: model_class: Peewee模型类,如ActionRecords, Messages等 data: 要保存的数据字典 key_field: 用于查找现有记录的字段名,例如"action_id" key_value: 用于查找现有记录的字段值 - + Returns: Dict[str, Any]: 保存后的记录数据 None: 如果操作失败 - + 示例: # 创建或更新一条记录 record = await self.db_save( ActionRecords, { - "action_id": "123", - "time": time.time(), + "action_id": "123", + "time": time.time(), "action_name": "TestAction", "action_done": True }, @@ -280,58 +276,50 @@ class DatabaseAPI: # 如果提供了key_field和key_value,尝试更新现有记录 if key_field and key_value is not None: # 查找现有记录 - existing_records = list(model_class.select().where( - getattr(model_class, key_field) == key_value - ).limit(1)) - + existing_records = list( + model_class.select().where(getattr(model_class, key_field) == key_value).limit(1) + ) + if existing_records: # 更新现有记录 existing_record = existing_records[0] for field, value in data.items(): setattr(existing_record, field, value) existing_record.save() - + # 返回更新后的记录 - updated_record = model_class.select().where( - model_class.id == existing_record.id - ).dicts().get() + updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() return updated_record - + # 如果没有找到现有记录或未提供key_field和key_value,创建新记录 new_record = model_class.create(**data) - + # 返回创建的记录 - created_record = model_class.select().where( - model_class.id == new_record.id - ).dicts().get() + created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() return created_record - + except Exception as e: logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}") traceback.print_exc() return None - + async def db_get( - self, - model_class: Type[Model], - filters: Dict[str, Any] = None, - order_by: str = None, - limit: int = None + self, model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: """从数据库获取记录 - + 这是db_query方法的简化版本,专注于数据检索操作。 - + Args: model_class: Peewee模型类 filters: 过滤条件,字段名和值的字典 order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序 limit: 结果数量限制,如果为1则返回单个记录而不是列表 - + Returns: 如果limit=1,返回单个记录字典或None; 否则返回记录字典列表或空列表。 - + 示例: # 获取单个记录 record = await self.db_get( @@ -339,7 +327,7 @@ class DatabaseAPI: filters={"action_id": "123"}, limit=1 ) - + # 获取最近10条记录 records = await self.db_get( Messages, @@ -351,32 +339,32 @@ class DatabaseAPI: try: # 构建查询 query = model_class.select() - + # 应用过滤条件 if filters: for field, value in filters.items(): query = query.where(getattr(model_class, field) == value) - + # 应用排序 if order_by: if order_by.startswith("-"): query = query.order_by(getattr(model_class, order_by[1:]).desc()) else: query = query.order_by(getattr(model_class, order_by)) - + # 应用限制 if limit: query = query.limit(limit) - + # 执行查询 results = list(query.dicts()) - + # 返回结果 if limit == 1: return results[0] if results else None return results - + except Exception as e: logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}") traceback.print_exc() - return None if limit == 1 else [] \ No newline at end of file + return None if limit == 1 else [] diff --git a/src/chat/actions/plugin_api/hearflow_api.py b/src/chat/actions/plugin_api/hearflow_api.py index c7d0452a2..2c26ce768 100644 --- a/src/chat/actions/plugin_api/hearflow_api.py +++ b/src/chat/actions/plugin_api/hearflow_api.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Any, Tuple +from typing import Optional, List, Any from src.common.logger_manager import get_logger from src.chat.heart_flow.heartflow import heartflow from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState @@ -8,16 +8,16 @@ logger = get_logger("hearflow_api") class HearflowAPI: """心流API模块 - + 提供与心流和子心流相关的操作接口 """ - + async def get_sub_hearflow_by_chat_id(self, chat_id: str) -> Optional[SubHeartflow]: """根据chat_id获取指定的sub_hearflow实例 - + Args: chat_id: 聊天ID,与sub_hearflow的subheartflow_id相同 - + Returns: Optional[SubHeartflow]: sub_hearflow实例,如果不存在则返回None """ @@ -35,11 +35,10 @@ class HearflowAPI: except Exception as e: logger.error(f"{self.log_prefix} 获取子心流实例时出错: {e}") return None - - + def get_all_sub_hearflow_ids(self) -> List[str]: """获取所有子心流的ID列表 - + Returns: List[str]: 所有子心流的ID列表 """ @@ -51,10 +50,10 @@ class HearflowAPI: except Exception as e: logger.error(f"{self.log_prefix} 获取子心流ID列表时出错: {e}") return [] - + def get_all_sub_hearflows(self) -> List[SubHeartflow]: """获取所有子心流实例 - + Returns: List[SubHeartflow]: 所有活跃的子心流实例列表 """ @@ -66,13 +65,13 @@ class HearflowAPI: except Exception as e: logger.error(f"{self.log_prefix} 获取子心流实例列表时出错: {e}") return [] - + async def get_sub_hearflow_chat_state(self, chat_id: str) -> Optional[ChatState]: """获取指定子心流的聊天状态 - + Args: chat_id: 聊天ID - + Returns: Optional[ChatState]: 聊天状态,如果子心流不存在则返回None """ @@ -84,14 +83,14 @@ class HearflowAPI: except Exception as e: logger.error(f"{self.log_prefix} 获取子心流聊天状态时出错: {e}") return None - + async def set_sub_hearflow_chat_state(self, chat_id: str, target_state: ChatState) -> bool: """设置指定子心流的聊天状态 - + Args: chat_id: 聊天ID target_state: 目标状态 - + Returns: bool: 是否设置成功 """ @@ -100,13 +99,13 @@ class HearflowAPI: except Exception as e: logger.error(f"{self.log_prefix} 设置子心流聊天状态时出错: {e}") return False - + async def get_sub_hearflow_replyer(self, chat_id: str) -> Optional[Any]: """根据chat_id获取指定子心流的replyer实例 - + Args: chat_id: 聊天ID - + Returns: Optional[Any]: replyer实例,如果不存在则返回None """ @@ -116,13 +115,13 @@ class HearflowAPI: except Exception as e: logger.error(f"{self.log_prefix} 获取子心流replyer时出错: {e}") return None - + async def get_sub_hearflow_expressor(self, chat_id: str) -> Optional[Any]: """根据chat_id获取指定子心流的expressor实例 - + Args: chat_id: 聊天ID - + Returns: Optional[Any]: expressor实例,如果不存在则返回None """ @@ -131,4 +130,4 @@ class HearflowAPI: return expressor except Exception as e: logger.error(f"{self.log_prefix} 获取子心流expressor时出错: {e}") - return None \ No newline at end of file + return None diff --git a/src/chat/actions/plugin_api/llm_api.py b/src/chat/actions/plugin_api/llm_api.py index 0e80e897b..743aac748 100644 --- a/src/chat/actions/plugin_api/llm_api.py +++ b/src/chat/actions/plugin_api/llm_api.py @@ -5,12 +5,13 @@ from src.config.config import global_config logger = get_logger("llm_api") + class LLMAPI: """LLM API模块 - + 提供了与LLM模型交互的功能 """ - + def get_available_models(self) -> Dict[str, Any]: """获取所有可用的模型配置 @@ -20,17 +21,13 @@ class LLMAPI: if not hasattr(global_config, "model"): logger.error(f"{self.log_prefix} 无法获取模型列表:全局配置中未找到 model 配置") return {} - + models = global_config.model - + return models async def generate_with_model( - self, - prompt: str, - model_config: Dict[str, Any], - request_type: str = "plugin.generate", - **kwargs + self, prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs ) -> Tuple[bool, str, str, str]: """使用指定模型生成内容 @@ -45,17 +42,13 @@ class LLMAPI: """ try: logger.info(f"{self.log_prefix} 使用模型生成内容,提示词: {prompt[:100]}...") - - llm_request = LLMRequest( - model=model_config, - request_type=request_type, - **kwargs - ) - + + llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs) + response, (reasoning, model_name) = await llm_request.generate_response_async(prompt) return True, response, reasoning, model_name - + except Exception as e: error_msg = f"生成内容时出错: {str(e)}" logger.error(f"{self.log_prefix} {error_msg}") - return False, error_msg, "", "" \ No newline at end of file + return False, error_msg, "", "" diff --git a/src/chat/actions/plugin_api/message_api.py b/src/chat/actions/plugin_api/message_api.py index 00af27665..ca4c7e1cf 100644 --- a/src/chat/actions/plugin_api/message_api.py +++ b/src/chat/actions/plugin_api/message_api.py @@ -14,17 +14,18 @@ from src.chat.focus_chat.info.obs_info import ObsInfo # 新增导入 from src.chat.focus_chat.heartFC_sender import HeartFCSender from src.chat.message_receive.message import MessageSending -from maim_message import Seg, UserInfo, GroupInfo +from maim_message import Seg, UserInfo from src.config.config import global_config logger = get_logger("message_api") + class MessageAPI: """消息API模块 - + 提供了发送消息、获取消息历史等功能 """ - + async def send_message_to_target( self, message_type: str, @@ -35,7 +36,7 @@ class MessageAPI: display_message: str = "", ) -> bool: """直接向指定目标发送消息 - + Args: message_type: 消息类型,如"text"、"image"、"emoji"等 content: 消息内容 @@ -43,7 +44,7 @@ class MessageAPI: target_id: 目标ID(群ID或用户ID) is_group: 是否为群聊,True为群聊,False为私聊 display_message: 显示消息(可选) - + Returns: bool: 是否发送成功 """ @@ -53,12 +54,14 @@ class MessageAPI: # 群聊:从数据库查找对应的聊天流 target_stream = None for stream_id, stream in chat_manager.streams.items(): - if (stream.group_info and - str(stream.group_info.group_id) == str(target_id) and - stream.platform == platform): + if ( + stream.group_info + and str(stream.group_info.group_id) == str(target_id) + and stream.platform == platform + ): target_stream = stream break - + if not target_stream: logger.error(f"{getattr(self, 'log_prefix', '')} 未找到群ID为 {target_id} 的聊天流") return False @@ -66,39 +69,39 @@ class MessageAPI: # 私聊:从数据库查找对应的聊天流 target_stream = None for stream_id, stream in chat_manager.streams.items(): - if (not stream.group_info and - str(stream.user_info.user_id) == str(target_id) and - stream.platform == platform): + if ( + not stream.group_info + and str(stream.user_info.user_id) == str(target_id) + and stream.platform == platform + ): target_stream = stream break - + if not target_stream: logger.error(f"{getattr(self, 'log_prefix', '')} 未找到用户ID为 {target_id} 的私聊流") return False - + # 创建HeartFCSender实例 heart_fc_sender = HeartFCSender() - + # 生成消息ID和thinking_id current_time = time.time() message_id = f"plugin_msg_{int(current_time * 1000)}" thinking_id = f"plugin_thinking_{int(current_time * 1000)}" - + # 构建机器人用户信息 bot_user_info = UserInfo( user_id=global_config.bot.qq_account, user_nickname=global_config.bot.nickname, platform=platform, ) - + # 创建消息段 message_segment = Seg(type=message_type, data=content) - + # 创建空锚点消息(用于回复) - anchor_message = await create_empty_anchor_message( - platform, target_stream.group_info, target_stream - ) - + anchor_message = await create_empty_anchor_message(platform, target_stream.group_info, target_stream) + # 构建发送消息对象 bot_message = MessageSending( message_id=message_id, @@ -112,22 +115,17 @@ class MessageAPI: is_emoji=(message_type == "emoji"), thinking_start_time=current_time, ) - + # 发送消息 - sent_msg = await heart_fc_sender.send_message( - bot_message, - has_thinking=True, - typing=False, - set_reply=False - ) - + sent_msg = await heart_fc_sender.send_message(bot_message, has_thinking=True, typing=False, set_reply=False) + if sent_msg: logger.info(f"{getattr(self, 'log_prefix', '')} 成功发送消息到 {platform}:{target_id}") return True else: logger.error(f"{getattr(self, 'log_prefix', '')} 发送消息失败") return False - + except Exception as e: logger.error(f"{getattr(self, 'log_prefix', '')} 向目标发送消息时出错: {e}") traceback.print_exc() @@ -135,42 +133,34 @@ class MessageAPI: async def send_text_to_group(self, text: str, group_id: str, platform: str = "qq") -> bool: """便捷方法:向指定群聊发送文本消息 - + Args: text: 要发送的文本内容 group_id: 群聊ID platform: 平台,默认为"qq" - + Returns: bool: 是否发送成功 """ return await self.send_message_to_target( - message_type="text", - content=text, - platform=platform, - target_id=group_id, - is_group=True + message_type="text", content=text, platform=platform, target_id=group_id, is_group=True ) async def send_text_to_user(self, text: str, user_id: str, platform: str = "qq") -> bool: """便捷方法:向指定用户发送私聊文本消息 - + Args: text: 要发送的文本内容 user_id: 用户ID platform: 平台,默认为"qq" - + Returns: bool: 是否发送成功 """ return await self.send_message_to_target( - message_type="text", - content=text, - platform=platform, - target_id=user_id, - is_group=False + message_type="text", content=text, platform=platform, target_id=user_id, is_group=False ) - + async def send_message(self, type: str, data: str, target: Optional[str] = "", display_message: str = "") -> bool: """发送消息的简化方法 @@ -288,7 +278,9 @@ class MessageAPI: return success - async def send_message_by_replyer(self, target: Optional[str] = None, extra_info_block: Optional[str] = None) -> bool: + async def send_message_by_replyer( + self, target: Optional[str] = None, extra_info_block: Optional[str] = None + ) -> bool: """通过replyer发送消息的简化方法 Args: @@ -381,4 +373,4 @@ class MessageAPI: } messages.append(simple_msg) - return messages \ No newline at end of file + return messages diff --git a/src/chat/actions/plugin_api/stream_api.py b/src/chat/actions/plugin_api/stream_api.py index ea282dfdb..e8db18279 100644 --- a/src/chat/actions/plugin_api/stream_api.py +++ b/src/chat/actions/plugin_api/stream_api.py @@ -1,147 +1,142 @@ -import hashlib from typing import Optional, List, Dict, Any from src.common.logger_manager import get_logger from src.chat.message_receive.chat_stream import ChatManager, ChatStream -from maim_message import GroupInfo, UserInfo logger = get_logger("stream_api") class StreamAPI: """聊天流API模块 - + 提供了获取聊天流、通过群ID查找聊天流等功能 """ - + def get_chat_stream_by_group_id(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]: """通过QQ群ID获取聊天流 - + Args: group_id: QQ群ID platform: 平台标识,默认为"qq" - + Returns: Optional[ChatStream]: 找到的聊天流对象,如果未找到则返回None """ try: chat_manager = ChatManager() - + # 遍历所有已加载的聊天流,查找匹配的群ID for stream_id, stream in chat_manager.streams.items(): - if (stream.group_info and - str(stream.group_info.group_id) == str(group_id) and - stream.platform == platform): + if ( + stream.group_info + and str(stream.group_info.group_id) == str(group_id) + and stream.platform == platform + ): logger.info(f"{self.log_prefix} 通过群ID {group_id} 找到聊天流: {stream_id}") return stream - + logger.warning(f"{self.log_prefix} 未找到群ID为 {group_id} 的聊天流") return None - + except Exception as e: logger.error(f"{self.log_prefix} 通过群ID获取聊天流时出错: {e}") return None - + def get_all_group_chat_streams(self, platform: str = "qq") -> List[ChatStream]: """获取所有群聊的聊天流 - + Args: platform: 平台标识,默认为"qq" - + Returns: List[ChatStream]: 所有群聊的聊天流列表 """ try: chat_manager = ChatManager() group_streams = [] - + for stream in chat_manager.streams.values(): - if (stream.group_info and - stream.platform == platform): + if stream.group_info and stream.platform == platform: group_streams.append(stream) - + logger.info(f"{self.log_prefix} 找到 {len(group_streams)} 个群聊聊天流") return group_streams - + except Exception as e: logger.error(f"{self.log_prefix} 获取所有群聊聊天流时出错: {e}") return [] - + def get_chat_stream_by_user_id(self, user_id: str, platform: str = "qq") -> Optional[ChatStream]: """通过用户ID获取私聊聊天流 - + Args: user_id: 用户ID platform: 平台标识,默认为"qq" - + Returns: Optional[ChatStream]: 找到的私聊聊天流对象,如果未找到则返回None """ try: chat_manager = ChatManager() - + # 遍历所有已加载的聊天流,查找匹配的用户ID(私聊) for stream_id, stream in chat_manager.streams.items(): - if (not stream.group_info and # 私聊没有群信息 - stream.user_info and - str(stream.user_info.user_id) == str(user_id) and - stream.platform == platform): + if ( + not stream.group_info # 私聊没有群信息 + and stream.user_info + and str(stream.user_info.user_id) == str(user_id) + and stream.platform == platform + ): logger.info(f"{self.log_prefix} 通过用户ID {user_id} 找到私聊聊天流: {stream_id}") return stream - + logger.warning(f"{self.log_prefix} 未找到用户ID为 {user_id} 的私聊聊天流") return None - + except Exception as e: logger.error(f"{self.log_prefix} 通过用户ID获取私聊聊天流时出错: {e}") return None - + def get_chat_streams_info(self) -> List[Dict[str, Any]]: """获取所有聊天流的基本信息 - + Returns: List[Dict[str, Any]]: 包含聊天流基本信息的字典列表 """ try: chat_manager = ChatManager() streams_info = [] - + for stream_id, stream in chat_manager.streams.items(): info = { "stream_id": stream_id, "platform": stream.platform, "chat_type": "group" if stream.group_info else "private", "create_time": stream.create_time, - "last_active_time": stream.last_active_time + "last_active_time": stream.last_active_time, } - + if stream.group_info: - info.update({ - "group_id": stream.group_info.group_id, - "group_name": stream.group_info.group_name - }) - + info.update({"group_id": stream.group_info.group_id, "group_name": stream.group_info.group_name}) + if stream.user_info: - info.update({ - "user_id": stream.user_info.user_id, - "user_nickname": stream.user_info.user_nickname - }) - + info.update({"user_id": stream.user_info.user_id, "user_nickname": stream.user_info.user_nickname}) + streams_info.append(info) - + logger.info(f"{self.log_prefix} 获取到 {len(streams_info)} 个聊天流信息") return streams_info - + except Exception as e: logger.error(f"{self.log_prefix} 获取聊天流信息时出错: {e}") return [] - + async def get_chat_stream_by_group_id_async(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]: """异步通过QQ群ID获取聊天流(包括从数据库搜索) - + Args: group_id: QQ群ID platform: 平台标识,默认为"qq" - + Returns: Optional[ChatStream]: 找到的聊天流对象,如果未找到则返回None """ @@ -150,15 +145,15 @@ class StreamAPI: stream = self.get_chat_stream_by_group_id(group_id, platform) if stream: return stream - + # 如果内存中没有,尝试从数据库加载所有聊天流后再查找 chat_manager = ChatManager() await chat_manager.load_all_streams() - + # 再次尝试从内存中查找 stream = self.get_chat_stream_by_group_id(group_id, platform) return stream - + except Exception as e: logger.error(f"{self.log_prefix} 异步通过群ID获取聊天流时出错: {e}") - return None \ No newline at end of file + return None diff --git a/src/chat/actions/plugin_api/utils_api.py b/src/chat/actions/plugin_api/utils_api.py index b5c476fa1..1cae23b03 100644 --- a/src/chat/actions/plugin_api/utils_api.py +++ b/src/chat/actions/plugin_api/utils_api.py @@ -1,35 +1,37 @@ import os import json import time -from typing import Any, Dict, List, Optional +from typing import Any, Optional from src.common.logger_manager import get_logger logger = get_logger("utils_api") + class UtilsAPI: """工具类API模块 - + 提供了各种辅助功能 """ - + def get_plugin_path(self) -> str: """获取当前插件的路径 - + Returns: str: 插件目录的绝对路径 """ import inspect + plugin_module_path = inspect.getfile(self.__class__) plugin_dir = os.path.dirname(plugin_module_path) return plugin_dir - + def read_json_file(self, file_path: str, default: Any = None) -> Any: """读取JSON文件 - + Args: file_path: 文件路径,可以是相对于插件目录的路径 default: 如果文件不存在或读取失败时返回的默认值 - + Returns: Any: JSON数据或默认值 """ @@ -37,25 +39,25 @@ class UtilsAPI: # 如果是相对路径,则相对于插件目录 if not os.path.isabs(file_path): file_path = os.path.join(self.get_plugin_path(), file_path) - + if not os.path.exists(file_path): logger.warning(f"{self.log_prefix} 文件不存在: {file_path}") return default - - with open(file_path, 'r', encoding='utf-8') as f: + + with open(file_path, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: logger.error(f"{self.log_prefix} 读取JSON文件出错: {e}") return default - + def write_json_file(self, file_path: str, data: Any, indent: int = 2) -> bool: """写入JSON文件 - + Args: file_path: 文件路径,可以是相对于插件目录的路径 data: 要写入的数据 indent: JSON缩进 - + Returns: bool: 是否写入成功 """ @@ -63,59 +65,62 @@ class UtilsAPI: # 如果是相对路径,则相对于插件目录 if not os.path.isabs(file_path): file_path = os.path.join(self.get_plugin_path(), file_path) - + # 确保目录存在 os.makedirs(os.path.dirname(file_path), exist_ok=True) - - with open(file_path, 'w', encoding='utf-8') as f: + + with open(file_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=indent) return True except Exception as e: logger.error(f"{self.log_prefix} 写入JSON文件出错: {e}") return False - + def get_timestamp(self) -> int: """获取当前时间戳 - + Returns: int: 当前时间戳(秒) """ return int(time.time()) - + def format_time(self, timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str: """格式化时间 - + Args: timestamp: 时间戳,如果为None则使用当前时间 format_str: 时间格式字符串 - + Returns: str: 格式化后的时间字符串 """ import datetime + if timestamp is None: timestamp = time.time() return datetime.datetime.fromtimestamp(timestamp).strftime(format_str) - + def parse_time(self, time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int: """解析时间字符串为时间戳 - + Args: time_str: 时间字符串 format_str: 时间格式字符串 - + Returns: int: 时间戳(秒) """ import datetime + dt = datetime.datetime.strptime(time_str, format_str) return int(dt.timestamp()) - + def generate_unique_id(self) -> str: """生成唯一ID - + Returns: str: 唯一ID """ import uuid - return str(uuid.uuid4()) \ No newline at end of file + + return str(uuid.uuid4()) diff --git a/src/chat/command/command_handler.py b/src/chat/command/command_handler.py index d15215d4d..07b452a6c 100644 --- a/src/chat/command/command_handler.py +++ b/src/chat/command/command_handler.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from typing import Dict, List, Type, Optional, Tuple, Pattern from src.common.logger_manager import get_logger from src.chat.message_receive.message import MessageRecv -from src.chat.actions.plugin_api.message_api import MessageAPI from src.chat.focus_chat.hfc_utils import create_empty_anchor_message from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor @@ -13,9 +12,10 @@ logger = get_logger("command_handler") _COMMAND_REGISTRY: Dict[str, Type["BaseCommand"]] = {} _COMMAND_PATTERNS: Dict[Pattern, Type["BaseCommand"]] = {} + class BaseCommand(ABC): """命令基类,所有自定义命令都应该继承这个类""" - + # 命令的基本属性 command_name: str = "" # 命令名称 command_description: str = "" # 命令描述 @@ -23,43 +23,43 @@ class BaseCommand(ABC): command_help: str = "" # 命令帮助信息 command_examples: List[str] = [] # 命令使用示例 enable_command: bool = True # 是否启用命令 - + def __init__(self, message: MessageRecv): """初始化命令处理器 - + Args: message: 接收到的消息对象 """ self.message = message self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组 self._services = {} # 存储内部服务 - + # 设置服务 self._services["chat_stream"] = message.chat_stream - + # 日志前缀 self.log_prefix = f"[Command:{self.command_name}]" - + @abstractmethod async def execute(self) -> Tuple[bool, Optional[str]]: """执行命令的抽象方法,需要被子类实现 - + Returns: Tuple[bool, Optional[str]]: (是否执行成功, 可选的回复消息) """ pass - + def set_matched_groups(self, groups: Dict[str, str]) -> None: """设置正则表达式匹配的命名组 - + Args: groups: 正则表达式匹配的命名组 """ self.matched_groups = groups - + async def send_reply(self, content: str) -> None: """发送回复消息 - + Args: content: 回复内容 """ @@ -69,43 +69,42 @@ class BaseCommand(ABC): if not chat_stream: logger.error(f"{self.log_prefix} 无法发送消息:缺少chat_stream") return - + # 创建空的锚定消息 anchor_message = await create_empty_anchor_message( - chat_stream.platform, - chat_stream.group_info, - chat_stream + chat_stream.platform, chat_stream.group_info, chat_stream ) - + # 创建表达器,传入chat_stream参数 expressor = DefaultExpressor(chat_stream) - + # 设置服务 self._services["expressor"] = expressor - + # 发送消息 response_set = [ ("text", content), ] - + # 调用表达器发送消息 await expressor.send_response_messages( anchor_message=anchor_message, response_set=response_set, display_message="", ) - + logger.info(f"{self.log_prefix} 命令回复消息发送成功: {content[:30]}...") except Exception as e: logger.error(f"{self.log_prefix} 发送命令回复消息失败: {e}") import traceback + logger.error(traceback.format_exc()) def register_command(cls): """ 命令注册装饰器 - + 用法: @register_command class MyCommand(BaseCommand): @@ -115,21 +114,25 @@ def register_command(cls): ... """ # 检查类是否有必要的属性 - if not hasattr(cls, "command_name") or not hasattr(cls, "command_description") or not hasattr(cls, "command_pattern"): + if ( + not hasattr(cls, "command_name") + or not hasattr(cls, "command_description") + or not hasattr(cls, "command_pattern") + ): logger.error(f"命令类 {cls.__name__} 缺少必要的属性: command_name, command_description 或 command_pattern") return cls - + command_name = cls.command_name command_pattern = cls.command_pattern is_enabled = getattr(cls, "enable_command", True) # 默认启用命令 - + if not command_name or not command_pattern: logger.error(f"命令类 {cls.__name__} 的 command_name 或 command_pattern 为空") return cls - + # 将命令类注册到全局注册表 _COMMAND_REGISTRY[command_name] = cls - + # 编译正则表达式并注册 try: pattern = re.compile(command_pattern, re.IGNORECASE | re.DOTALL) @@ -137,47 +140,47 @@ def register_command(cls): logger.info(f"已注册命令: {command_name} -> {cls.__name__},命令启用: {is_enabled}") except re.error as e: logger.error(f"命令 {command_name} 的正则表达式编译失败: {e}") - + return cls class CommandManager: """命令管理器,负责处理命令(不再负责加载,加载由统一的插件加载器处理)""" - + def __init__(self): """初始化命令管理器""" # 命令加载现在由统一的插件加载器处理,这里只需要初始化 logger.info("命令管理器初始化完成") - + async def process_command(self, message: MessageRecv) -> Tuple[bool, Optional[str], bool]: """处理消息中的命令 - + Args: message: 接收到的消息对象 - + Returns: Tuple[bool, Optional[str], bool]: (是否找到并执行了命令, 命令执行结果, 是否继续处理消息) """ if not message.processed_plain_text: await message.process() - + text = message.processed_plain_text - + # 检查是否匹配任何命令模式 for pattern, command_cls in _COMMAND_PATTERNS.items(): match = pattern.match(text) if match and getattr(command_cls, "enable_command", True): # 创建命令实例 command_instance = command_cls(message) - + # 提取命名组并设置 groups = match.groupdict() command_instance.set_matched_groups(groups) - + try: # 执行命令 success, response = await command_instance.execute() - + # 记录命令执行结果 if success: logger.info(f"命令 {command_cls.command_name} 执行成功") @@ -189,27 +192,28 @@ class CommandManager: if response: # 使用命令实例的send_reply方法发送错误信息 await command_instance.send_reply(f"命令执行失败: {response}") - + # 命令执行后不再继续处理消息 return True, response, False - + except Exception as e: logger.error(f"执行命令 {command_cls.command_name} 时出错: {e}") import traceback + logger.error(traceback.format_exc()) - + try: # 使用命令实例的send_reply方法发送错误信息 await command_instance.send_reply(f"命令执行出错: {str(e)}") except Exception as send_error: logger.error(f"发送错误消息失败: {send_error}") - + # 命令执行出错后不再继续处理消息 return True, str(e), False - + # 没有匹配到任何命令,继续处理消息 return False, None, True # 创建全局命令管理器实例 -command_manager = CommandManager() \ No newline at end of file +command_manager = CommandManager() diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index adb595f17..a0e85843b 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -227,8 +227,6 @@ class DefaultExpressor: logger.info(f"想要表达:{in_mind_reply}||理由:{reason}") logger.info(f"最终回复: {content}\n") - - except Exception as llm_e: # 精简报错信息 logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}") diff --git a/src/chat/focus_chat/expressors/exprssion_learner.py b/src/chat/focus_chat/expressors/exprssion_learner.py index b7de6ce6d..e210cf7ed 100644 --- a/src/chat/focus_chat/expressors/exprssion_learner.py +++ b/src/chat/focus_chat/expressors/exprssion_learner.py @@ -113,25 +113,25 @@ class ExpressionLearner: 同时对所有已存储的表达方式进行全局衰减 """ current_time = time.time() - + # 全局衰减所有已存储的表达方式 for type in ["style", "grammar"]: base_dir = os.path.join("data", "expression", f"learnt_{type}") if not os.path.exists(base_dir): continue - + for chat_id in os.listdir(base_dir): file_path = os.path.join(base_dir, chat_id, "expressions.json") if not os.path.exists(file_path): continue - + try: with open(file_path, "r", encoding="utf-8") as f: expressions = json.load(f) - + # 应用全局衰减 decayed_expressions = self.apply_decay_to_expressions(expressions, current_time) - + # 保存衰减后的结果 with open(file_path, "w", encoding="utf-8") as f: json.dump(decayed_expressions, f, ensure_ascii=False, indent=2) @@ -162,23 +162,25 @@ class ExpressionLearner: """ if time_diff_days <= 0 or time_diff_days >= DECAY_DAYS: return 0.001 - + # 使用二次函数进行插值 # 将7天作为顶点,0天和30天作为两个端点 # 使用顶点式:y = a(x-h)^2 + k,其中(h,k)为顶点 h = 7.0 # 顶点x坐标 k = 0.001 # 顶点y坐标 - + # 计算a值,使得x=0和x=30时y=0.001 # 0.001 = a(0-7)^2 + 0.001 # 解得a = 0 a = 0 - + # 计算衰减值 decay = a * (time_diff_days - h) ** 2 + k return min(0.001, decay) - def apply_decay_to_expressions(self, expressions: List[Dict[str, Any]], current_time: float) -> List[Dict[str, Any]]: + def apply_decay_to_expressions( + self, expressions: List[Dict[str, Any]], current_time: float + ) -> List[Dict[str, Any]]: """ 对表达式列表应用衰减 返回衰减后的表达式列表,移除count小于0的项 @@ -188,16 +190,16 @@ class ExpressionLearner: # 确保last_active_time存在,如果不存在则使用current_time if "last_active_time" not in expr: expr["last_active_time"] = current_time - + last_active = expr["last_active_time"] time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 - + decay_value = self.calculate_decay_factor(time_diff_days) expr["count"] = max(0.01, expr.get("count", 1) - decay_value) - + if expr["count"] > 0: result.append(expr) - + return result async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]: @@ -211,7 +213,7 @@ class ExpressionLearner: type_str = "句法特点" else: raise ValueError(f"Invalid type: {type}") - + res = await self.learn_expression(type, num) if res is None: @@ -238,15 +240,15 @@ class ExpressionLearner: if chat_id not in chat_dict: chat_dict[chat_id] = [] chat_dict[chat_id].append({"situation": situation, "style": style}) - + current_time = time.time() - + # 存储到/data/expression/对应chat_id/expressions.json for chat_id, expr_list in chat_dict.items(): dir_path = os.path.join("data", "expression", f"learnt_{type}", str(chat_id)) os.makedirs(dir_path, exist_ok=True) file_path = os.path.join(dir_path, "expressions.json") - + # 若已存在,先读出合并 old_data: List[Dict[str, Any]] = [] if os.path.exists(file_path): @@ -255,10 +257,10 @@ class ExpressionLearner: old_data = json.load(f) except Exception: old_data = [] - + # 应用衰减 # old_data = self.apply_decay_to_expressions(old_data, current_time) - + # 合并逻辑 for new_expr in expr_list: found = False @@ -278,43 +280,43 @@ class ExpressionLearner: new_expr["count"] = 1 new_expr["last_active_time"] = current_time old_data.append(new_expr) - + # 处理超限问题 if len(old_data) > MAX_EXPRESSION_COUNT: # 计算每个表达方式的权重(count的倒数,这样count越小的越容易被选中) weights = [1 / (expr.get("count", 1) + 0.1) for expr in old_data] - + # 随机选择要移除的表达方式,避免重复索引 remove_count = len(old_data) - MAX_EXPRESSION_COUNT - + # 使用一种不会选到重复索引的方法 indices = list(range(len(old_data))) - + # 方法1:使用numpy.random.choice # 把列表转成一个映射字典,保证不会有重复 remove_set = set() total_attempts = 0 - + # 尝试按权重随机选择,直到选够数量 while len(remove_set) < remove_count and total_attempts < len(old_data) * 2: idx = random.choices(indices, weights=weights, k=1)[0] remove_set.add(idx) total_attempts += 1 - + # 如果没选够,随机补充 if len(remove_set) < remove_count: remaining = set(indices) - remove_set remove_set.update(random.sample(list(remaining), remove_count - len(remove_set))) - + remove_indices = list(remove_set) - + # 从后往前删除,避免索引变化 for idx in sorted(remove_indices, reverse=True): old_data.pop(idx) - + with open(file_path, "w", encoding="utf-8") as f: json.dump(old_data, f, ensure_ascii=False, indent=2) - + return learnt_expressions async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]: diff --git a/src/chat/focus_chat/heartFC_Cycleinfo.py b/src/chat/focus_chat/heartFC_Cycleinfo.py index ec0c4f1c7..7900a16a2 100644 --- a/src/chat/focus_chat/heartFC_Cycleinfo.py +++ b/src/chat/focus_chat/heartFC_Cycleinfo.py @@ -97,7 +97,7 @@ class CycleDetail: ) # current_time_minute = time.strftime("%Y%m%d_%H%M", time.localtime()) - + # try: # self.log_cycle_to_file( # log_dir + self.prefix + f"/{current_time_minute}_cycle_" + str(self.cycle_id) + ".json" @@ -117,7 +117,6 @@ class CycleDetail: if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) # 写入文件 - file_path = os.path.join(dir_name, os.path.basename(file_path)) # print("file_path:", file_path) diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 3137c1f23..4ab767a15 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -99,22 +99,23 @@ class HeartFChatting: self.stream_id: str = chat_id # 聊天流ID self.chat_stream = chat_manager.get_stream(self.stream_id) self.log_prefix = f"[{chat_manager.get_stream_name(self.stream_id) or self.stream_id}]" - + self.memory_activator = MemoryActivator() - + # 初始化观察器 self.observations: List[Observation] = [] self._register_observations() - + # 根据配置文件和默认规则确定启用的处理器 config_processor_settings = global_config.focus_chat_processor self.enabled_processor_names = [] - + for proc_name, (_proc_class, config_key) in PROCESSOR_CLASSES.items(): # 对于关系处理器,需要同时检查两个配置项 if proc_name == "RelationshipProcessor": - if (global_config.relationship.enable_relationship and - getattr(config_processor_settings, config_key, True)): + if global_config.relationship.enable_relationship and getattr( + config_processor_settings, config_key, True + ): self.enabled_processor_names.append(proc_name) else: # 其他处理器的原有逻辑 @@ -122,14 +123,13 @@ class HeartFChatting: self.enabled_processor_names.append(proc_name) # logger.info(f"{self.log_prefix} 将启用的处理器: {self.enabled_processor_names}") - + self.processors: List[BaseProcessor] = [] self._register_default_processors() self.expressor = DefaultExpressor(chat_stream=self.chat_stream) self.replyer = DefaultReplyer(chat_stream=self.chat_stream) - - + self.action_manager = ActionManager() self.action_planner = PlannerFactory.create_planner( log_prefix=self.log_prefix, action_manager=self.action_manager @@ -138,7 +138,6 @@ class HeartFChatting: self.action_observation = ActionObservation(observe_id=self.stream_id) self.action_observation.set_action_manager(self.action_manager) - self._processing_lock = asyncio.Lock() # 循环控制内部状态 @@ -182,7 +181,13 @@ class HeartFChatting: if processor_info: processor_actual_class = processor_info[0] # 获取实际的类定义 # 根据处理器类名判断是否需要 subheartflow_id - if name in ["MindProcessor", "ToolProcessor", "WorkingMemoryProcessor", "SelfProcessor", "RelationshipProcessor"]: + if name in [ + "MindProcessor", + "ToolProcessor", + "WorkingMemoryProcessor", + "SelfProcessor", + "RelationshipProcessor", + ]: self.processors.append(processor_actual_class(subheartflow_id=self.stream_id)) elif name == "ChattingInfoProcessor": self.processors.append(processor_actual_class()) @@ -203,9 +208,7 @@ class HeartFChatting: ) if self.processors: - logger.info( - f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}" - ) + logger.info(f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}") else: logger.warning(f"{self.log_prefix} 没有注册任何处理器。这可能是由于配置错误或所有处理器都被禁用了。") @@ -292,7 +295,9 @@ class HeartFChatting: self._current_cycle_detail.set_loop_info(loop_info) # 从observations列表中获取HFCloopObservation - hfcloop_observation = next((obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None) + hfcloop_observation = next( + (obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None + ) if hfcloop_observation: hfcloop_observation.add_loop_info(self._current_cycle_detail) else: @@ -451,19 +456,19 @@ class HeartFChatting: # 根据配置决定是否并行执行调整动作、回忆和处理器阶段 - # 并行执行调整动作、回忆和处理器阶段 + # 并行执行调整动作、回忆和处理器阶段 with Timer("并行调整动作、处理", cycle_timers): # 创建并行任务 - async def modify_actions_task(): + async def modify_actions_task(): # 调用完整的动作修改流程 await self.action_modifier.modify_actions( observations=self.observations, ) - + await self.action_observation.observe() self.observations.append(self.action_observation) return True - + # 创建三个并行任务 action_modify_task = asyncio.create_task(modify_actions_task()) memory_task = asyncio.create_task(self.memory_activator.activate_memory(self.observations)) @@ -474,9 +479,6 @@ class HeartFChatting: action_modify_task, memory_task, processor_task ) - - - loop_processor_info = { "all_plan_info": all_plan_info, "processor_time_costs": processor_time_costs, @@ -594,9 +596,7 @@ class HeartFChatting: else: success, reply_text = result command = "" - logger.debug( - f"{self.log_prefix} 麦麦执行了'{action}', 返回结果'{success}', '{reply_text}', '{command}'" - ) + logger.debug(f"{self.log_prefix} 麦麦执行了'{action}', 返回结果'{success}', '{reply_text}', '{command}'") return success, reply_text, command diff --git a/src/chat/focus_chat/heartflow_message_processor.py b/src/chat/focus_chat/heartflow_message_processor.py index c20f29a13..b09b72bdd 100644 --- a/src/chat/focus_chat/heartflow_message_processor.py +++ b/src/chat/focus_chat/heartflow_message_processor.py @@ -51,8 +51,8 @@ async def _process_relationship(message: MessageRecv) -> None: logger.info(f"首次认识用户: {nickname}") await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # elif not await relationship_manager.is_qved_name(platform, user_id): - # logger.info(f"给用户({nickname},{cardname})取名: {nickname}") - # await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "") + # logger.info(f"给用户({nickname},{cardname})取名: {nickname}") + # await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "") async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: @@ -74,7 +74,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: fast_retrieval=True, ) logger.trace(f"记忆激活率: {interested_rate:.2f}") - + text_len = len(message.processed_plain_text) # 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05 # 采用对数函数实现递减增长 @@ -181,7 +181,6 @@ class HeartFCMessageReceiver: userinfo = message.message_info.user_info messageinfo = message.message_info - chat = await chat_manager.get_or_create_stream( platform=messageinfo.platform, user_info=userinfo, diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py index e2ae41c0d..561b90f5d 100644 --- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py +++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py @@ -11,7 +11,6 @@ from datetime import datetime from typing import Dict from src.llm_models.utils_model import LLMRequest from src.config.config import global_config -import asyncio logger = get_logger("processor") diff --git a/src/chat/focus_chat/info_processors/relationship_processor.py b/src/chat/focus_chat/info_processors/relationship_processor.py index 0436b5e50..9d25235c4 100644 --- a/src/chat/focus_chat/info_processors/relationship_processor.py +++ b/src/chat/focus_chat/info_processors/relationship_processor.py @@ -22,7 +22,7 @@ from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_ch # 配置常量:是否启用小模型即时信息提取 # 开启时:使用小模型并行即时提取,速度更快,但精度可能略低 # 关闭时:使用原来的异步模式,精度更高但速度较慢 -ENABLE_INSTANT_INFO_EXTRACTION = True +ENABLE_INSTANT_INFO_EXTRACTION = True logger = get_logger("processor") @@ -63,7 +63,7 @@ def init_prompt(): """ Prompt(relationship_prompt, "relationship_prompt") - + fetch_info_prompt = """ {name_block} @@ -84,7 +84,6 @@ def init_prompt(): Prompt(fetch_info_prompt, "fetch_info_prompt") - class RelationshipProcessor(BaseProcessor): log_prefix = "关系" @@ -92,8 +91,10 @@ class RelationshipProcessor(BaseProcessor): super().__init__() self.subheartflow_id = subheartflow_id - self.info_fetching_cache: List[Dict[str, any]] = [] - self.info_fetched_cache: Dict[str, Dict[str, any]] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}} + self.info_fetching_cache: List[Dict[str, any]] = [] + self.info_fetched_cache: Dict[ + str, Dict[str, any] + ] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}} self.person_engaged_cache: List[Dict[str, any]] = [] # [{person_id: str, start_time: float, rounds: int}] self.grace_period_rounds = 5 @@ -101,7 +102,7 @@ class RelationshipProcessor(BaseProcessor): model=global_config.model.relation, request_type="focus.relationship", ) - + # 小模型用于即时信息提取 if ENABLE_INSTANT_INFO_EXTRACTION: self.instant_llm_model = LLMRequest( @@ -156,26 +157,27 @@ class RelationshipProcessor(BaseProcessor): for record in list(self.person_engaged_cache): record["rounds"] += 1 time_elapsed = current_time - record["start_time"] - message_count = len(get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time)) - + message_count = len( + get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time) + ) + print(record) - + # 根据消息数量和时间设置不同的触发条件 should_trigger = ( - message_count >= 50 or # 50条消息必定满足 - (message_count >= 35 and time_elapsed >= 300) or # 35条且10分钟 - (message_count >= 25 and time_elapsed >= 900) or # 25条且30分钟 - (message_count >= 10 and time_elapsed >= 2000) # 10条且1小时 + message_count >= 50 # 50条消息必定满足 + or (message_count >= 35 and time_elapsed >= 300) # 35条且10分钟 + or (message_count >= 25 and time_elapsed >= 900) # 25条且30分钟 + or (message_count >= 10 and time_elapsed >= 2000) # 10条且1小时 ) - + if should_trigger: - logger.info(f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。消息数:{message_count},时长:{time_elapsed:.0f}秒") + logger.info( + f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。消息数:{message_count},时长:{time_elapsed:.0f}秒" + ) asyncio.create_task( self.update_impression_on_cache_expiry( - record["person_id"], - self.subheartflow_id, - record["start_time"], - current_time + record["person_id"], self.subheartflow_id, record["start_time"], current_time ) ) self.person_engaged_cache.remove(record) @@ -187,20 +189,24 @@ class RelationshipProcessor(BaseProcessor): if self.info_fetched_cache[person_id][info_type]["ttl"] <= 0: # 在删除前查找匹配的info_fetching_cache记录 matched_record = None - min_time_diff = float('inf') + min_time_diff = float("inf") for record in self.info_fetching_cache: - if (record["person_id"] == person_id and - record["info_type"] == info_type and - not record["forget"]): - time_diff = abs(record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"]) + if ( + record["person_id"] == person_id + and record["info_type"] == info_type + and not record["forget"] + ): + time_diff = abs( + record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"] + ) if time_diff < min_time_diff: min_time_diff = time_diff matched_record = record - + if matched_record: matched_record["forget"] = True logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已过期,标记为遗忘。") - + del self.info_fetched_cache[person_id][info_type] if not self.info_fetched_cache[person_id]: del self.info_fetched_cache[person_id] @@ -208,7 +214,7 @@ class RelationshipProcessor(BaseProcessor): # 5. 为需要处理的人员准备LLM prompt nickname_str = ",".join(global_config.bot.alias_names) name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" - + info_cache_block = "" if self.info_fetching_cache: for info_fetching in self.info_fetching_cache: @@ -223,7 +229,7 @@ class RelationshipProcessor(BaseProcessor): chat_observe_info=chat_observe_info, info_cache_block=info_cache_block, ) - + try: logger.debug(f"{self.log_prefix} 人物信息prompt: \n{prompt}\n") content, _ = await self.llm_model.generate_response_async(prompt=prompt) @@ -234,45 +240,47 @@ class RelationshipProcessor(BaseProcessor): # 收集即时提取任务 instant_tasks = [] async_tasks = [] - + for person_name, info_type in content_json.items(): person_id = person_info_manager.get_person_id_by_person_name(person_name) if person_id: - self.info_fetching_cache.append({ - "person_id": person_id, - "person_name": person_name, - "info_type": info_type, - "start_time": time.time(), - "forget": False, - }) + self.info_fetching_cache.append( + { + "person_id": person_id, + "person_name": person_name, + "info_type": info_type, + "start_time": time.time(), + "forget": False, + } + ) if len(self.info_fetching_cache) > 20: self.info_fetching_cache.pop(0) else: logger.warning(f"{self.log_prefix} 未找到用户 {person_name} 的ID,跳过调取信息。") continue - + logger.info(f"{self.log_prefix} 调取用户 {person_name} 的 {info_type} 信息。") - + # 检查person_engaged_cache中是否已存在该person_id person_exists = any(record["person_id"] == person_id for record in self.person_engaged_cache) if not person_exists: - self.person_engaged_cache.append({ - "person_id": person_id, - "start_time": time.time(), - "rounds": 0 - }) - + self.person_engaged_cache.append( + {"person_id": person_id, "start_time": time.time(), "rounds": 0} + ) + if ENABLE_INSTANT_INFO_EXTRACTION: # 收集即时提取任务 instant_tasks.append((person_id, info_type, time.time())) else: # 使用原来的异步模式 - async_tasks.append(asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time()))) + async_tasks.append( + asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time())) + ) # 执行即时提取任务 if ENABLE_INSTANT_INFO_EXTRACTION and instant_tasks: await self._execute_instant_extraction_batch(instant_tasks) - + # 启动异步任务(如果不是即时模式) if async_tasks: # 异步任务不需要等待完成 @@ -300,7 +308,7 @@ class RelationshipProcessor(BaseProcessor): person_infos_str += f"你不了解{person_name}有关[{info_type}]的信息,不要胡乱回答,你可以直接说你不知道,或者你忘记了;" if person_infos_str: persons_infos_str += f"你对 {person_name} 的了解:{person_infos_str}\n" - + # 处理正在调取但还没有结果的项目(只在非即时提取模式下显示) if not ENABLE_INSTANT_INFO_EXTRACTION: pending_info_dict = {} @@ -312,50 +320,47 @@ class RelationshipProcessor(BaseProcessor): person_id = record["person_id"] person_name = record["person_name"] info_type = record["info_type"] - + # 检查是否已经在info_fetched_cache中有结果 - if (person_id in self.info_fetched_cache and - info_type in self.info_fetched_cache[person_id]): + if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]: continue - + # 按人物组织正在调取的信息 if person_name not in pending_info_dict: pending_info_dict[person_name] = [] pending_info_dict[person_name].append(info_type) - + # 添加正在调取的信息到返回字符串 for person_name, info_types in pending_info_dict.items(): info_types_str = "、".join(info_types) persons_infos_str += f"你正在识图回忆有关 {person_name} 的 {info_types_str} 信息,稍等一下再回答...\n" return persons_infos_str - + async def _execute_instant_extraction_batch(self, instant_tasks: list): """ 批量执行即时提取任务 """ if not instant_tasks: return - + logger.info(f"{self.log_prefix} [即时提取] 开始批量提取 {len(instant_tasks)} 个信息") - + # 创建所有提取任务 extraction_tasks = [] for person_id, info_type, start_time in instant_tasks: # 检查缓存中是否已存在且未过期的信息 - if (person_id in self.info_fetched_cache and - info_type in self.info_fetched_cache[person_id]): + if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]: logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已存在且未过期,跳过调取。") continue - + task = asyncio.create_task(self._fetch_single_info_instant(person_id, info_type, start_time)) extraction_tasks.append(task) - + # 并行执行所有提取任务并等待完成 if extraction_tasks: await asyncio.gather(*extraction_tasks, return_exceptions=True) logger.info(f"{self.log_prefix} [即时提取] 批量提取完成") - async def _fetch_single_info_instant(self, person_id: str, info_type: str, start_time: float): """ @@ -363,24 +368,21 @@ class RelationshipProcessor(BaseProcessor): """ nickname_str = ",".join(global_config.bot.alias_names) name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" - + person_name = await person_info_manager.get_value(person_id, "person_name") - + person_impression = await person_info_manager.get_value(person_id, "impression") if not person_impression: impression_block = "你对ta没有什么深刻的印象" else: impression_block = f"{person_impression}" - + points = await person_info_manager.get_value(person_id, "points") if points: - points_text = "\n".join([ - f"{point[2]}:{point[0]}" - for point in points - ]) + points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points]) else: points_text = "你不记得ta最近发生了什么" - + prompt = (await global_prompt_manager.get_prompt_async("fetch_info_prompt")).format( name_block=name_block, info_type=info_type, @@ -393,9 +395,9 @@ class RelationshipProcessor(BaseProcessor): try: # 使用小模型进行即时提取 content, _ = await self.instant_llm_model.generate_response_async(prompt=prompt) - + logger.info(f"{self.log_prefix} [即时提取] {person_name} 的 {info_type} 结果: {content}") - + if content: content_json = json.loads(repair_json(content)) if info_type in content_json: @@ -410,7 +412,9 @@ class RelationshipProcessor(BaseProcessor): "person_name": person_name, "unknow": False, } - logger.info(f"{self.log_prefix} [即时提取] 成功获取 {person_name} 的 {info_type}: {info_content}") + logger.info( + f"{self.log_prefix} [即时提取] 成功获取 {person_name} 的 {info_type}: {info_content}" + ) else: if person_id not in self.info_fetched_cache: self.info_fetched_cache[person_id] = {} @@ -423,59 +427,55 @@ class RelationshipProcessor(BaseProcessor): } logger.info(f"{self.log_prefix} [即时提取] {person_name} 的 {info_type} 信息不明确") else: - logger.warning(f"{self.log_prefix} [即时提取] 小模型返回空结果,获取 {person_name} 的 {info_type} 信息失败。") + logger.warning( + f"{self.log_prefix} [即时提取] 小模型返回空结果,获取 {person_name} 的 {info_type} 信息失败。" + ) except Exception as e: logger.error(f"{self.log_prefix} [即时提取] 执行小模型请求获取用户信息时出错: {e}") logger.error(traceback.format_exc()) - + async def fetch_person_info(self, person_id: str, info_types: list[str], start_time: float): """ 获取某个人的信息 """ # 检查缓存中是否已存在且未过期的信息 info_types_to_fetch = [] - + for info_type in info_types: - if (person_id in self.info_fetched_cache and - info_type in self.info_fetched_cache[person_id]): + if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]: logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已存在且未过期,跳过调取。") continue info_types_to_fetch.append(info_type) - + if not info_types_to_fetch: return - + nickname_str = ",".join(global_config.bot.alias_names) name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" - + person_name = await person_info_manager.get_value(person_id, "person_name") - + info_type_str = "" info_json_str = "" for info_type in info_types_to_fetch: info_type_str += f"{info_type}," - info_json_str += f"\"{info_type}\": \"信息内容\"," + info_json_str += f'"{info_type}": "信息内容",' info_type_str = info_type_str[:-1] info_json_str = info_json_str[:-1] - + person_impression = await person_info_manager.get_value(person_id, "impression") if not person_impression: impression_block = "你对ta没有什么深刻的印象" else: impression_block = f"{person_impression}" - - + points = await person_info_manager.get_value(person_id, "points") if points: - points_text = "\n".join([ - f"{point[2]}:{point[0]}" - for point in points - ]) + points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points]) else: points_text = "你不记得ta最近发生了什么" - - + prompt = (await global_prompt_manager.get_prompt_async("fetch_info_prompt")).format( name_block=name_block, info_type=info_type_str, @@ -487,10 +487,10 @@ class RelationshipProcessor(BaseProcessor): try: content, _ = await self.llm_model.generate_response_async(prompt=prompt) - + # logger.info(f"{self.log_prefix} fetch_person_info prompt: \n{prompt}\n") logger.info(f"{self.log_prefix} fetch_person_info 结果: {content}") - + if content: try: content_json = json.loads(repair_json(content)) @@ -508,9 +508,9 @@ class RelationshipProcessor(BaseProcessor): else: if person_id not in self.info_fetched_cache: self.info_fetched_cache[person_id] = {} - + self.info_fetched_cache[person_id][info_type] = { - "info":"unknow", + "info": "unknow", "ttl": 10, "start_time": start_time, "person_name": person_name, @@ -525,16 +525,12 @@ class RelationshipProcessor(BaseProcessor): logger.error(f"{self.log_prefix} 执行LLM请求获取用户信息时出错: {e}") logger.error(traceback.format_exc()) - async def update_impression_on_cache_expiry( - self, person_id: str, chat_id: str, start_time: float, end_time: float - ): + async def update_impression_on_cache_expiry(self, person_id: str, chat_id: str, start_time: float, end_time: float): """ 在缓存过期时,获取聊天记录并更新用户印象 """ logger.info(f"缓存过期,开始为 {person_id} 更新印象。时间范围:{start_time} -> {end_time}") try: - - impression_messages = get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time) if impression_messages: logger.info(f"为 {person_id} 获取到 {len(impression_messages)} 条消息用于印象更新。") diff --git a/src/chat/focus_chat/info_processors/self_processor.py b/src/chat/focus_chat/info_processors/self_processor.py index f21a1d3b1..0f75b6686 100644 --- a/src/chat/focus_chat/info_processors/self_processor.py +++ b/src/chat/focus_chat/info_processors/self_processor.py @@ -122,9 +122,7 @@ class SelfProcessor(BaseProcessor): ) # 获取聊天内容 chat_observe_info = observation.get_observe_info() - person_list = observation.person_list if isinstance(observation, HFCloopObservation): - # hfcloop_observe_info = observation.get_observe_info() pass nickname_str = "" @@ -133,9 +131,7 @@ class SelfProcessor(BaseProcessor): name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" personality_block = individuality.get_personality_prompt(x_person=2, level=2) - - - + identity_block = individuality.get_identity_prompt(x_person=2, level=2) prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format( diff --git a/src/chat/focus_chat/info_processors/tool_processor.py b/src/chat/focus_chat/info_processors/tool_processor.py index cf31f4418..2f46fc8b2 100644 --- a/src/chat/focus_chat/info_processors/tool_processor.py +++ b/src/chat/focus_chat/info_processors/tool_processor.py @@ -118,7 +118,7 @@ class ToolProcessor(BaseProcessor): is_group_chat = observation.is_group_chat chat_observe_info = observation.get_observe_info() - person_list = observation.person_list + # person_list = observation.person_list memory_str = "" if running_memorys: @@ -141,9 +141,7 @@ class ToolProcessor(BaseProcessor): # 调用LLM,专注于工具使用 # logger.info(f"开始执行工具调用{prompt}") - response, other_info = await self.llm_model.generate_response_async( - prompt=prompt, tools=tools - ) + response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools) if len(other_info) == 3: reasoning_content, model_name, tool_calls = other_info diff --git a/src/chat/focus_chat/info_processors/working_memory_processor.py b/src/chat/focus_chat/info_processors/working_memory_processor.py index 9eb848089..af016e7bb 100644 --- a/src/chat/focus_chat/info_processors/working_memory_processor.py +++ b/src/chat/focus_chat/info_processors/working_memory_processor.py @@ -118,9 +118,7 @@ class WorkingMemoryProcessor(BaseProcessor): memory_str=memory_choose_str, ) - # print(f"prompt: {prompt}") - # 调用LLM处理记忆 content = "" diff --git a/src/chat/focus_chat/memory_activator.py b/src/chat/focus_chat/memory_activator.py index 4f57286b8..26178d961 100644 --- a/src/chat/focus_chat/memory_activator.py +++ b/src/chat/focus_chat/memory_activator.py @@ -90,7 +90,7 @@ class MemoryActivator: # 如果记忆系统被禁用,直接返回空列表 if not global_config.memory.enable_memory: return [] - + obs_info_text = "" for observation in observations: if isinstance(observation, ChattingObservation): diff --git a/src/chat/focus_chat/planners/action_manager.py b/src/chat/focus_chat/planners/action_manager.py index b45300710..a848b5fd9 100644 --- a/src/chat/focus_chat/planners/action_manager.py +++ b/src/chat/focus_chat/planners/action_manager.py @@ -5,9 +5,6 @@ from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.message_receive.chat_stream import ChatStream from src.common.logger_manager import get_logger -import importlib -import pkgutil -import os # 不再需要导入动作类,因为已经在main.py中导入 # import src.chat.actions.default_actions # noqa @@ -41,7 +38,7 @@ class ActionManager: # 初始化时将默认动作加载到使用中的动作 self._using_actions = self._default_actions.copy() - + # 添加系统核心动作 self._add_system_core_actions() @@ -63,19 +60,19 @@ class ActionManager: action_require: list[str] = getattr(action_class, "action_require", []) associated_types: list[str] = getattr(action_class, "associated_types", []) is_enabled: bool = getattr(action_class, "enable_plugin", True) - + # 获取激活类型相关属性 focus_activation_type: str = getattr(action_class, "focus_activation_type", "always") normal_activation_type: str = getattr(action_class, "normal_activation_type", "always") - + random_probability: float = getattr(action_class, "random_activation_probability", 0.3) llm_judge_prompt: str = getattr(action_class, "llm_judge_prompt", "") activation_keywords: list[str] = getattr(action_class, "activation_keywords", []) keyword_case_sensitive: bool = getattr(action_class, "keyword_case_sensitive", False) - + # 获取模式启用属性 mode_enable: str = getattr(action_class, "mode_enable", "all") - + # 获取并行执行属性 parallel_action: bool = getattr(action_class, "parallel_action", False) @@ -114,13 +111,13 @@ class ActionManager: def _load_plugin_actions(self) -> None: """ 加载所有插件目录中的动作 - + 注意:插件动作的实际导入已经在main.py中完成,这里只需要从_ACTION_REGISTRY获取 """ try: # 插件动作已在main.py中加载,这里只需要从_ACTION_REGISTRY获取 self._load_registered_actions() - logger.info(f"从注册表加载插件动作成功") + logger.info("从注册表加载插件动作成功") except Exception as e: logger.error(f"加载插件动作失败: {e}") @@ -203,25 +200,25 @@ class ActionManager: def get_using_actions_for_mode(self, mode: str) -> Dict[str, ActionInfo]: """ 根据聊天模式获取可用的动作集合 - + Args: mode: 聊天模式 ("focus", "normal", "all") - + Returns: Dict[str, ActionInfo]: 在指定模式下可用的动作集合 """ filtered_actions = {} - + for action_name, action_info in self._using_actions.items(): action_mode = action_info.get("mode_enable", "all") - + # 检查动作是否在当前模式下启用 if action_mode == "all" or action_mode == mode: filtered_actions[action_name] = action_info logger.debug(f"动作 {action_name} 在模式 {mode} 下可用 (mode_enable: {action_mode})") else: logger.debug(f"动作 {action_name} 在模式 {mode} 下不可用 (mode_enable: {action_mode})") - + logger.debug(f"模式 {mode} 下可用动作: {list(filtered_actions.keys())}") return filtered_actions @@ -325,7 +322,7 @@ class ActionManager: 系统核心动作是那些enable_plugin为False但是系统必需的动作 """ system_core_actions = ["exit_focus_chat"] # 可以根据需要扩展 - + for action_name in system_core_actions: if action_name in self._registered_actions and action_name not in self._using_actions: self._using_actions[action_name] = self._registered_actions[action_name] @@ -334,10 +331,10 @@ class ActionManager: def add_system_action_if_needed(self, action_name: str) -> bool: """ 根据需要添加系统动作到使用集 - + Args: action_name: 动作名称 - + Returns: bool: 是否成功添加 """ diff --git a/src/chat/focus_chat/planners/modify_actions.py b/src/chat/focus_chat/planners/modify_actions.py index 5ab398a56..4be4af786 100644 --- a/src/chat/focus_chat/planners/modify_actions.py +++ b/src/chat/focus_chat/planners/modify_actions.py @@ -30,13 +30,13 @@ class ActionModifier: """初始化动作处理器""" self.action_manager = action_manager self.all_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS) - + # 用于LLM判定的小模型 self.llm_judge = LLMRequest( model=global_config.model.utils_small, request_type="action.judge", ) - + # 缓存相关属性 self._llm_judge_cache = {} # 缓存LLM判定结果 self._cache_expiry_time = 30 # 缓存过期时间(秒) @@ -49,15 +49,15 @@ class ActionModifier: ): """ 完整的动作修改流程,整合传统观察处理和新的激活类型判定 - + 这个方法处理完整的动作管理流程: 1. 基于观察的传统动作修改(循环历史分析、类型匹配等) 2. 基于激活类型的智能动作判定,最终确定可用动作集 - + 处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用 """ logger.debug(f"{self.log_prefix}开始完整动作修改流程") - + # === 第一阶段:传统观察处理 === if observations: hfc_obs = None @@ -86,7 +86,7 @@ class ActionModifier: merged_action_changes["add"].extend(action_changes["add"]) merged_action_changes["remove"].extend(action_changes["remove"]) reasons.append("基于循环历史分析") - + # 详细记录循环历史分析的变更原因 for action_name in action_changes["add"]: logger.info(f"{self.log_prefix}添加动作: {action_name},原因: 循环历史分析建议添加") @@ -106,7 +106,9 @@ class ActionModifier: if not chat_context.check_types(data["associated_types"]): type_mismatched_actions.append(action_name) associated_types_str = ", ".join(data["associated_types"]) - logger.info(f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str})") + logger.info( + f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str})" + ) if type_mismatched_actions: # 合并到移除列表中 @@ -123,17 +125,19 @@ class ActionModifier: self.action_manager.remove_action_from_using(action_name) logger.debug(f"{self.log_prefix}应用移除动作: {action_name},原因集合: {reasons}") - logger.info(f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}") + logger.info( + f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}" + ) # === 第二阶段:激活类型判定 === # 如果提供了聊天上下文,则进行激活类型判定 if chat_content is not None: logger.debug(f"{self.log_prefix}开始激活类型判定阶段") - + # 获取当前使用的动作集(经过第一阶段处理,且适用于FOCUS模式) current_using_actions = self.action_manager.get_using_actions() all_registered_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS) - + # 构建完整的动作信息 current_actions_with_info = {} for action_name in current_using_actions.keys(): @@ -141,17 +145,17 @@ class ActionModifier: current_actions_with_info[action_name] = all_registered_actions[action_name] else: logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到") - + # 应用激活类型判定 final_activated_actions = await self._apply_activation_type_filtering( current_actions_with_info, chat_content, ) - + # 更新ActionManager,移除未激活的动作 actions_to_remove = [] removal_reasons = {} - + for action_name in current_using_actions.keys(): if action_name not in final_activated_actions: actions_to_remove.append(action_name) @@ -159,7 +163,7 @@ class ActionModifier: if action_name in all_registered_actions: action_info = all_registered_actions[action_name] activation_type = action_info.get("focus_activation_type", ActionActivationType.ALWAYS) - + if activation_type == ActionActivationType.RANDOM: probability = action_info.get("random_probability", 0.3) removal_reasons[action_name] = f"RANDOM类型未触发(概率{probability})" @@ -172,15 +176,17 @@ class ActionModifier: removal_reasons[action_name] = "激活判定未通过" else: removal_reasons[action_name] = "动作信息不完整" - + for action_name in actions_to_remove: self.action_manager.remove_action_from_using(action_name) reason = removal_reasons.get(action_name, "未知原因") logger.info(f"{self.log_prefix}移除动作: {action_name},原因: {reason}") - + logger.info(f"{self.log_prefix}激活类型判定完成,最终可用动作: {list(final_activated_actions.keys())}") - - logger.info(f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}") + + logger.info( + f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}" + ) async def _apply_activation_type_filtering( self, @@ -189,27 +195,27 @@ class ActionModifier: ) -> Dict[str, Any]: """ 应用激活类型过滤逻辑,支持四种激活类型的并行处理 - + Args: actions_with_info: 带完整信息的动作字典 observed_messages_str: 观察到的聊天消息 chat_context: 聊天上下文信息 extra_context: 额外的上下文信息 - + Returns: Dict[str, Any]: 过滤后激活的actions字典 """ activated_actions = {} - + # 分类处理不同激活类型的actions always_actions = {} random_actions = {} llm_judge_actions = {} keyword_actions = {} - + for action_name, action_info in actions_with_info.items(): activation_type = action_info.get("focus_activation_type", ActionActivationType.ALWAYS) - + if activation_type == ActionActivationType.ALWAYS: always_actions[action_name] = action_info elif activation_type == ActionActivationType.RANDOM: @@ -220,12 +226,12 @@ class ActionModifier: keyword_actions[action_name] = action_info else: logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理") - + # 1. 处理ALWAYS类型(直接激活) for action_name, action_info in always_actions.items(): activated_actions[action_name] = action_info logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: ALWAYS类型直接激活") - + # 2. 处理RANDOM类型 for action_name, action_info in random_actions.items(): probability = action_info.get("random_probability", 0.3) @@ -235,7 +241,7 @@ class ActionModifier: logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发(概率{probability})") else: logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: RANDOM类型未触发(概率{probability})") - + # 3. 处理KEYWORD类型(快速判定) for action_name, action_info in keyword_actions.items(): should_activate = self._check_keyword_activation( @@ -250,7 +256,7 @@ class ActionModifier: else: keywords = action_info.get("activation_keywords", []) logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词({keywords})") - + # 4. 处理LLM_JUDGE类型(并行判定) if llm_judge_actions: # 直接并行处理所有LLM判定actions @@ -258,7 +264,7 @@ class ActionModifier: llm_judge_actions, chat_content, ) - + # 添加激活的LLM判定actions for action_name, should_activate in llm_results.items(): if should_activate: @@ -266,46 +272,43 @@ class ActionModifier: logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: LLM_JUDGE类型判定通过") else: logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: LLM_JUDGE类型判定未通过") - + logger.debug(f"{self.log_prefix}激活类型过滤完成: {list(activated_actions.keys())}") return activated_actions async def process_actions_for_planner( - self, - observed_messages_str: str = "", - chat_context: Optional[str] = None, - extra_context: Optional[str] = None + self, observed_messages_str: str = "", chat_context: Optional[str] = None, extra_context: Optional[str] = None ) -> Dict[str, Any]: """ [已废弃] 此方法现在已被整合到 modify_actions() 中 - + 为了保持向后兼容性而保留,但建议直接使用 ActionManager.get_using_actions() 规划器应该直接从 ActionManager 获取最终的可用动作集,而不是调用此方法 - + 新的架构: 1. 主循环调用 modify_actions() 处理完整的动作管理流程 2. 规划器直接使用 ActionManager.get_using_actions() 获取最终动作集 """ - logger.warning(f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()") - + logger.warning( + f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()" + ) + # 为了向后兼容,仍然返回当前使用的动作集 current_using_actions = self.action_manager.get_using_actions() all_registered_actions = self.action_manager.get_registered_actions() - + # 构建完整的动作信息 result = {} for action_name in current_using_actions.keys(): if action_name in all_registered_actions: result[action_name] = all_registered_actions[action_name] - + return result def _generate_context_hash(self, chat_content: str) -> str: """生成上下文的哈希值用于缓存""" context_content = f"{chat_content}" - return hashlib.md5(context_content.encode('utf-8')).hexdigest() - - + return hashlib.md5(context_content.encode("utf-8")).hexdigest() async def _process_llm_judge_actions_parallel( self, @@ -314,85 +317,85 @@ class ActionModifier: ) -> Dict[str, bool]: """ 并行处理LLM判定actions,支持智能缓存 - + Args: llm_judge_actions: 需要LLM判定的actions observed_messages_str: 观察到的聊天消息 chat_context: 聊天上下文 extra_context: 额外上下文 - + Returns: Dict[str, bool]: action名称到激活结果的映射 """ - + # 生成当前上下文的哈希值 current_context_hash = self._generate_context_hash(chat_content) current_time = time.time() - + results = {} tasks_to_run = {} - + # 检查缓存 for action_name, action_info in llm_judge_actions.items(): cache_key = f"{action_name}_{current_context_hash}" - + # 检查是否有有效的缓存 - if (cache_key in self._llm_judge_cache and - current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time): - + if ( + cache_key in self._llm_judge_cache + and current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time + ): results[action_name] = self._llm_judge_cache[cache_key]["result"] - logger.debug(f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}") + logger.debug( + f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}" + ) else: # 需要进行LLM判定 tasks_to_run[action_name] = action_info - + # 如果有需要运行的任务,并行执行 if tasks_to_run: logger.debug(f"{self.log_prefix}并行执行LLM判定,任务数: {len(tasks_to_run)}") - + # 创建并行任务 tasks = [] task_names = [] - + for action_name, action_info in tasks_to_run.items(): task = self._llm_judge_action( - action_name, - action_info, - chat_content, + action_name, + action_info, + chat_content, ) tasks.append(task) task_names.append(action_name) - + # 并行执行所有任务 try: task_results = await asyncio.gather(*tasks, return_exceptions=True) - + # 处理结果并更新缓存 - for i, (action_name, result) in enumerate(zip(task_names, task_results)): + for _, (action_name, result) in enumerate(zip(task_names, task_results)): if isinstance(result, Exception): logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}") results[action_name] = False else: results[action_name] = result - + # 更新缓存 cache_key = f"{action_name}_{current_context_hash}" - self._llm_judge_cache[cache_key] = { - "result": result, - "timestamp": current_time - } - + self._llm_judge_cache[cache_key] = {"result": result, "timestamp": current_time} + logger.debug(f"{self.log_prefix}并行LLM判定完成,耗时: {time.time() - current_time:.2f}s") - + except Exception as e: logger.error(f"{self.log_prefix}并行LLM判定失败: {e}") # 如果并行执行失败,为所有任务返回False for action_name in tasks_to_run.keys(): results[action_name] = False - + # 清理过期缓存 self._cleanup_expired_cache(current_time) - + return results def _cleanup_expired_cache(self, current_time: float): @@ -401,40 +404,39 @@ class ActionModifier: for cache_key, cache_data in self._llm_judge_cache.items(): if current_time - cache_data["timestamp"] > self._cache_expiry_time: expired_keys.append(cache_key) - + for key in expired_keys: del self._llm_judge_cache[key] - + if expired_keys: logger.debug(f"{self.log_prefix}清理了 {len(expired_keys)} 个过期缓存条目") async def _llm_judge_action( - self, - action_name: str, + self, + action_name: str, action_info: Dict[str, Any], chat_content: str = "", ) -> bool: """ 使用LLM判定是否应该激活某个action - + Args: action_name: 动作名称 action_info: 动作信息 observed_messages_str: 观察到的聊天消息 chat_context: 聊天上下文 extra_context: 额外上下文 - + Returns: bool: 是否应该激活此action """ - + try: # 构建判定提示词 action_description = action_info.get("description", "") action_require = action_info.get("require", []) custom_prompt = action_info.get("llm_judge_prompt", "") - - + # 构建基础判定提示词 base_prompt = f""" 你需要判断在当前聊天情况下,是否应该激活名为"{action_name}"的动作。 @@ -445,34 +447,34 @@ class ActionModifier: """ for req in action_require: base_prompt += f"- {req}\n" - + if custom_prompt: base_prompt += f"\n额外判定条件:\n{custom_prompt}\n" - + if chat_content: base_prompt += f"\n当前聊天记录:\n{chat_content}\n" - - + base_prompt += """ 请根据以上信息判断是否应该激活这个动作。 只需要回答"是"或"否",不要有其他内容。 """ - + # 调用LLM进行判定 response, _ = await self.llm_judge.generate_response_async(prompt=base_prompt) - + # 解析响应 response = response.strip().lower() - + # print(base_prompt) print(f"LLM判定动作 {action_name}:响应='{response}'") - - + should_activate = "是" in response or "yes" in response or "true" in response - - logger.debug(f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}") + + logger.debug( + f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}" + ) return should_activate - + except Exception as e: logger.error(f"{self.log_prefix}LLM判定动作 {action_name} 时出错: {e}") # 出错时默认不激活 @@ -486,45 +488,45 @@ class ActionModifier: ) -> bool: """ 检查是否匹配关键词触发条件 - + Args: action_name: 动作名称 action_info: 动作信息 observed_messages_str: 观察到的聊天消息 chat_context: 聊天上下文 extra_context: 额外上下文 - + Returns: bool: 是否应该激活此action """ - + activation_keywords = action_info.get("activation_keywords", []) case_sensitive = action_info.get("keyword_case_sensitive", False) - + if not activation_keywords: logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词") return False - + # 构建检索文本 search_text = "" if chat_content: search_text += chat_content # if chat_context: - # search_text += f" {chat_context}" + # search_text += f" {chat_context}" # if extra_context: - # search_text += f" {extra_context}" - + # search_text += f" {extra_context}" + # 如果不区分大小写,转换为小写 if not case_sensitive: search_text = search_text.lower() - + # 检查每个关键词 matched_keywords = [] for keyword in activation_keywords: check_keyword = keyword if case_sensitive else keyword.lower() if check_keyword in search_text: matched_keywords.append(keyword) - + if matched_keywords: logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}") return True @@ -568,7 +570,9 @@ class ActionModifier: result["remove"].append("no_reply") result["remove"].append("reply") no_reply_ratio = no_reply_count / len(recent_cycles) - logger.info(f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f},达到退出聊天阈值,将添加exit_focus_chat并移除no_reply/reply动作") + logger.info( + f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f},达到退出聊天阈值,将添加exit_focus_chat并移除no_reply/reply动作" + ) # 计算连续回复的相关阈值 @@ -593,7 +597,7 @@ class ActionModifier: if len(last_max_reply_num) >= max_reply_num and all(last_max_reply_num): # 如果最近max_reply_num次都是reply,直接移除 result["remove"].append("reply") - reply_count = len(last_max_reply_num) - no_reply_count + # reply_count = len(last_max_reply_num) - no_reply_count logger.info( f"{self.log_prefix}移除reply动作,原因: 连续回复过多(最近{len(last_max_reply_num)}次全是reply,超过阈值{max_reply_num})" ) @@ -622,8 +626,6 @@ class ActionModifier: f"{self.log_prefix}连续回复检测:最近{one_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,未触发" ) else: - logger.debug( - f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常" - ) + logger.debug(f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常") return result diff --git a/src/chat/focus_chat/planners/planner_simple.py b/src/chat/focus_chat/planners/planner_simple.py index bfb0420fa..7154c7ecc 100644 --- a/src/chat/focus_chat/planners/planner_simple.py +++ b/src/chat/focus_chat/planners/planner_simple.py @@ -146,7 +146,7 @@ class ActionPlanner(BasePlanner): # 注意:动作的激活判定现在在主循环的modify_actions中完成 # 使用Focus模式过滤动作 current_available_actions_dict = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS) - + # 获取完整的动作信息 all_registered_actions = self.action_manager.get_registered_actions() current_available_actions = {} @@ -192,12 +192,11 @@ class ActionPlanner(BasePlanner): try: prompt = f"{prompt}" llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt) - + logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}") logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}") - - + except Exception as req_e: logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}") reasoning = f"LLM 请求失败,你的模型出现问题: {req_e}" @@ -237,10 +236,10 @@ class ActionPlanner(BasePlanner): extra_info_block = "" action_data["extra_info_block"] = extra_info_block - + if relation_info: action_data["relation_info_block"] = relation_info - + # 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data if extracted_action not in current_available_actions: @@ -303,12 +302,11 @@ class ActionPlanner(BasePlanner): ) -> str: """构建 Planner LLM 的提示词 (获取模板并填充数据)""" try: - if relation_info_block: relation_info_block = f"以下是你和别人的关系描述:\n{relation_info_block}" else: relation_info_block = "" - + memory_str = "" if running_memorys: memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" @@ -331,9 +329,9 @@ class ActionPlanner(BasePlanner): # mind_info_block = "" # if current_mind: - # mind_info_block = f"对聊天的规划:{current_mind}" + # mind_info_block = f"对聊天的规划:{current_mind}" # else: - # mind_info_block = "你刚参与聊天" + # mind_info_block = "你刚参与聊天" personality_block = individuality.get_prompt(x_person=2, level=2) @@ -351,16 +349,14 @@ class ActionPlanner(BasePlanner): param_text = "\n" for param_name, param_description in using_actions_info["parameters"].items(): param_text += f' "{param_name}":"{param_description}"\n' - param_text = param_text.rstrip('\n') + param_text = param_text.rstrip("\n") else: param_text = "" - require_text = "" for require_item in using_actions_info["require"]: require_text += f"- {require_item}\n" - require_text = require_text.rstrip('\n') - + require_text = require_text.rstrip("\n") using_action_prompt = using_action_prompt.format( action_name=using_actions_name, diff --git a/src/chat/focus_chat/replyer/default_replyer.py b/src/chat/focus_chat/replyer/default_replyer.py index a9424a910..a591a26c5 100644 --- a/src/chat/focus_chat/replyer/default_replyer.py +++ b/src/chat/focus_chat/replyer/default_replyer.py @@ -93,7 +93,7 @@ class DefaultReplyer: self.chat_id = chat_stream.stream_id self.chat_stream = chat_stream - self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id) + self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id) async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str): """创建思考消息 (尝试锚定到 anchor_message)""" @@ -141,7 +141,7 @@ class DefaultReplyer: # text_part = action_data.get("text", []) # if text_part: sent_msg_list = [] - + with Timer("生成回复", cycle_timers): # 可以保留原有的文本处理逻辑或进行适当调整 reply = await self.reply( @@ -240,22 +240,21 @@ class DefaultReplyer: # current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier # self.express_model.params["temperature"] = current_temp # 动态调整温度 - reply_to = action_data.get("reply_to", "none") - + sender = "" targer = "" if ":" in reply_to or ":" in reply_to: # 使用正则表达式匹配中文或英文冒号 - parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1) + parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1) if len(parts) == 2: sender = parts[0].strip() targer = parts[1].strip() - + identity = action_data.get("identity", "") extra_info_block = action_data.get("extra_info_block", "") relation_info_block = action_data.get("relation_info_block", "") - + # 3. 构建 Prompt with Timer("构建Prompt", {}): # 内部计时器,可选保留 prompt = await self.build_prompt_focus( @@ -374,8 +373,6 @@ class DefaultReplyer: style_habbits_str = "\n".join(style_habbits) grammar_habbits_str = "\n".join(grammar_habbits) - - # 关键词检测与反应 keywords_reaction_prompt = "" @@ -407,16 +404,15 @@ class DefaultReplyer: time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" # logger.debug("开始构建 focus prompt") - + if sender_name: - reply_target_block = f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。" + reply_target_block = ( + f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。" + ) elif target_message: reply_target_block = f"现在{target_message}引起了你的注意,你想要在群里发言或者回复这条消息。" else: reply_target_block = "现在,你想要在群里发言或者回复消息。" - - - # --- Choose template based on chat type --- if is_group_chat: @@ -665,30 +661,30 @@ def find_similar_expressions(input_text: str, expressions: List[Dict], top_k: in """使用TF-IDF和余弦相似度找出与输入文本最相似的top_k个表达方式""" if not expressions: return [] - + # 准备文本数据 - texts = [expr['situation'] for expr in expressions] + texts = [expr["situation"] for expr in expressions] texts.append(input_text) # 添加输入文本 - + # 使用TF-IDF向量化 vectorizer = TfidfVectorizer() tfidf_matrix = vectorizer.fit_transform(texts) - + # 计算余弦相似度 similarity_matrix = cosine_similarity(tfidf_matrix) - + # 获取输入文本的相似度分数(最后一行) scores = similarity_matrix[-1][:-1] # 排除与自身的相似度 - + # 获取top_k的索引 top_indices = np.argsort(scores)[::-1][:top_k] - + # 获取相似表达 similar_exprs = [] for idx in top_indices: if scores[idx] > 0: # 只保留有相似度的 similar_exprs.append(expressions[idx]) - + return similar_exprs diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py index 593a238b5..72dbb596f 100644 --- a/src/chat/heart_flow/observation/chatting_observation.py +++ b/src/chat/heart_flow/observation/chatting_observation.py @@ -62,13 +62,12 @@ class ChattingObservation(Observation): self.oldest_messages = [] self.oldest_messages_str = "" self.compressor_prompt = "" - + initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10) self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time self.talking_message = initial_messages self.talking_message_str = build_readable_messages(self.talking_message, show_actions=True) - def to_dict(self) -> dict: """将观察对象转换为可序列化的字典""" return { @@ -283,7 +282,7 @@ class ChattingObservation(Observation): show_actions=True, ) # print(f"构建中:self.talking_message_str_truncate: {self.talking_message_str_truncate}") - + self.person_list = await get_person_id_list(self.talking_message) # print(f"构建中:self.person_list: {self.person_list}") diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py index f4cde94af..d94f94f75 100644 --- a/src/chat/heart_flow/sub_heartflow.py +++ b/src/chat/heart_flow/sub_heartflow.py @@ -42,9 +42,7 @@ class SubHeartflow: self.history_chat_state: List[Tuple[ChatState, float]] = [] self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id) - self.log_prefix = ( - chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id - ) + self.log_prefix = chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id # 兴趣消息集合 self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {} @@ -199,7 +197,6 @@ class SubHeartflow: # 如果实例不存在,则创建并启动 logger.info(f"{log_prefix} 麦麦准备开始专注聊天...") try: - self.heart_fc_instance = HeartFChatting( chat_id=self.subheartflow_id, # observations=self.observations, diff --git a/src/chat/heart_flow/utils_chat.py b/src/chat/heart_flow/utils_chat.py index 527e6aafb..7289db1a8 100644 --- a/src/chat/heart_flow/utils_chat.py +++ b/src/chat/heart_flow/utils_chat.py @@ -23,7 +23,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: chat_target_info = None try: - chat_stream = chat_manager.get_stream(chat_id) + chat_stream = chat_manager.get_stream(chat_id) if chat_stream: if chat_stream.group_info: diff --git a/src/chat/knowledge/raw_processing.py b/src/chat/knowledge/raw_processing.py index ffdcf814b..a5ac45dcc 100644 --- a/src/chat/knowledge/raw_processing.py +++ b/src/chat/knowledge/raw_processing.py @@ -3,7 +3,7 @@ import os from .global_logger import logger from .lpmmconfig import global_config -from src.chat.knowledge.utils import get_sha256 +from src.chat.knowledge.utils.hash import get_sha256 def load_raw_data(path: str = None) -> tuple[list[str], list[str]]: diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 1a6c2bcf8..debb0e0ca 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -346,7 +346,9 @@ class Hippocampus: # 使用LLM提取关键词 topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 # logger.info(f"提取关键词数量: {topic_num}") - topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num)) + topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async( + self.find_topic_llm(text, topic_num) + ) # 提取关键词 keywords = re.findall(r"<([^>]+)>", topics_response) @@ -701,7 +703,9 @@ class Hippocampus: # 使用LLM提取关键词 topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 # logger.info(f"提取关键词数量: {topic_num}") - topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num)) + topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async( + self.find_topic_llm(text, topic_num) + ) # 提取关键词 keywords = re.findall(r"<([^>]+)>", topics_response) @@ -893,7 +897,7 @@ class EntorhinalCortex: # 获取数据库中所有节点和内存中所有节点 db_nodes = {node.concept: node for node in GraphNodes.select()} memory_nodes = list(self.memory_graph.G.nodes(data=True)) - + # 批量准备节点数据 nodes_to_create = [] nodes_to_update = [] @@ -929,22 +933,26 @@ class EntorhinalCortex: continue if concept not in db_nodes: - nodes_to_create.append({ - "concept": concept, - "memory_items": memory_items_json, - "hash": memory_hash, - "created_time": created_time, - "last_modified": last_modified, - }) - else: - db_node = db_nodes[concept] - if db_node.hash != memory_hash: - nodes_to_update.append({ + nodes_to_create.append( + { "concept": concept, "memory_items": memory_items_json, "hash": memory_hash, + "created_time": created_time, "last_modified": last_modified, - }) + } + ) + else: + db_node = db_nodes[concept] + if db_node.hash != memory_hash: + nodes_to_update.append( + { + "concept": concept, + "memory_items": memory_items_json, + "hash": memory_hash, + "last_modified": last_modified, + } + ) # 计算需要删除的节点 memory_concepts = {concept for concept, _ in memory_nodes} @@ -954,13 +962,13 @@ class EntorhinalCortex: if nodes_to_create: batch_size = 100 for i in range(0, len(nodes_to_create), batch_size): - batch = nodes_to_create[i:i + batch_size] + batch = nodes_to_create[i : i + batch_size] GraphNodes.insert_many(batch).execute() if nodes_to_update: batch_size = 100 for i in range(0, len(nodes_to_update), batch_size): - batch = nodes_to_update[i:i + batch_size] + batch = nodes_to_update[i : i + batch_size] for node_data in batch: GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where( GraphNodes.concept == node_data["concept"] @@ -992,22 +1000,26 @@ class EntorhinalCortex: last_modified = data.get("last_modified", current_time) if edge_key not in db_edge_dict: - edges_to_create.append({ - "source": source, - "target": target, - "strength": strength, - "hash": edge_hash, - "created_time": created_time, - "last_modified": last_modified, - }) + edges_to_create.append( + { + "source": source, + "target": target, + "strength": strength, + "hash": edge_hash, + "created_time": created_time, + "last_modified": last_modified, + } + ) elif db_edge_dict[edge_key]["hash"] != edge_hash: - edges_to_update.append({ - "source": source, - "target": target, - "strength": strength, - "hash": edge_hash, - "last_modified": last_modified, - }) + edges_to_update.append( + { + "source": source, + "target": target, + "strength": strength, + "hash": edge_hash, + "last_modified": last_modified, + } + ) # 计算需要删除的边 memory_edge_keys = {(source, target) for source, target, _ in memory_edges} @@ -1017,13 +1029,13 @@ class EntorhinalCortex: if edges_to_create: batch_size = 100 for i in range(0, len(edges_to_create), batch_size): - batch = edges_to_create[i:i + batch_size] + batch = edges_to_create[i : i + batch_size] GraphEdges.insert_many(batch).execute() if edges_to_update: batch_size = 100 for i in range(0, len(edges_to_update), batch_size): - batch = edges_to_update[i:i + batch_size] + batch = edges_to_update[i : i + batch_size] for edge_data in batch: GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where( (GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"]) @@ -1031,9 +1043,7 @@ class EntorhinalCortex: if edges_to_delete: for source, target in edges_to_delete: - GraphEdges.delete().where( - (GraphEdges.source == source) & (GraphEdges.target == target) - ).execute() + GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute() end_time = time.time() logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒") @@ -1069,13 +1079,15 @@ class EntorhinalCortex: if not memory_items_json: continue - nodes_data.append({ - "concept": concept, - "memory_items": memory_items_json, - "hash": self.hippocampus.calculate_node_hash(concept, memory_items), - "created_time": data.get("created_time", current_time), - "last_modified": data.get("last_modified", current_time), - }) + nodes_data.append( + { + "concept": concept, + "memory_items": memory_items_json, + "hash": self.hippocampus.calculate_node_hash(concept, memory_items), + "created_time": data.get("created_time", current_time), + "last_modified": data.get("last_modified", current_time), + } + ) except Exception as e: logger.error(f"准备节点 {concept} 数据时发生错误: {e}") continue @@ -1084,14 +1096,16 @@ class EntorhinalCortex: edges_data = [] for source, target, data in memory_edges: try: - edges_data.append({ - "source": source, - "target": target, - "strength": data.get("strength", 1), - "hash": self.hippocampus.calculate_edge_hash(source, target), - "created_time": data.get("created_time", current_time), - "last_modified": data.get("last_modified", current_time), - }) + edges_data.append( + { + "source": source, + "target": target, + "strength": data.get("strength", 1), + "hash": self.hippocampus.calculate_edge_hash(source, target), + "created_time": data.get("created_time", current_time), + "last_modified": data.get("last_modified", current_time), + } + ) except Exception as e: logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}") continue @@ -1102,7 +1116,7 @@ class EntorhinalCortex: batch_size = 500 # 增加批量大小 with GraphNodes._meta.database.atomic(): for i in range(0, len(nodes_data), batch_size): - batch = nodes_data[i:i + batch_size] + batch = nodes_data[i : i + batch_size] GraphNodes.insert_many(batch).execute() node_end = time.time() logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒") @@ -1113,7 +1127,7 @@ class EntorhinalCortex: batch_size = 500 # 增加批量大小 with GraphEdges._meta.database.atomic(): for i in range(0, len(edges_data), batch_size): - batch = edges_data[i:i + batch_size] + batch = edges_data[i : i + batch_size] GraphEdges.insert_many(batch).execute() edge_end = time.time() logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒") diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index b7a292c41..29d571905 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -79,7 +79,7 @@ class ChatBot: group_info = message.message_info.group_info user_info = message.message_info.user_info chat_manager.register_message(message) - + # 创建聊天流 chat = await chat_manager.get_or_create_stream( platform=message.message_info.platform, @@ -87,13 +87,13 @@ class ChatBot: group_info=group_info, ) message.update_chat_stream(chat) - + # 处理消息内容,生成纯文本 await message.process() - + # 命令处理 - 在消息处理的早期阶段检查并处理命令 is_command, cmd_result, continue_process = await command_manager.process_command(message) - + # 如果是命令且不需要继续处理,则直接返回 if is_command and not continue_process: logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}") diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 03b2e4361..8c05a9ab0 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -24,7 +24,6 @@ class MessageStorage: else: filtered_processed_plain_text = "" - if isinstance(message, MessageSending): display_message = message.display_message if display_message: diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 2babc500b..7d37f7ead 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -13,8 +13,6 @@ from src.chat.utils.prompt_builder import global_prompt_manager from .normal_chat_generator import NormalChatGenerator from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet from src.chat.message_receive.message_sender import message_manager -from src.chat.utils.utils_image import image_path_to_base64 -from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.normal_chat.willing.willing_manager import willing_manager from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats from src.config.config import global_config @@ -69,7 +67,7 @@ class NormalChat: self.on_switch_to_focus_callback = on_switch_to_focus_callback self._disabled = False # 增加停用标志 - + logger.debug(f"[{self.stream_name}] NormalChat 初始化完成 (异步部分)。") # 改为实例方法 @@ -193,7 +191,9 @@ class NormalChat: return timing_results = {} - reply_probability = 1.0 if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply else 0.0 # 如果被提及,且开启了提及必回复,则基础概率为1,否则需要意愿判断 + reply_probability = ( + 1.0 if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply else 0.0 + ) # 如果被提及,且开启了提及必回复,则基础概率为1,否则需要意愿判断 # 意愿管理器:设置当前message信息 willing_manager.setup(message, self.chat_stream, is_mentioned, interested_rate) @@ -267,13 +267,17 @@ class NormalChat: try: # 获取发送者名称(动作修改已在并行执行前完成) sender_name = self._get_sender_name(message) - + no_action = { - "action_result": {"action_type": "no_action", "action_data": {}, "reasoning": "规划器初始化默认", "is_parallel": True}, + "action_result": { + "action_type": "no_action", + "action_data": {}, + "reasoning": "规划器初始化默认", + "is_parallel": True, + }, "chat_context": "", "action_prompt": "", } - # 检查是否应该跳过规划 if self.action_modifier.should_skip_planning(): @@ -288,7 +292,9 @@ class NormalChat: reasoning = plan_result["action_result"]["reasoning"] is_parallel = plan_result["action_result"].get("is_parallel", False) - logger.info(f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}") + logger.info( + f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}" + ) self.action_type = action_type # 更新实例属性 self.is_parallel_action = is_parallel # 新增:保存并行执行标志 @@ -307,7 +313,12 @@ class NormalChat: else: logger.warning(f"[{self.stream_name}] 额外动作 {action_type} 执行失败") - return {"action_type": action_type, "action_data": action_data, "reasoning": reasoning, "is_parallel": is_parallel} + return { + "action_type": action_type, + "action_data": action_data, + "reasoning": reasoning, + "is_parallel": is_parallel, + } except Exception as e: logger.error(f"[{self.stream_name}] Planner执行失败: {e}") @@ -331,13 +342,19 @@ class NormalChat: logger.error(f"[{self.stream_name}] 动作规划异常: {plan_result}") elif plan_result: logger.debug(f"[{self.stream_name}] 额外动作处理完成: {self.action_type}") - + if not response_set or ( - self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action + self.enable_planner + and self.action_type not in ["no_action", "change_to_focus_chat"] + and not self.is_parallel_action ): if not response_set: logger.info(f"[{self.stream_name}] 模型未生成回复内容") - elif self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action: + elif ( + self.enable_planner + and self.action_type not in ["no_action", "change_to_focus_chat"] + and not self.is_parallel_action + ): logger.info(f"[{self.stream_name}] 模型选择其他动作(非并行动作)") # 如果模型未生成回复,移除思考消息 container = await message_manager.get_container(self.stream_id) # 使用 self.stream_id @@ -364,7 +381,6 @@ class NormalChat: # 检查 first_bot_msg 是否为 None (例如思考消息已被移除的情况) if first_bot_msg: - # 记录回复信息到最近回复列表中 reply_info = { "time": time.time(), @@ -396,7 +412,6 @@ class NormalChat: await self._check_switch_to_focus() pass - # with Timer("关系更新", timing_results): # await self._update_relationship(message, response_set) @@ -605,7 +620,7 @@ class NormalChat: # 执行动作 result = await action_handler.handle_action() success = False - + if result and isinstance(result, tuple) and len(result) >= 2: # handle_action返回 (success: bool, message: str) success = result[0] diff --git a/src/chat/normal_chat/normal_chat_action_modifier.py b/src/chat/normal_chat/normal_chat_action_modifier.py index 78593c1f5..b13c1ee41 100644 --- a/src/chat/normal_chat/normal_chat_action_modifier.py +++ b/src/chat/normal_chat/normal_chat_action_modifier.py @@ -35,7 +35,7 @@ class NormalChatActionModifier: **kwargs: Any, ): """为Normal Chat修改可用动作集合 - + 实现动作激活策略: 1. 基于关联类型的动态过滤 2. 基于激活类型的智能判定(LLM_JUDGE转为概率激活) @@ -49,7 +49,7 @@ class NormalChatActionModifier: reasons = [] merged_action_changes = {"add": [], "remove": []} type_mismatched_actions = [] # 在外层定义避免作用域问题 - + self.action_manager.restore_default_actions() # 第一阶段:基于关联类型的动态过滤 @@ -74,7 +74,7 @@ class NormalChatActionModifier: # 第二阶段:应用激活类型判定 # 构建聊天内容 - 使用与planner一致的方式 chat_content = "" - if chat_stream and hasattr(chat_stream, 'stream_id'): + if chat_stream and hasattr(chat_stream, "stream_id"): try: # 获取消息历史,使用与normal_chat_planner相同的方法 message_list_before_now = get_raw_msg_before_timestamp_with_chat( @@ -82,7 +82,7 @@ class NormalChatActionModifier: timestamp=time.time(), limit=global_config.focus_chat.observation_context_size, # 使用相同的配置 ) - + # 构建可读的聊天上下文 chat_content = build_readable_messages( message_list_before_now, @@ -92,39 +92,41 @@ class NormalChatActionModifier: read_mark=0.0, show_actions=True, ) - + logger.debug(f"{self.log_prefix} 成功构建聊天内容,长度: {len(chat_content)}") - + except Exception as e: logger.warning(f"{self.log_prefix} 构建聊天内容失败: {e}") chat_content = "" - + # 获取当前Normal模式下的动作集进行激活判定 current_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL) # print(f"current_actions: {current_actions}") # print(f"chat_content: {chat_content}") final_activated_actions = await self._apply_normal_activation_filtering( - current_actions, - chat_content, - message_content + current_actions, chat_content, message_content ) # print(f"final_activated_actions: {final_activated_actions}") - + # 统一处理所有需要移除的动作,避免重复移除 all_actions_to_remove = set() # 使用set避免重复 - + # 添加关联类型不匹配的动作 if type_mismatched_actions: all_actions_to_remove.update(type_mismatched_actions) - + # 添加激活类型判定未通过的动作 for action_name in current_actions.keys(): if action_name not in final_activated_actions: all_actions_to_remove.add(action_name) - + # 统计移除原因(避免重复) - activation_failed_actions = [name for name in current_actions.keys() if name not in final_activated_actions and name not in type_mismatched_actions] + activation_failed_actions = [ + name + for name in current_actions.keys() + if name not in final_activated_actions and name not in type_mismatched_actions + ] if activation_failed_actions: reasons.append(f"移除{activation_failed_actions}(激活类型判定未通过)") @@ -146,7 +148,7 @@ class NormalChatActionModifier: # 记录变更原因 if reasons: logger.info(f"{self.log_prefix} 动作调整完成: {' | '.join(reasons)}") - + # 获取最终的Normal模式可用动作并记录 final_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL) logger.debug(f"{self.log_prefix} 当前Normal模式可用动作: {list(final_actions.keys())}") @@ -159,31 +161,31 @@ class NormalChatActionModifier: ) -> Dict[str, Any]: """ 应用Normal模式的激活类型过滤逻辑 - + 与Focus模式的区别: 1. LLM_JUDGE类型转换为概率激活(避免LLM调用) 2. RANDOM类型保持概率激活 3. KEYWORD类型保持关键词匹配 4. ALWAYS类型直接激活 - + Args: actions_with_info: 带完整信息的动作字典 chat_content: 聊天内容 - + Returns: Dict[str, Any]: 过滤后激活的actions字典 """ activated_actions = {} - + # 分类处理不同激活类型的actions always_actions = {} random_actions = {} keyword_actions = {} - + for action_name, action_info in actions_with_info.items(): # 使用normal_activation_type activation_type = action_info.get("normal_activation_type", ActionActivationType.ALWAYS) - + if activation_type == ActionActivationType.ALWAYS: always_actions[action_name] = action_info elif activation_type == ActionActivationType.RANDOM or activation_type == ActionActivationType.LLM_JUDGE: @@ -192,12 +194,12 @@ class NormalChatActionModifier: keyword_actions[action_name] = action_info else: logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理") - + # 1. 处理ALWAYS类型(直接激活) for action_name, action_info in always_actions.items(): activated_actions[action_name] = action_info logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: ALWAYS类型直接激活") - + # 2. 处理RANDOM类型(概率激活) for action_name, action_info in random_actions.items(): probability = action_info.get("random_probability", 0.3) @@ -207,15 +209,10 @@ class NormalChatActionModifier: logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发(概率{probability})") else: logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: RANDOM类型未触发(概率{probability})") - + # 3. 处理KEYWORD类型(关键词匹配) for action_name, action_info in keyword_actions.items(): - should_activate = self._check_keyword_activation( - action_name, - action_info, - chat_content, - message_content - ) + should_activate = self._check_keyword_activation(action_name, action_info, chat_content, message_content) if should_activate: activated_actions[action_name] = action_info keywords = action_info.get("activation_keywords", []) @@ -225,7 +222,7 @@ class NormalChatActionModifier: logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词({keywords})") # print(f"keywords: {keywords}") # print(f"chat_content: {chat_content}") - + logger.debug(f"{self.log_prefix}Normal模式激活类型过滤完成: {list(activated_actions.keys())}") return activated_actions @@ -238,41 +235,40 @@ class NormalChatActionModifier: ) -> bool: """ 检查是否匹配关键词触发条件 - + Args: action_name: 动作名称 action_info: 动作信息 chat_content: 聊天内容(已经是格式化后的可读消息) - + Returns: bool: 是否应该激活此action """ - + activation_keywords = action_info.get("activation_keywords", []) case_sensitive = action_info.get("keyword_case_sensitive", False) - + if not activation_keywords: logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词") return False - + # 使用构建好的聊天内容作为检索文本 - search_text = chat_content +message_content - + search_text = chat_content + message_content + # 如果不区分大小写,转换为小写 if not case_sensitive: search_text = search_text.lower() - + # 检查每个关键词 matched_keywords = [] for keyword in activation_keywords: check_keyword = keyword if case_sensitive else keyword.lower() if check_keyword in search_text: matched_keywords.append(keyword) - - + # print(f"search_text: {search_text}") # print(f"activation_keywords: {activation_keywords}") - + if matched_keywords: logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}") return True diff --git a/src/chat/normal_chat/normal_chat_expressor.py b/src/chat/normal_chat/normal_chat_expressor.py index 45c0155f8..0f423259f 100644 --- a/src/chat/normal_chat/normal_chat_expressor.py +++ b/src/chat/normal_chat/normal_chat_expressor.py @@ -9,7 +9,7 @@ import time from typing import List, Optional, Tuple, Dict, Any from src.chat.message_receive.message import MessageRecv, MessageSending, MessageThinking, Seg from src.chat.message_receive.message import UserInfo -from src.chat.message_receive.chat_stream import ChatStream,chat_manager +from src.chat.message_receive.chat_stream import ChatStream, chat_manager from src.chat.message_receive.message_sender import message_manager from src.config.config import global_config from src.common.logger_manager import get_logger @@ -37,7 +37,7 @@ class NormalChatExpressor: self.chat_stream = chat_stream self.stream_name = chat_manager.get_stream_name(self.chat_stream.stream_id) or self.chat_stream.stream_id self.log_prefix = f"[{self.stream_name}]Normal表达器" - + logger.debug(f"{self.log_prefix} 初始化完成") async def create_thinking_message( diff --git a/src/chat/normal_chat/normal_chat_generator.py b/src/chat/normal_chat/normal_chat_generator.py index 06fb9cf77..41ac71492 100644 --- a/src/chat/normal_chat/normal_chat_generator.py +++ b/src/chat/normal_chat/normal_chat_generator.py @@ -25,9 +25,7 @@ class NormalChatGenerator: request_type="normal.chat_2", ) - self.model_sum = LLMRequest( - model=global_config.model.memory_summary, temperature=0.7, request_type="relation" - ) + self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation") self.current_model_type = "r1" # 默认使用 R1 self.current_model_name = "unknown model" @@ -68,7 +66,6 @@ class NormalChatGenerator: enable_planner: bool = False, available_actions=None, ): - person_id = person_info_manager.get_person_id( message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id ) @@ -103,7 +100,6 @@ class NormalChatGenerator: logger.info(f"对 {message.processed_plain_text} 的回复:{content}") - except Exception: logger.exception("生成回复时出错") return None diff --git a/src/chat/normal_chat/normal_chat_planner.py b/src/chat/normal_chat/normal_chat_planner.py index 0712d1c8d..eceb73d77 100644 --- a/src/chat/normal_chat/normal_chat_planner.py +++ b/src/chat/normal_chat/normal_chat_planner.py @@ -101,7 +101,7 @@ class NormalChatPlanner: # 获取当前可用的动作,使用Normal模式过滤 current_available_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL) - + # 注意:动作的激活判定现在在 normal_chat_action_modifier 中完成 # 这里直接使用经过 action_modifier 处理后的最终动作集 # 符合职责分离原则:ActionModifier负责动作管理,Planner专注于决策 @@ -110,7 +110,12 @@ class NormalChatPlanner: if not current_available_actions: logger.debug(f"{self.log_prefix}规划器: 没有可用动作,返回no_action") return { - "action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": True}, + "action_result": { + "action_type": action, + "action_data": action_data, + "reasoning": reasoning, + "is_parallel": True, + }, "chat_context": "", "action_prompt": "", } @@ -121,7 +126,7 @@ class NormalChatPlanner: timestamp=time.time(), limit=global_config.focus_chat.observation_context_size, ) - + chat_context = build_readable_messages( message_list_before_now, replace_bot_name=True, @@ -130,7 +135,7 @@ class NormalChatPlanner: read_mark=0.0, show_actions=True, ) - + # 构建planner的prompt prompt = await self.build_planner_prompt( self_info_block=self_info, @@ -141,7 +146,12 @@ class NormalChatPlanner: if not prompt: logger.warning(f"{self.log_prefix}规划器: 构建提示词失败") return { - "action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": False}, + "action_result": { + "action_type": action, + "action_data": action_data, + "reasoning": reasoning, + "is_parallel": False, + }, "chat_context": chat_context, "action_prompt": "", } @@ -149,7 +159,7 @@ class NormalChatPlanner: # 使用LLM生成动作决策 try: content, (reasoning_content, model_name) = await self.planner_llm.generate_response_async(prompt) - + # logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") logger.info(f"{self.log_prefix}规划器原始响应: {content}") logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}") @@ -201,8 +211,10 @@ class NormalChatPlanner: if action in current_available_actions: action_info = current_available_actions[action] is_parallel = action_info.get("parallel_action", False) - - logger.debug(f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}") + + logger.debug( + f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}" + ) # 恢复到默认动作集 self.action_manager.restore_actions() @@ -216,15 +228,15 @@ class NormalChatPlanner: "action_data": action_data, "reasoning": reasoning, "timestamp": time.time(), - "model_name": model_name if 'model_name' in locals() else None + "model_name": model_name if "model_name" in locals() else None, } action_result = { - "action_type": action, - "action_data": action_data, + "action_type": action, + "action_data": action_data, "reasoning": reasoning, "is_parallel": is_parallel, - "action_record": json.dumps(action_record, ensure_ascii=False) + "action_record": json.dumps(action_record, ensure_ascii=False), } plan_result = { @@ -248,24 +260,19 @@ class NormalChatPlanner: # 添加特殊的change_to_focus_chat动作 action_options_text += "动作:change_to_focus_chat\n" - action_options_text += ( - "该动作的描述:当聊天变得热烈、自己回复条数很多或需要深入交流时使用,正常回复消息并切换到focus_chat模式\n" - ) + action_options_text += "该动作的描述:当聊天变得热烈、自己回复条数很多或需要深入交流时使用,正常回复消息并切换到focus_chat模式\n" action_options_text += "使用该动作的场景:\n" action_options_text += "- 聊天上下文中自己的回复条数较多(超过3-4条)\n" action_options_text += "- 对话进行得非常热烈活跃\n" action_options_text += "- 用户表现出深入交流的意图\n" action_options_text += "- 话题需要更专注和深入的讨论\n\n" - + action_options_text += "输出要求:\n" action_options_text += "{{" - action_options_text += " \"action\": \"change_to_focus_chat\"" + action_options_text += ' "action": "change_to_focus_chat"' action_options_text += "}}\n\n" - - - - + for action_name, action_info in current_available_actions.items(): action_description = action_info.get("description", "") action_parameters = action_info.get("parameters", {}) @@ -276,15 +283,14 @@ class NormalChatPlanner: print(action_parameters) for param_name, param_description in action_parameters.items(): param_text += f' "{param_name}":"{param_description}"\n' - param_text = param_text.rstrip('\n') + param_text = param_text.rstrip("\n") else: param_text = "" - require_text = "" for require_item in action_require: require_text += f"- {require_item}\n" - require_text = require_text.rstrip('\n') + require_text = require_text.rstrip("\n") # 构建单个动作的提示 action_prompt = await global_prompt_manager.format_prompt( @@ -316,6 +322,4 @@ class NormalChatPlanner: return "" - - init_prompt() diff --git a/src/chat/normal_chat/normal_prompt.py b/src/chat/normal_chat/normal_prompt.py index 8b835951b..168b52da2 100644 --- a/src/chat/normal_chat/normal_prompt.py +++ b/src/chat/normal_chat/normal_prompt.py @@ -214,7 +214,6 @@ class PromptBuilder: except Exception as e: logger.error(f"关键词检测与反应时发生异常: {str(e)}", exc_info=True) - moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。" # 构建action描述 (如果启用planner) diff --git a/src/chat/normal_chat/willing/mode_classical.py b/src/chat/normal_chat/willing/mode_classical.py index 1aa302945..fc030a7cd 100644 --- a/src/chat/normal_chat/willing/mode_classical.py +++ b/src/chat/normal_chat/willing/mode_classical.py @@ -42,9 +42,7 @@ class ClassicalWillingManager(BaseWillingManager): self.chat_reply_willing[chat_id] = min(current_willing, 3.0) - reply_probability = min( - max((current_willing - 0.5), 0.01) * 2, 1 - ) + reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1) # 检查群组权限(如果是群聊) if ( diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 782b7500d..73ee59fd1 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -286,7 +286,7 @@ def _build_readable_messages_internal( message_details_with_flags.append((timestamp, name, content, is_action)) # print(f"content:{content}") # print(f"is_action:{is_action}") - + # print(f"message_details_with_flags:{message_details_with_flags}") # 应用截断逻辑 (如果 truncate 为 True) @@ -324,7 +324,7 @@ def _build_readable_messages_internal( else: # 如果不截断,直接使用原始列表 message_details = message_details_with_flags - + # print(f"message_details:{message_details}") # 3: 合并连续消息 (如果 merge_messages 为 True) @@ -336,12 +336,12 @@ def _build_readable_messages_internal( "start_time": message_details[0][0], "end_time": message_details[0][0], "content": [message_details[0][2]], - "is_action": message_details[0][3] + "is_action": message_details[0][3], } for i in range(1, len(message_details)): timestamp, name, content, is_action = message_details[i] - + # 对于动作记录,不进行合并 if is_action or current_merge["is_action"]: # 保存当前的合并块 @@ -352,7 +352,7 @@ def _build_readable_messages_internal( "start_time": timestamp, "end_time": timestamp, "content": [content], - "is_action": is_action + "is_action": is_action, } continue @@ -365,11 +365,11 @@ def _build_readable_messages_internal( merged_messages.append(current_merge) # 开始新的合并块 current_merge = { - "name": name, - "start_time": timestamp, - "end_time": timestamp, + "name": name, + "start_time": timestamp, + "end_time": timestamp, "content": [content], - "is_action": is_action + "is_action": is_action, } # 添加最后一个合并块 merged_messages.append(current_merge) @@ -381,10 +381,9 @@ def _build_readable_messages_internal( "start_time": timestamp, # 起始和结束时间相同 "end_time": timestamp, "content": [content], # 内容只有一个元素 - "is_action": is_action + "is_action": is_action, } ) - # 4 & 5: 格式化为字符串 output_lines = [] @@ -451,7 +450,7 @@ def build_readable_messages( 将消息列表转换为可读的文本格式。 如果提供了 read_mark,则在相应位置插入已读标记。 允许通过参数控制格式化行为。 - + Args: messages: 消息列表 replace_bot_name: 是否替换机器人名称为"你" @@ -463,22 +462,24 @@ def build_readable_messages( """ # 创建messages的深拷贝,避免修改原始列表 copy_messages = [msg.copy() for msg in messages] - + if show_actions and copy_messages: # 获取所有消息的时间范围 min_time = min(msg.get("time", 0) for msg in copy_messages) max_time = max(msg.get("time", 0) for msg in copy_messages) - + # 从第一条消息中获取chat_id chat_id = copy_messages[0].get("chat_id") if copy_messages else None - + # 获取这个时间范围内的动作记录,并匹配chat_id - actions = ActionRecords.select().where( - (ActionRecords.time >= min_time) & - (ActionRecords.time <= max_time) & - (ActionRecords.chat_id == chat_id) - ).order_by(ActionRecords.time) - + actions = ( + ActionRecords.select() + .where( + (ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id) + ) + .order_by(ActionRecords.time) + ) + # 将动作记录转换为消息格式 for action in actions: # 只有当build_into_prompt为True时才添加动作记录 @@ -495,25 +496,22 @@ def build_readable_messages( "action_name": action.action_name, # 保存动作名称 } copy_messages.append(action_msg) - + # 重新按时间排序 copy_messages.sort(key=lambda x: x.get("time", 0)) if read_mark <= 0: # 没有有效的 read_mark,直接格式化所有消息 - + # for message in messages: - # print(f"message:{message}") - - + # print(f"message:{message}") + formatted_string, _ = _build_readable_messages_internal( copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate ) - + # print(f"formatted_string:{formatted_string}") - - - + return formatted_string else: # 按 read_mark 分割消息 @@ -521,10 +519,10 @@ def build_readable_messages( messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark] # for message in messages_before_mark: - # print(f"message:{message}") - + # print(f"message:{message}") + # for message in messages_after_mark: - # print(f"message:{message}") + # print(f"message:{message}") # 分别格式化 formatted_before, _ = _build_readable_messages_internal( @@ -536,7 +534,7 @@ def build_readable_messages( merge_messages, timestamp_mode, ) - + # print(f"formatted_before:{formatted_before}") # print(f"formatted_after:{formatted_after}") diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 3f6fd7b44..b9d6a6e15 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -154,7 +154,8 @@ class Messages(BaseModel): class Meta: # database = db # 继承自 BaseModel table_name = "messages" - + + class ActionRecords(BaseModel): """ 用于存储动作记录数据的模型。 @@ -162,11 +163,11 @@ class ActionRecords(BaseModel): action_id = TextField(index=True) # 消息 ID (更改自 IntegerField) time = DoubleField() # 消息时间戳 - + action_name = TextField() action_data = TextField() action_done = BooleanField(default=False) - + action_build_into_prompt = BooleanField(default=False) action_prompt_display = TextField() @@ -241,11 +242,10 @@ class PersonInfo(BaseModel): points = TextField(null=True) # 个人印象的点 forgotten_points = TextField(null=True) # 被遗忘的点 info_list = TextField(null=True) # 与Bot的互动 - + know_times = FloatField(null=True) # 认识时间 (时间戳) know_since = FloatField(null=True) # 首次印象总结时间 last_know = FloatField(null=True) # 最后一次印象总结时间 - class Meta: # database = db # 继承自 BaseModel @@ -403,20 +403,20 @@ def initialize_database(): logger.info(f"表 '{table_name}' 缺失字段 '{field_name}',正在添加...") field_type = field_obj.__class__.__name__ sql_type = { - 'TextField': 'TEXT', - 'IntegerField': 'INTEGER', - 'FloatField': 'FLOAT', - 'DoubleField': 'DOUBLE', - 'BooleanField': 'INTEGER', - 'DateTimeField': 'DATETIME' - }.get(field_type, 'TEXT') - alter_sql = f'ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}' + "TextField": "TEXT", + "IntegerField": "INTEGER", + "FloatField": "FLOAT", + "DoubleField": "DOUBLE", + "BooleanField": "INTEGER", + "DateTimeField": "DATETIME", + }.get(field_type, "TEXT") + alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}" if field_obj.null: - alter_sql += ' NULL' + alter_sql += " NULL" else: - alter_sql += ' NOT NULL' - if hasattr(field_obj, 'default') and field_obj.default is not None: - alter_sql += f' DEFAULT {field_obj.default}' + alter_sql += " NOT NULL" + if hasattr(field_obj, "default") and field_obj.default is not None: + alter_sql += f" DEFAULT {field_obj.default}" db.execute_sql(alter_sql) logger.info(f"字段 '{field_name}' 添加成功") diff --git a/src/config/auto_update.py b/src/config/auto_update.py index 54419a622..2088e3628 100644 --- a/src/config/auto_update.py +++ b/src/config/auto_update.py @@ -84,7 +84,7 @@ def update_config(): contains_regex = False if value and isinstance(value[0], dict) and "regex" in value[0]: contains_regex = True - + if contains_regex: target[key] = value else: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 3adff5fac..34da536db 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -49,7 +49,7 @@ class IdentityConfig(ConfigBase): @dataclass class RelationshipConfig(ConfigBase): """关系配置类""" - + enable_relationship: bool = True give_name: bool = False @@ -58,6 +58,7 @@ class RelationshipConfig(ConfigBase): build_relationship_interval: int = 600 """构建关系间隔 单位秒,如果为0则不构建关系""" + @dataclass class ChatConfig(ConfigBase): """聊天配置类""" @@ -222,7 +223,7 @@ class EmojiConfig(ConfigBase): @dataclass class MemoryConfig(ConfigBase): """记忆配置类""" - + enable_memory: bool = True memory_build_interval: int = 600 @@ -329,6 +330,7 @@ class KeywordReactionConfig(ConfigBase): if not isinstance(rule, KeywordRuleConfig): raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}") + @dataclass class ResponsePostProcessConfig(ConfigBase): """回复后处理配置类""" @@ -461,7 +463,7 @@ class LPMMKnowledgeConfig(ConfigBase): qa_res_top_k: int = 10 """QA最终结果的Top K数量""" - + @dataclass class ModelConfig(ConfigBase): diff --git a/src/experimental/PFC/message_storage.py b/src/experimental/PFC/message_storage.py index e2e1dd052..f049e002f 100644 --- a/src/experimental/PFC/message_storage.py +++ b/src/experimental/PFC/message_storage.py @@ -1,9 +1,12 @@ from abc import ABC, abstractmethod -from typing import List, Dict, Any +from typing import List, Dict, Any, Callable -# from src.common.database.database import db # Peewee db 导入 +from playhouse import shortcuts + +# from src.common.database.database import db # Peewee db 导入 from src.common.database.database_model import Messages # Peewee Messages 模型导入 -from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典 + +model_to_dict: Callable[..., dict] = shortcuts.model_to_dict # Peewee 模型转换为字典的快捷函数 class MessageStorage(ABC): diff --git a/src/individuality/expression_style.py b/src/individuality/expression_style.py index 3f8ae8de7..7ff3b91ff 100644 --- a/src/individuality/expression_style.py +++ b/src/individuality/expression_style.py @@ -90,7 +90,7 @@ class PersonalityExpression: current_style_text = global_config.expression.expression_style current_personality = global_config.personality.personality_core - + meta_data = self._read_meta_data() last_style_text = meta_data.get("last_style_text") @@ -98,9 +98,10 @@ class PersonalityExpression: count = meta_data.get("count", 0) # 检查是否有任何变化 - if (current_style_text != last_style_text or - current_personality != last_personality): - logger.info(f"检测到变化:\n风格: '{last_style_text}' -> '{current_style_text}'\n人格: '{last_personality}' -> '{current_personality}'") + if current_style_text != last_style_text or current_personality != last_personality: + logger.info( + f"检测到变化:\n风格: '{last_style_text}' -> '{current_style_text}'\n人格: '{last_personality}' -> '{current_personality}'" + ) count = 0 if os.path.exists(self.expressions_file_path): try: @@ -196,7 +197,7 @@ class PersonalityExpression: "last_style_text": current_style_text, "last_personality": current_personality, "count": count, - "last_update_time": current_time + "last_update_time": current_time, } ) logger.info(f"成功处理。当前配置的计数现在是 {count},最后更新时间:{current_time}。") diff --git a/src/main.py b/src/main.py index f7df8ee16..004b68ba2 100644 --- a/src/main.py +++ b/src/main.py @@ -19,6 +19,7 @@ from .common.server import global_server, Server from rich.traceback import install from .chat.focus_chat.expressors.exprssion_learner import expression_learner from .api.main import start_api_server + # 导入actions模块,确保装饰器被执行 import src.chat.actions.default_actions # noqa @@ -40,7 +41,7 @@ class MainSystem: self.hippocampus_manager = hippocampus_manager else: self.hippocampus_manager = None - + self.individuality: Individuality = individuality # 使用消息API替代直接的FastAPI实例 @@ -74,11 +75,11 @@ class MainSystem: # 启动API服务器 start_api_server() logger.success("API服务器启动成功") - + # 加载所有actions,包括默认的和插件的 self._load_all_actions() logger.success("动作系统加载成功") - + # 初始化表情管理器 emoji_manager.initialize() logger.success("表情包管理器初始化成功") @@ -137,23 +138,25 @@ class MainSystem: try: # 导入统一的插件加载器 from src.plugins.plugin_loader import plugin_loader - + # 使用统一的插件加载器加载所有插件组件 loaded_actions, loaded_commands = plugin_loader.load_all_plugins() - + # 加载命令处理系统 try: # 导入命令处理系统 - from src.chat.command.command_handler import command_manager + logger.success("命令处理系统加载成功") except Exception as e: logger.error(f"加载命令处理系统失败: {e}") import traceback + logger.error(traceback.format_exc()) - + except Exception as e: logger.error(f"加载插件失败: {e}") import traceback + logger.error(traceback.format_exc()) async def schedule_tasks(self): @@ -165,17 +168,19 @@ class MainSystem: self.app.run(), self.server.run(), ] - + # 根据配置条件性地添加记忆系统相关任务 if global_config.memory.enable_memory and self.hippocampus_manager: - tasks.extend([ - self.build_memory_task(), - self.forget_memory_task(), - self.consolidate_memory_task(), - ]) - + tasks.extend( + [ + self.build_memory_task(), + self.forget_memory_task(), + self.consolidate_memory_task(), + ] + ) + tasks.append(self.learn_and_store_expression_task()) - + await asyncio.gather(*tasks) async def build_memory_task(self): @@ -190,9 +195,7 @@ class MainSystem: while True: await asyncio.sleep(global_config.memory.forget_memory_interval) logger.info("[记忆遗忘] 开始遗忘记忆...") - await self.hippocampus_manager.forget_memory( - percentage=global_config.memory.memory_forget_percentage - ) + await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) logger.info("[记忆遗忘] 记忆遗忘完成") async def consolidate_memory_task(self): diff --git a/src/person_info/impression_update_task.py b/src/person_info/impression_update_task.py index d6e1e2017..98b6ede36 100644 --- a/src/person_info/impression_update_task.py +++ b/src/person_info/impression_update_task.py @@ -11,6 +11,7 @@ from collections import defaultdict logger = get_logger("relation") + # 暂时弃用,改为实时更新 class ImpressionUpdateTask(AsyncTask): def __init__(self): @@ -25,10 +26,10 @@ class ImpressionUpdateTask(AsyncTask): # 获取最近的消息 current_time = int(time.time()) start_time = current_time - global_config.relationship.build_relationship_interval # 100分钟前 - + # 获取所有消息 messages = get_raw_msg_by_timestamp(timestamp_start=start_time, timestamp_end=current_time) - + if not messages: logger.info("没有找到需要处理的消息") return @@ -48,7 +49,7 @@ class ImpressionUpdateTask(AsyncTask): if len(msgs) < 30: logger.info(f"聊天组 {chat_id} 消息数小于30,跳过处理") continue - + chat_stream = chat_manager.get_stream(chat_id) if not chat_stream: logger.warning(f"未找到聊天组 {chat_id} 的chat_stream,跳过处理") @@ -56,26 +57,26 @@ class ImpressionUpdateTask(AsyncTask): # 找到bot的消息 bot_messages = [msg for msg in msgs if msg["user_nickname"] == global_config.bot.nickname] - + if not bot_messages: logger.info(f"聊天组 {chat_id} 没有bot消息,跳过处理") continue # 按时间排序所有消息 sorted_messages = sorted(msgs, key=lambda x: x["time"]) - + # 找到第一条和最后一条bot消息 first_bot_msg = bot_messages[0] last_bot_msg = bot_messages[-1] - + # 获取第一条bot消息前15条消息 first_bot_index = sorted_messages.index(first_bot_msg) start_index = max(0, first_bot_index - 25) - + # 获取最后一条bot消息后15条消息 last_bot_index = sorted_messages.index(last_bot_msg) end_index = min(len(sorted_messages), last_bot_index + 26) - + # 获取相关消息 relevant_messages = sorted_messages[start_index:end_index] @@ -85,7 +86,9 @@ class ImpressionUpdateTask(AsyncTask): # 计算权重 for bot_msg in bot_messages: bot_time = bot_msg["time"] - context_messages = [msg for msg in relevant_messages if abs(msg["time"] - bot_time) <= 600] # 前后10分钟 + context_messages = [ + msg for msg in relevant_messages if abs(msg["time"] - bot_time) <= 600 + ] # 前后10分钟 logger.debug(f"Bot消息 {bot_time} 的上下文消息数: {len(context_messages)}") for msg in context_messages: @@ -121,7 +124,7 @@ class ImpressionUpdateTask(AsyncTask): weights = [user[1]["weight"] for user in sorted_users] total_weight = sum(weights) # 计算每个用户的概率 - probabilities = [w/total_weight for w in weights] + probabilities = [w / total_weight for w in weights] # 使用累积概率进行选择 selected_indices = [] remaining_indices = list(range(len(sorted_users))) @@ -131,12 +134,12 @@ class ImpressionUpdateTask(AsyncTask): # 计算剩余索引的累积概率 remaining_probs = [probabilities[i] for i in remaining_indices] # 归一化概率 - remaining_probs = [p/sum(remaining_probs) for p in remaining_probs] + remaining_probs = [p / sum(remaining_probs) for p in remaining_probs] # 选择索引 chosen_idx = random.choices(remaining_indices, weights=remaining_probs, k=1)[0] selected_indices.append(chosen_idx) remaining_indices.remove(chosen_idx) - + selected_users = [sorted_users[i] for i in selected_indices] logger.info( f"开始进一步了解这些用户: {[msg[1]['messages'][0]['user_nickname'] for msg in selected_users]}" @@ -153,19 +156,16 @@ class ImpressionUpdateTask(AsyncTask): platform = data["messages"][0]["chat_info_platform"] user_id = data["messages"][0]["user_id"] cardname = data["messages"][0]["user_cardname"] - + is_known = await relationship_manager.is_known_some_one(platform, user_id) if not is_known: logger.info(f"首次认识用户: {user_nickname}") await relationship_manager.first_knowing_some_one(platform, user_id, user_nickname, cardname) - - + logger.info(f"开始更新用户 {user_nickname} 的印象") await relationship_manager.update_person_impression( - person_id=person_id, - timestamp=last_bot_msg["time"], - bot_engaged_messages=relevant_messages + person_id=person_id, timestamp=last_bot_msg["time"], bot_engaged_messages=relevant_messages ) logger.debug("印象更新任务执行完成") diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 6e4d8219b..a62200c90 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -33,7 +33,7 @@ JSON_SERIALIZED_FIELDS = ["points", "forgotten_points", "info_list"] person_info_default = { "person_id": None, "person_name": None, - "name_reason": None, # Corrected from person_name_reason to match common usage if intended + "name_reason": None, # Corrected from person_name_reason to match common usage if intended "platform": "unknown", "user_id": "unknown", "nickname": "Unknown", @@ -42,11 +42,10 @@ person_info_default = { "last_know": None, # "user_cardname": None, # This field is not in Peewee model PersonInfo # "user_avatar": None, # This field is not in Peewee model PersonInfo - "impression": None, # Corrected from persion_impression + "impression": None, # Corrected from persion_impression "info_list": None, "points": None, "forgotten_points": None, - } @@ -126,7 +125,7 @@ class PersonInfoManager: for key, default_value in _person_info_default.items(): if key in model_fields: final_data[key] = default_value - + # Override with provided data if data: for key, value in data.items(): @@ -141,7 +140,7 @@ class PersonInfoManager: if key in final_data: if isinstance(final_data[key], (list, dict)): final_data[key] = json.dumps(final_data[key], ensure_ascii=False) - elif final_data[key] is None: # Default for lists is [], store as "[]" + elif final_data[key] is None: # Default for lists is [], store as "[]" final_data[key] = json.dumps([], ensure_ascii=False) # If it's already a string, assume it's valid JSON or a non-JSON string field @@ -165,12 +164,12 @@ class PersonInfoManager: return print(f"更新字段: {field_name},值: {value}") - + processed_value = value if field_name in JSON_SERIALIZED_FIELDS: if isinstance(value, (list, dict)): processed_value = json.dumps(value, ensure_ascii=False, indent=None) - elif value is None: # Store None as "[]" for JSON list fields + elif value is None: # Store None as "[]" for JSON list fields processed_value = json.dumps([], ensure_ascii=False, indent=None) # If value is already a string, assume it's pre-serialized or a non-JSON string. @@ -180,7 +179,7 @@ class PersonInfoManager: setattr(record, f_name, val_to_set) record.save() return True, False # Found and updated, no creation needed - return False, True # Not found, needs creation + return False, True # Not found, needs creation found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value) @@ -190,15 +189,14 @@ class PersonInfoManager: # Ensure platform and user_id are present for context if available from 'data' # but primarily, set the field that triggered the update. # The create_person_info will handle defaults and serialization. - creation_data[field_name] = value # Pass original value to create_person_info - + creation_data[field_name] = value # Pass original value to create_person_info + # Ensure platform and user_id are in creation_data if available, # otherwise create_person_info will use defaults. if data and "platform" in data: - creation_data["platform"] = data["platform"] + creation_data["platform"] = data["platform"] if data and "user_id" in data: - creation_data["user_id"] = data["user_id"] - + creation_data["user_id"] = data["user_id"] await self.create_person_info(person_id, creation_data) @@ -233,7 +231,7 @@ class PersonInfoManager: if isinstance(parsed_json, list) and parsed_json: parsed_json = parsed_json[0] - + if isinstance(parsed_json, dict): return parsed_json @@ -249,11 +247,11 @@ class PersonInfoManager: # 处理空昵称的情况 if not base_name or base_name.isspace(): base_name = "空格" - + # 检查基础名称是否已存在 if base_name not in self.person_name_list.values(): return base_name - + # 如果存在,添加数字后缀 counter = 1 while True: @@ -331,9 +329,11 @@ class PersonInfoManager: if not is_duplicate: await self.update_one_field(person_id, "person_name", generated_nickname) await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由")) - - logger.info(f"成功给用户{user_nickname} {person_id} 取名 {generated_nickname},理由:{result.get('reason', '未提供理由')}") - + + logger.info( + f"成功给用户{user_nickname} {person_id} 取名 {generated_nickname},理由:{result.get('reason', '未提供理由')}" + ) + self.person_name_list[person_id] = generated_nickname return result else: @@ -379,7 +379,7 @@ class PersonInfoManager: """获取指定用户指定字段的值""" default_value_for_field = person_info_default.get(field_name) if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: - default_value_for_field = [] # Ensure JSON fields default to [] if not in DB + default_value_for_field = [] # Ensure JSON fields default to [] if not in DB def _db_get_value_sync(p_id: str, f_name: str): record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) @@ -391,32 +391,32 @@ class PersonInfoManager: return json.loads(val) except json.JSONDecodeError: logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.") - return [] # Default for JSON fields on error - elif val is None: # Field exists in DB but is None - return [] # Default for JSON fields + return [] # Default for JSON fields on error + elif val is None: # Field exists in DB but is None + return [] # Default for JSON fields # If val is already a list/dict (e.g. if somehow set without serialization) - return val # Should ideally not happen if update_one_field is always used + return val # Should ideally not happen if update_one_field is always used return val - return None # Record not found + return None # Record not found try: value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name) if value_from_db is not None: return value_from_db if field_name in person_info_default: - return default_value_for_field + return default_value_for_field logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。") - return None # Ultimate fallback + return None # Ultimate fallback except Exception as e: logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}") # Fallback to default in case of any error during DB access if field_name in person_info_default: return default_value_for_field return None - + @staticmethod def get_value_sync(person_id: str, field_name: str): - """ 同步获取指定用户指定字段的值 """ + """同步获取指定用户指定字段的值""" default_value_for_field = person_info_default.get(field_name) if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: default_value_for_field = [] @@ -430,12 +430,12 @@ class PersonInfoManager: return json.loads(val) except json.JSONDecodeError: logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.") - return [] + return [] elif val is None: return [] - return val + return val return val - + if field_name in person_info_default: return default_value_for_field logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。") @@ -534,7 +534,7 @@ class PersonInfoManager: "last_know": int(datetime.datetime.now().timestamp()), "impression": None, "points": [], - "forgotten_points": [] + "forgotten_points": [], } model_fields = PersonInfo._meta.fields.keys() filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 558476bc0..19b53be1c 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -7,7 +7,6 @@ from src.llm_models.utils_model import LLMRequest from src.config.config import global_config from src.chat.utils.chat_message_builder import build_readable_messages from src.manager.mood_manager import mood_manager -from src.individuality.individuality import individuality import json from json_repair import repair_json from datetime import datetime @@ -90,9 +89,7 @@ class RelationshipManager: return is_known @staticmethod - async def first_knowing_some_one( - platform: str, user_id: str, user_nickname: str, user_cardname: str - ): + async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str): """判断是否认识某人""" person_id = person_info_manager.get_person_id(platform, user_id) # 生成唯一的 person_name @@ -112,7 +109,7 @@ class RelationshipManager: ) # 尝试生成更好的名字 # await person_info_manager.qv_person_name( - # person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar + # person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar # ) async def build_relationship_info(self, person, is_id: bool = False) -> str: @@ -124,26 +121,24 @@ class RelationshipManager: person_name = await person_info_manager.get_value(person_id, "person_name") if not person_name or person_name == "none": return "" - impression = await person_info_manager.get_value(person_id, "impression") + # impression = await person_info_manager.get_value(person_id, "impression") points = await person_info_manager.get_value(person_id, "points") or [] - + if isinstance(points, str): try: points = ast.literal_eval(points) except (SyntaxError, ValueError): points = [] - + random_points = random.sample(points, min(5, len(points))) if points else [] - + nickname_str = await person_info_manager.get_value(person_id, "nickname") platform = await person_info_manager.get_value(person_id, "platform") relation_prompt = f"'{person_name}' ,ta在{platform}上的昵称是{nickname_str}。" - # if impression: - # relation_prompt += f"你对ta的印象是:{impression}。" + # relation_prompt += f"你对ta的印象是:{impression}。" - if random_points: for point in random_points: # print(f"point: {point}") @@ -151,13 +146,12 @@ class RelationshipManager: # print(f"point[0]: {point[0]}") point_str = f"时间:{point[2]}。内容:{point[0]}" relation_prompt += f"你记得{person_name}最近的点是:{point_str}。" - - + return relation_prompt async def _update_list_field(self, person_id: str, field_name: str, new_items: list) -> None: """更新列表类型的字段,将新项目添加到现有列表中 - + Args: person_id: 用户ID field_name: 字段名称 @@ -179,21 +173,21 @@ class RelationshipManager: """ person_name = await person_info_manager.get_value(person_id, "person_name") nickname = await person_info_manager.get_value(person_id, "nickname") - + alias_str = ", ".join(global_config.bot.alias_names) - personality_block = individuality.get_personality_prompt(x_person=2, level=2) - identity_block = individuality.get_identity_prompt(x_person=2, level=2) + # personality_block = individuality.get_personality_prompt(x_person=2, level=2) + # identity_block = individuality.get_identity_prompt(x_person=2, level=2) user_messages = bot_engaged_messages - + current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") - + # 匿名化消息 # 创建用户名称映射 name_mapping = {} current_user = "A" user_count = 1 - + # 遍历消息,构建映射 for msg in user_messages: await person_info_manager.get_or_create_person( @@ -206,37 +200,31 @@ class RelationshipManager: replace_platform = msg.get("chat_info_platform") replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id) replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name") - + # 跳过机器人自己 if replace_user_id == global_config.bot.qq_account: name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}" continue - + # 跳过目标用户 if replace_person_name == person_name: name_mapping[replace_person_name] = f"{person_name}" continue - + # 其他用户映射 if replace_person_name not in name_mapping: - if current_user > 'Z': - current_user = 'A' + if current_user > "Z": + current_user = "A" user_count += 1 name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}" current_user = chr(ord(current_user) + 1) - - - - readable_messages = self.build_focus_readable_messages( - messages=user_messages, - target_person_id=person_id - ) - + readable_messages = self.build_focus_readable_messages(messages=user_messages, target_person_id=person_id) + for original_name, mapped_name in name_mapping.items(): # print(f"original_name: {original_name}, mapped_name: {mapped_name}") readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}") - + prompt = f""" 你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。 请不要混淆你自己和{global_config.bot.nickname}和{person_name}。 @@ -271,22 +259,22 @@ class RelationshipManager: "weight": 0 }} """ - + # 调用LLM生成印象 points, _ = await self.relationship_llm.generate_response_async(prompt=prompt) points = points.strip() - + # 还原用户名称 for original_name, mapped_name in name_mapping.items(): points = points.replace(mapped_name, original_name) - + # logger.info(f"prompt: {prompt}") # logger.info(f"points: {points}") - + if not points: logger.warning(f"未能从LLM获取 {person_name} 的新印象") return - + # 解析JSON并转换为元组列表 try: points = repair_json(points) @@ -307,7 +295,7 @@ class RelationshipManager: except (KeyError, TypeError) as e: logger.error(f"处理points数据失败: {e}, points: {points}") return - + current_points = await person_info_manager.get_value(person_id, "points") or [] if isinstance(current_points, str): try: @@ -318,7 +306,9 @@ class RelationshipManager: elif not isinstance(current_points, list): current_points = [] current_points.extend(points_list) - await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)) + await person_info_manager.update_one_field( + person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None) + ) # 将新记录添加到现有记录中 if isinstance(current_points, list): @@ -326,14 +316,14 @@ class RelationshipManager: for new_point in points_list: similar_points = [] similar_indices = [] - + # 在现有points中查找相似的点 for i, existing_point in enumerate(current_points): # 使用组合的相似度检查方法 if self.check_similarity(new_point[0], existing_point[0]): similar_points.append(existing_point) similar_indices.append(i) - + if similar_points: # 合并相似的点 all_points = [new_point] + similar_points @@ -343,14 +333,14 @@ class RelationshipManager: total_weight = sum(p[1] for p in all_points) # 使用最长的描述 longest_desc = max(all_points, key=lambda x: len(x[0]))[0] - + # 创建合并后的点 merged_point = (longest_desc, total_weight, latest_time) - + # 从现有points中移除已合并的点 for idx in sorted(similar_indices, reverse=True): current_points.pop(idx) - + # 添加合并后的点 current_points.append(merged_point) else: @@ -359,7 +349,7 @@ class RelationshipManager: else: current_points = points_list -# 如果points超过10条,按权重随机选择多余的条目移动到forgotten_points + # 如果points超过10条,按权重随机选择多余的条目移动到forgotten_points if len(current_points) > 10: # 获取现有forgotten_points forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or [] @@ -371,29 +361,29 @@ class RelationshipManager: forgotten_points = [] elif not isinstance(forgotten_points, list): forgotten_points = [] - + # 计算当前时间 current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") - + # 计算每个点的最终权重(原始权重 * 时间权重) weighted_points = [] for point in current_points: time_weight = self.calculate_time_weight(point[2], current_time) final_weight = point[1] * time_weight weighted_points.append((point, final_weight)) - + # 计算总权重 total_weight = sum(w for _, w in weighted_points) - + # 按权重随机选择要保留的点 remaining_points = [] points_to_move = [] - + # 对每个点进行随机选择 for point, weight in weighted_points: # 计算保留概率(权重越高越可能保留) keep_probability = weight / total_weight - + if len(remaining_points) < 10: # 如果还没达到30条,直接保留 remaining_points.append(point) @@ -407,28 +397,26 @@ class RelationshipManager: else: # 不保留这个点 points_to_move.append(point) - + # 更新points和forgotten_points current_points = remaining_points forgotten_points.extend(points_to_move) - + # 检查forgotten_points是否达到5条 if len(forgotten_points) >= 10: # 构建压缩总结提示词 alias_str = ", ".join(global_config.bot.alias_names) - + # 按时间排序forgotten_points forgotten_points.sort(key=lambda x: x[2]) - + # 构建points文本 - points_text = "\n".join([ - f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" - for point in forgotten_points - ]) - - + points_text = "\n".join( + [f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" for point in forgotten_points] + ) + impression = await person_info_manager.get_value(person_id, "impression") or "" - + compress_prompt = f""" 你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。 请不要混淆你自己和{global_config.bot.nickname}和{person_name}。 @@ -449,88 +437,85 @@ class RelationshipManager: """ # 调用LLM生成压缩总结 compressed_summary, _ = await self.relationship_llm.generate_response_async(prompt=compress_prompt) - + current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") compressed_summary = f"截至{current_time},你对{person_name}的了解:{compressed_summary}" - + await person_info_manager.update_one_field(person_id, "impression", compressed_summary) - + forgotten_points = [] - # 这句代码的作用是:将更新后的 forgotten_points(遗忘的记忆点)列表,序列化为 JSON 字符串后,写回到数据库中的 forgotten_points 字段 - await person_info_manager.update_one_field(person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)) - + await person_info_manager.update_one_field( + person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None) + ) + # 更新数据库 - await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)) + await person_info_manager.update_one_field( + person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None) + ) know_times = await person_info_manager.get_value(person_id, "know_times") or 0 await person_info_manager.update_one_field(person_id, "know_times", know_times + 1) await person_info_manager.update_one_field(person_id, "last_know", timestamp) - logger.info(f"印象更新完成 for {person_name}") - - - + def build_focus_readable_messages(self, messages: list, target_person_id: str = None) -> str: - """格式化消息,只保留目标用户和bot消息附近的内容""" - # 找到目标用户和bot的消息索引 - target_indices = [] - for i, msg in enumerate(messages): - user_id = msg.get("user_id") - platform = msg.get("chat_info_platform") - person_id = person_info_manager.get_person_id(platform, user_id) - if person_id == target_person_id: - target_indices.append(i) - - if not target_indices: - return "" - - # 获取需要保留的消息索引 - keep_indices = set() - for idx in target_indices: - # 获取前后5条消息的索引 - start_idx = max(0, idx - 5) - end_idx = min(len(messages), idx + 6) - keep_indices.update(range(start_idx, end_idx)) - - print(keep_indices) - - # 将索引排序 - keep_indices = sorted(list(keep_indices)) - - # 按顺序构建消息组 - message_groups = [] - current_group = [] - - for i in range(len(messages)): - if i in keep_indices: - current_group.append(messages[i]) - elif current_group: - # 如果当前组不为空,且遇到不保留的消息,则结束当前组 - if current_group: - message_groups.append(current_group) - current_group = [] - - # 添加最后一组 - if current_group: - message_groups.append(current_group) - - # 构建最终的消息文本 - result = [] - for i, group in enumerate(message_groups): - if i > 0: - result.append("...") - group_text = build_readable_messages( - messages=group, - replace_bot_name=True, - timestamp_mode="normal_no_YMD", - truncate=False - ) - result.append(group_text) - - return "\n".join(result) - + """格式化消息,只保留目标用户和bot消息附近的内容""" + # 找到目标用户和bot的消息索引 + target_indices = [] + for i, msg in enumerate(messages): + user_id = msg.get("user_id") + platform = msg.get("chat_info_platform") + person_id = person_info_manager.get_person_id(platform, user_id) + if person_id == target_person_id: + target_indices.append(i) + + if not target_indices: + return "" + + # 获取需要保留的消息索引 + keep_indices = set() + for idx in target_indices: + # 获取前后5条消息的索引 + start_idx = max(0, idx - 5) + end_idx = min(len(messages), idx + 6) + keep_indices.update(range(start_idx, end_idx)) + + print(keep_indices) + + # 将索引排序 + keep_indices = sorted(list(keep_indices)) + + # 按顺序构建消息组 + message_groups = [] + current_group = [] + + for i in range(len(messages)): + if i in keep_indices: + current_group.append(messages[i]) + elif current_group: + # 如果当前组不为空,且遇到不保留的消息,则结束当前组 + if current_group: + message_groups.append(current_group) + current_group = [] + + # 添加最后一组 + if current_group: + message_groups.append(current_group) + + # 构建最终的消息文本 + result = [] + for i, group in enumerate(message_groups): + if i > 0: + result.append("...") + group_text = build_readable_messages( + messages=group, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=False + ) + result.append(group_text) + + return "\n".join(result) + def calculate_time_weight(self, point_time: str, current_time: str) -> float: """计算基于时间的权重系数""" try: @@ -538,7 +523,7 @@ class RelationshipManager: current_timestamp = datetime.strptime(current_time, "%Y-%m-%d %H:%M:%S") time_diff = current_timestamp - point_timestamp hours_diff = time_diff.total_seconds() / 3600 - + if hours_diff <= 1: # 1小时内 return 1.0 elif hours_diff <= 24: # 1-24小时 @@ -564,18 +549,18 @@ class RelationshipManager: s1 = " ".join(str(x) for x in s1) if isinstance(s2, list): s2 = " ".join(str(x) for x in s2) - + # 转换为字符串类型 s1 = str(s1) s2 = str(s2) - + # 1. 使用 jieba 进行分词 s1_words = " ".join(jieba.cut(s1)) s2_words = " ".join(jieba.cut(s2)) - + # 2. 将两句话放入一个列表中 corpus = [s1_words, s2_words] - + # 3. 创建 TF-IDF 向量化器并进行计算 try: vectorizer = TfidfVectorizer() @@ -586,7 +571,7 @@ class RelationshipManager: # 4. 计算余弦相似度 similarity_matrix = cosine_similarity(tfidf_matrix) - + # 返回 s1 和 s2 的相似度 return similarity_matrix[0, 1] @@ -599,20 +584,20 @@ class RelationshipManager: def check_similarity(self, text1, text2, tfidf_threshold=0.5, seq_threshold=0.6): """ 使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的。 - + Args: text1: 第一个文本 text2: 第二个文本 tfidf_threshold: TF-IDF相似度阈值 seq_threshold: SequenceMatcher相似度阈值 - + Returns: bool: 如果任一方法达到阈值则返回True """ # 计算两种相似度 tfidf_sim = self.tfidf_similarity(text1, text2) seq_sim = self.sequence_similarity(text1, text2) - + # 只要其中一种方法达到阈值就认为是相似的 return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold diff --git a/src/plugins/doubao_pic/actions/generate_pic_config.py b/src/plugins/doubao_pic/actions/generate_pic_config.py index 1739f85e8..d9f689783 100644 --- a/src/plugins/doubao_pic/actions/generate_pic_config.py +++ b/src/plugins/doubao_pic/actions/generate_pic_config.py @@ -40,7 +40,7 @@ DEFAULT_CONFIG = { "default_guidance_scale": 2.5, "default_seed": 42, "cache_enabled": True, - "cache_max_size": 10 + "cache_max_size": 10, } @@ -49,37 +49,37 @@ def validate_and_fix_config(config_path: str) -> bool: try: with open(config_path, "r", encoding="utf-8") as f: config = toml.load(f) - + # 检查缺失的配置项 missing_keys = [] fixed = False - + for key, default_value in DEFAULT_CONFIG.items(): if key not in config: missing_keys.append(key) config[key] = default_value fixed = True logger.info(f"添加缺失的配置项: {key} = {default_value}") - + # 验证配置值的类型和范围 if isinstance(config.get("default_guidance_scale"), (int, float)): if not 0.1 <= config["default_guidance_scale"] <= 20.0: config["default_guidance_scale"] = 2.5 fixed = True logger.info("修复无效的 default_guidance_scale 值") - + if isinstance(config.get("default_seed"), (int, float)): config["default_seed"] = int(config["default_seed"]) else: config["default_seed"] = 42 fixed = True logger.info("修复无效的 default_seed 值") - + if config.get("cache_max_size") and not isinstance(config["cache_max_size"], int): config["cache_max_size"] = 10 fixed = True logger.info("修复无效的 cache_max_size 值") - + # 如果有修复,写回文件 if fixed: # 创建备份 @@ -87,14 +87,14 @@ def validate_and_fix_config(config_path: str) -> bool: if os.path.exists(config_path): os.rename(config_path, backup_path) logger.info(f"已创建配置备份: {backup_path}") - + # 写入修复后的配置 with open(config_path, "w", encoding="utf-8") as f: toml.dump(config, f) logger.info(f"配置文件已修复: {config_path}") - + return True - + except Exception as e: logger.error(f"验证配置文件时出错: {e}") return False diff --git a/src/plugins/doubao_pic/actions/pic_action.py b/src/plugins/doubao_pic/actions/pic_action.py index 8d5515366..193eeed73 100644 --- a/src/plugins/doubao_pic/actions/pic_action.py +++ b/src/plugins/doubao_pic/actions/pic_action.py @@ -37,15 +37,15 @@ class PicAction(PluginAction): ] enable_plugin = False action_config_file_name = "pic_action_config.toml" - + # 激活类型设置 focus_activation_type = ActionActivationType.LLM_JUDGE # Focus模式使用LLM判定,精确理解需求 - normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活,快速响应 - + normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活,快速响应 + # 关键词设置(用于Normal模式) activation_keywords = ["画", "绘制", "生成图片", "画图", "draw", "paint", "图片生成"] keyword_case_sensitive = False - + # LLM判定提示词(用于Focus模式) llm_judge_prompt = """ 判定是否需要使用图片生成动作的条件: @@ -67,31 +67,31 @@ class PicAction(PluginAction): 4. 技术讨论中提到绘图概念但无生成需求 5. 用户明确表示不需要图片时 """ - + # Random激活概率(备用) random_activation_probability = 0.15 # 适中概率,图片生成比较有趣 - + # 简单的请求缓存,避免短时间内重复请求 _request_cache = {} _cache_max_size = 10 - + # 模式启用设置 - 图片生成在所有模式下可用 mode_enable = ChatMode.ALL - + # 并行执行设置 - 图片生成可以与回复并行执行,不覆盖回复内容 parallel_action = False - + @classmethod def _get_cache_key(cls, description: str, model: str, size: str) -> str: """生成缓存键""" return f"{description[:100]}|{model}|{size}" # 限制描述长度避免键过长 - + @classmethod def _cleanup_cache(cls): """清理缓存,保持大小在限制内""" if len(cls._request_cache) > cls._cache_max_size: # 简单的FIFO策略,移除最旧的条目 - keys_to_remove = list(cls._request_cache.keys())[:-cls._cache_max_size//2] + keys_to_remove = list(cls._request_cache.keys())[: -cls._cache_max_size // 2] for key in keys_to_remove: del cls._request_cache[key] @@ -169,7 +169,7 @@ class PicAction(PluginAction): cached_result = self._request_cache[cache_key] logger.info(f"{self.log_prefix} 使用缓存的图片结果") await self.send_message_by_expressor("我之前画过类似的图片,用之前的结果~") - + # 直接发送缓存的结果 send_success = await self.send_message(type="image", data=cached_result) if send_success: @@ -258,7 +258,7 @@ class PicAction(PluginAction): # 缓存成功的结果 self._request_cache[cache_key] = base64_image_string self._cleanup_cache() - + await self.send_message_by_expressor("图片表情已发送!") return True, "图片表情已发送" else: @@ -370,7 +370,7 @@ class PicAction(PluginAction): def _validate_image_size(self, image_size: str) -> bool: """验证图片尺寸格式""" try: - width, height = map(int, image_size.split('x')) + width, height = map(int, image_size.split("x")) return 100 <= width <= 10000 and 100 <= height <= 10000 except (ValueError, TypeError): return False diff --git a/src/plugins/example_command_plugin/__init__.py b/src/plugins/example_command_plugin/__init__.py index 4f644bd2b..482b8c27c 100644 --- a/src/plugins/example_command_plugin/__init__.py +++ b/src/plugins/example_command_plugin/__init__.py @@ -11,4 +11,4 @@ - 用户输入特定格式的命令时触发 - 通过命令前缀(如/)快速执行特定功能 - 提供快速响应的交互方式 -""" \ No newline at end of file +""" diff --git a/src/plugins/example_command_plugin/commands/__init__.py b/src/plugins/example_command_plugin/commands/__init__.py index e8dce0578..9fb74a8c3 100644 --- a/src/plugins/example_command_plugin/commands/__init__.py +++ b/src/plugins/example_command_plugin/commands/__init__.py @@ -1,4 +1,4 @@ """示例命令包 包含示例命令的实现 -""" \ No newline at end of file +""" diff --git a/src/plugins/example_command_plugin/commands/custom_prefix_command.py b/src/plugins/example_command_plugin/commands/custom_prefix_command.py index 932cc062b..5297dd9ad 100644 --- a/src/plugins/example_command_plugin/commands/custom_prefix_command.py +++ b/src/plugins/example_command_plugin/commands/custom_prefix_command.py @@ -5,27 +5,28 @@ import random logger = get_logger("custom_prefix_command") + @register_command class DiceCommand(BaseCommand): """骰子命令,使用!前缀而不是/前缀""" - + command_name = "dice" command_description = "骰子命令,随机生成1-6的数字" command_pattern = r"^[!!](?:dice|骰子)(?:\s+(?P\d+))?$" # 匹配 !dice 或 !骰子,可选参数为骰子数量 command_help = "使用方法: !dice [数量] 或 !骰子 [数量] - 掷骰子,默认掷1个" command_examples = ["!dice", "!骰子", "!dice 3", "!骰子 5"] enable_command = True - + async def execute(self) -> Tuple[bool, Optional[str]]: """执行骰子命令 - + Returns: Tuple[bool, Optional[str]]: (是否执行成功, 回复消息) """ try: # 获取骰子数量,默认为1 count_str = self.matched_groups.get("count") - + # 确保count_str不为None if count_str is None: count = 1 # 默认值 @@ -38,10 +39,10 @@ class DiceCommand(BaseCommand): return False, "一次最多只能掷10个骰子" except ValueError: return False, "骰子数量必须是整数" - + # 生成随机数 results = [random.randint(1, 6) for _ in range(count)] - + # 构建回复消息 if count == 1: message = f"🎲 掷出了 {results[0]} 点" @@ -49,10 +50,10 @@ class DiceCommand(BaseCommand): dice_results = ", ".join(map(str, results)) total = sum(results) message = f"🎲 掷出了 {count} 个骰子: [{dice_results}],总点数: {total}" - + logger.info(f"{self.log_prefix} 执行骰子命令: {message}") return True, message - + except Exception as e: logger.error(f"{self.log_prefix} 执行骰子命令时出错: {e}") - return False, f"执行命令时出错: {str(e)}" \ No newline at end of file + return False, f"执行命令时出错: {str(e)}" diff --git a/src/plugins/example_command_plugin/commands/help_command.py b/src/plugins/example_command_plugin/commands/help_command.py index f2b440710..020f48300 100644 --- a/src/plugins/example_command_plugin/commands/help_command.py +++ b/src/plugins/example_command_plugin/commands/help_command.py @@ -4,90 +4,86 @@ from typing import Tuple, Optional logger = get_logger("help_command") + @register_command class HelpCommand(BaseCommand): """帮助命令,显示所有可用命令的帮助信息""" - + command_name = "help" command_description = "显示所有可用命令的帮助信息" command_pattern = r"^/help(?:\s+(?P\w+))?$" # 匹配 /help 或 /help 命令名 command_help = "使用方法: /help [命令名] - 显示所有命令或特定命令的帮助信息" command_examples = ["/help", "/help echo"] enable_command = True - + async def execute(self) -> Tuple[bool, Optional[str]]: """执行帮助命令 - + Returns: Tuple[bool, Optional[str]]: (是否执行成功, 回复消息) """ try: # 获取匹配到的命令名(如果有) command_name = self.matched_groups.get("command") - + # 如果指定了命令名,显示该命令的详细帮助 if command_name: logger.info(f"{self.log_prefix} 查询命令帮助: {command_name}") return self._show_command_help(command_name) - + # 否则,显示所有命令的简要帮助 logger.info(f"{self.log_prefix} 查询所有命令帮助") return self._show_all_commands() - + except Exception as e: logger.error(f"{self.log_prefix} 执行帮助命令时出错: {e}") return False, f"执行命令时出错: {str(e)}" - + def _show_command_help(self, command_name: str) -> Tuple[bool, str]: """显示特定命令的详细帮助信息 - + Args: command_name: 命令名称 - + Returns: Tuple[bool, str]: (是否执行成功, 回复消息) """ # 查找命令 command_cls = _COMMAND_REGISTRY.get(command_name) - + if not command_cls: return False, f"未找到命令: {command_name}" - + # 获取命令信息 description = getattr(command_cls, "command_description", "无描述") help_text = getattr(command_cls, "command_help", "无帮助信息") examples = getattr(command_cls, "command_examples", []) - + # 构建帮助信息 - help_info = [ - f"【命令】: {command_name}", - f"【描述】: {description}", - f"【用法】: {help_text}" - ] - + help_info = [f"【命令】: {command_name}", f"【描述】: {description}", f"【用法】: {help_text}"] + # 添加示例 if examples: help_info.append("【示例】:") for example in examples: help_info.append(f" {example}") - + return True, "\n".join(help_info) - + def _show_all_commands(self) -> Tuple[bool, str]: """显示所有可用命令的简要帮助信息 - + Returns: Tuple[bool, str]: (是否执行成功, 回复消息) """ # 获取所有已启用的命令 enabled_commands = { - name: cls for name, cls in _COMMAND_REGISTRY.items() - if getattr(cls, "enable_command", True) + name: cls for name, cls in _COMMAND_REGISTRY.items() if getattr(cls, "enable_command", True) } - + if not enabled_commands: return True, "当前没有可用的命令" - + # 构建命令列表 command_list = ["可用命令列表:"] for name, cls in sorted(enabled_commands.items()): @@ -107,9 +103,9 @@ class HelpCommand(BaseCommand): else: # 默认使用/name作为前缀 prefix = f"/{name}" - + command_list.append(f"{prefix} - {description}") - + command_list.append("\n使用 /help <命令名> 获取特定命令的详细帮助") - - return True, "\n".join(command_list) \ No newline at end of file + + return True, "\n".join(command_list) diff --git a/src/plugins/example_command_plugin/commands/message_info_command.py b/src/plugins/example_command_plugin/commands/message_info_command.py index aa30e24f5..4a73eb29b 100644 --- a/src/plugins/example_command_plugin/commands/message_info_command.py +++ b/src/plugins/example_command_plugin/commands/message_info_command.py @@ -1,43 +1,43 @@ from src.common.logger_manager import get_logger from src.chat.command.command_handler import BaseCommand, register_command from typing import Tuple, Optional -import json logger = get_logger("message_info_command") + @register_command class MessageInfoCommand(BaseCommand): """消息信息查看命令,展示发送命令的原始消息和相关信息""" - + command_name = "msginfo" command_description = "查看发送命令的原始消息信息" command_pattern = r"^/msginfo(?:\s+(?Pfull|simple))?$" command_help = "使用方法: /msginfo [full|simple] - 查看当前消息的详细信息" command_examples = ["/msginfo", "/msginfo full", "/msginfo simple"] enable_command = True - + async def execute(self) -> Tuple[bool, Optional[str]]: """执行消息信息查看命令""" try: detail_level = self.matched_groups.get("detail", "simple") - + logger.info(f"{self.log_prefix} 查看消息信息,详细级别: {detail_level}") - + if detail_level == "full": info_text = self._get_full_message_info() else: info_text = self._get_simple_message_info() - + return True, info_text - + except Exception as e: logger.error(f"{self.log_prefix} 获取消息信息时出错: {e}") return False, f"获取消息信息失败: {str(e)}" - + def _get_simple_message_info(self) -> str: """获取简化的消息信息""" message = self.message - + # 基础信息 info_lines = [ "📨 消息信息概览", @@ -45,157 +45,181 @@ class MessageInfoCommand(BaseCommand): f"⏰ 时间: {message.message_info.time}", f"🌐 平台: {message.message_info.platform}", ] - + # 发送者信息 user = message.message_info.user_info - info_lines.extend([ - "", - "👤 发送者信息:", - f" 用户ID: {user.user_id}", - f" 昵称: {user.user_nickname}", - f" 群名片: {user.user_cardname or '无'}", - ]) - + info_lines.extend( + [ + "", + "👤 发送者信息:", + f" 用户ID: {user.user_id}", + f" 昵称: {user.user_nickname}", + f" 群名片: {user.user_cardname or '无'}", + ] + ) + # 群聊信息(如果是群聊) if message.message_info.group_info: group = message.message_info.group_info - info_lines.extend([ - "", - "👥 群聊信息:", - f" 群ID: {group.group_id}", - f" 群名: {group.group_name or '未知'}", - ]) + info_lines.extend( + [ + "", + "👥 群聊信息:", + f" 群ID: {group.group_id}", + f" 群名: {group.group_name or '未知'}", + ] + ) else: - info_lines.extend([ - "", - "💬 消息类型: 私聊消息", - ]) - + info_lines.extend( + [ + "", + "💬 消息类型: 私聊消息", + ] + ) + # 消息内容 - info_lines.extend([ - "", - "📝 消息内容:", - f" 原始文本: {message.processed_plain_text}", - f" 是否表情: {'是' if getattr(message, 'is_emoji', False) else '否'}", - ]) - - # 聊天流信息 - if hasattr(message, 'chat_stream') and message.chat_stream: - chat_stream = message.chat_stream - info_lines.extend([ + info_lines.extend( + [ "", - "🔄 聊天流信息:", - f" 流ID: {chat_stream.stream_id}", - f" 是否激活: {'是' if chat_stream.is_active else '否'}", - ]) - + "📝 消息内容:", + f" 原始文本: {message.processed_plain_text}", + f" 是否表情: {'是' if getattr(message, 'is_emoji', False) else '否'}", + ] + ) + + # 聊天流信息 + if hasattr(message, "chat_stream") and message.chat_stream: + chat_stream = message.chat_stream + info_lines.extend( + [ + "", + "🔄 聊天流信息:", + f" 流ID: {chat_stream.stream_id}", + f" 是否激活: {'是' if chat_stream.is_active else '否'}", + ] + ) + return "\n".join(info_lines) - + def _get_full_message_info(self) -> str: """获取完整的消息信息(包含技术细节)""" message = self.message - + info_lines = [ "📨 完整消息信息", "=" * 40, ] - + # 消息基础信息 - info_lines.extend([ - "", - "🔍 基础消息信息:", - f" 消息ID: {message.message_info.message_id}", - f" 时间戳: {message.message_info.time}", - f" 平台: {message.message_info.platform}", - f" 处理后文本: {message.processed_plain_text}", - f" 详细文本: {message.detailed_plain_text[:100]}{'...' if len(message.detailed_plain_text) > 100 else ''}", - ]) - + info_lines.extend( + [ + "", + "🔍 基础消息信息:", + f" 消息ID: {message.message_info.message_id}", + f" 时间戳: {message.message_info.time}", + f" 平台: {message.message_info.platform}", + f" 处理后文本: {message.processed_plain_text}", + f" 详细文本: {message.detailed_plain_text[:100]}{'...' if len(message.detailed_plain_text) > 100 else ''}", + ] + ) + # 用户详细信息 user = message.message_info.user_info - info_lines.extend([ - "", - "👤 发送者详细信息:", - f" 用户ID: {user.user_id}", - f" 昵称: {user.user_nickname}", - f" 群名片: {user.user_cardname or '无'}", - f" 平台: {user.platform}", - ]) - + info_lines.extend( + [ + "", + "👤 发送者详细信息:", + f" 用户ID: {user.user_id}", + f" 昵称: {user.user_nickname}", + f" 群名片: {user.user_cardname or '无'}", + f" 平台: {user.platform}", + ] + ) + # 群聊详细信息 if message.message_info.group_info: group = message.message_info.group_info - info_lines.extend([ - "", - "👥 群聊详细信息:", - f" 群ID: {group.group_id}", - f" 群名: {group.group_name or '未知'}", - f" 平台: {group.platform}", - ]) + info_lines.extend( + [ + "", + "👥 群聊详细信息:", + f" 群ID: {group.group_id}", + f" 群名: {group.group_name or '未知'}", + f" 平台: {group.platform}", + ] + ) else: info_lines.append("\n💬 消息类型: 私聊消息") - + # 消息段信息 if message.message_segment: - info_lines.extend([ - "", - "📦 消息段信息:", - f" 类型: {message.message_segment.type}", - f" 数据类型: {type(message.message_segment.data).__name__}", - f" 数据预览: {str(message.message_segment.data)[:200]}{'...' if len(str(message.message_segment.data)) > 200 else ''}", - ]) - + info_lines.extend( + [ + "", + "📦 消息段信息:", + f" 类型: {message.message_segment.type}", + f" 数据类型: {type(message.message_segment.data).__name__}", + f" 数据预览: {str(message.message_segment.data)[:200]}{'...' if len(str(message.message_segment.data)) > 200 else ''}", + ] + ) + # 聊天流详细信息 - if hasattr(message, 'chat_stream') and message.chat_stream: + if hasattr(message, "chat_stream") and message.chat_stream: chat_stream = message.chat_stream - info_lines.extend([ - "", - "🔄 聊天流详细信息:", - f" 流ID: {chat_stream.stream_id}", - f" 平台: {chat_stream.platform}", - f" 是否激活: {'是' if chat_stream.is_active else '否'}", - f" 用户信息: {chat_stream.user_info.user_nickname} ({chat_stream.user_info.user_id})", - f" 群信息: {getattr(chat_stream.group_info, 'group_name', '私聊') if chat_stream.group_info else '私聊'}", - ]) - + info_lines.extend( + [ + "", + "🔄 聊天流详细信息:", + f" 流ID: {chat_stream.stream_id}", + f" 平台: {chat_stream.platform}", + f" 是否激活: {'是' if chat_stream.is_active else '否'}", + f" 用户信息: {chat_stream.user_info.user_nickname} ({chat_stream.user_info.user_id})", + f" 群信息: {getattr(chat_stream.group_info, 'group_name', '私聊') if chat_stream.group_info else '私聊'}", + ] + ) + # 回复信息 - if hasattr(message, 'reply') and message.reply: - info_lines.extend([ - "", - "↩️ 回复信息:", - f" 回复消息ID: {message.reply.message_info.message_id}", - f" 回复内容: {message.reply.processed_plain_text[:100]}{'...' if len(message.reply.processed_plain_text) > 100 else ''}", - ]) - + if hasattr(message, "reply") and message.reply: + info_lines.extend( + [ + "", + "↩️ 回复信息:", + f" 回复消息ID: {message.reply.message_info.message_id}", + f" 回复内容: {message.reply.processed_plain_text[:100]}{'...' if len(message.reply.processed_plain_text) > 100 else ''}", + ] + ) + # 原始消息数据(如果存在) - if hasattr(message, 'raw_message') and message.raw_message: - info_lines.extend([ - "", - "🗂️ 原始消息数据:", - f" 数据类型: {type(message.raw_message).__name__}", - f" 数据大小: {len(str(message.raw_message))} 字符", - ]) - + if hasattr(message, "raw_message") and message.raw_message: + info_lines.extend( + [ + "", + "🗂️ 原始消息数据:", + f" 数据类型: {type(message.raw_message).__name__}", + f" 数据大小: {len(str(message.raw_message))} 字符", + ] + ) + return "\n".join(info_lines) @register_command class SenderInfoCommand(BaseCommand): """发送者信息命令,快速查看发送者信息""" - + command_name = "whoami" command_description = "查看发送命令的用户信息" command_pattern = r"^/whoami$" command_help = "使用方法: /whoami - 查看你的用户信息" command_examples = ["/whoami"] enable_command = True - + async def execute(self) -> Tuple[bool, Optional[str]]: """执行发送者信息查看命令""" try: user = self.message.message_info.user_info group = self.message.message_info.group_info - + info_lines = [ "👤 你的身份信息", f"🆔 用户ID: {user.user_id}", @@ -203,19 +227,21 @@ class SenderInfoCommand(BaseCommand): f"🏷️ 群名片: {user.user_cardname or '无'}", f"🌐 平台: {user.platform}", ] - + if group: - info_lines.extend([ - "", - "👥 当前群聊:", - f"🆔 群ID: {group.group_id}", - f"📝 群名: {group.group_name or '未知'}", - ]) + info_lines.extend( + [ + "", + "👥 当前群聊:", + f"🆔 群ID: {group.group_id}", + f"📝 群名: {group.group_name or '未知'}", + ] + ) else: info_lines.append("\n💬 当前在私聊中") - + return True, "\n".join(info_lines) - + except Exception as e: logger.error(f"{self.log_prefix} 获取发送者信息时出错: {e}") return False, f"获取发送者信息失败: {str(e)}" @@ -224,59 +250,65 @@ class SenderInfoCommand(BaseCommand): @register_command class ChatStreamInfoCommand(BaseCommand): """聊天流信息命令""" - + command_name = "streaminfo" command_description = "查看当前聊天流的详细信息" command_pattern = r"^/streaminfo$" command_help = "使用方法: /streaminfo - 查看当前聊天流信息" command_examples = ["/streaminfo"] enable_command = True - + async def execute(self) -> Tuple[bool, Optional[str]]: """执行聊天流信息查看命令""" try: - if not hasattr(self.message, 'chat_stream') or not self.message.chat_stream: + if not hasattr(self.message, "chat_stream") or not self.message.chat_stream: return False, "无法获取聊天流信息" - + chat_stream = self.message.chat_stream - + info_lines = [ "🔄 聊天流信息", f"🆔 流ID: {chat_stream.stream_id}", f"🌐 平台: {chat_stream.platform}", f"⚡ 状态: {'激活' if chat_stream.is_active else '非激活'}", ] - + # 用户信息 if chat_stream.user_info: - info_lines.extend([ - "", - "👤 关联用户:", - f" ID: {chat_stream.user_info.user_id}", - f" 昵称: {chat_stream.user_info.user_nickname}", - ]) - + info_lines.extend( + [ + "", + "👤 关联用户:", + f" ID: {chat_stream.user_info.user_id}", + f" 昵称: {chat_stream.user_info.user_nickname}", + ] + ) + # 群信息 if chat_stream.group_info: - info_lines.extend([ - "", - "👥 关联群聊:", - f" 群ID: {chat_stream.group_info.group_id}", - f" 群名: {chat_stream.group_info.group_name or '未知'}", - ]) + info_lines.extend( + [ + "", + "👥 关联群聊:", + f" 群ID: {chat_stream.group_info.group_id}", + f" 群名: {chat_stream.group_info.group_name or '未知'}", + ] + ) else: info_lines.append("\n💬 类型: 私聊流") - + # 最近消息统计 - if hasattr(chat_stream, 'last_messages'): + if hasattr(chat_stream, "last_messages"): msg_count = len(chat_stream.last_messages) - info_lines.extend([ - "", - f"📈 消息统计: 记录了 {msg_count} 条最近消息", - ]) - + info_lines.extend( + [ + "", + f"📈 消息统计: 记录了 {msg_count} 条最近消息", + ] + ) + return True, "\n".join(info_lines) - + except Exception as e: logger.error(f"{self.log_prefix} 获取聊天流信息时出错: {e}") - return False, f"获取聊天流信息失败: {str(e)}" \ No newline at end of file + return False, f"获取聊天流信息失败: {str(e)}" diff --git a/src/plugins/example_command_plugin/commands/send_msg_commad.py b/src/plugins/example_command_plugin/commands/send_msg_commad.py index 7953eb5af..0b4176467 100644 --- a/src/plugins/example_command_plugin/commands/send_msg_commad.py +++ b/src/plugins/example_command_plugin/commands/send_msg_commad.py @@ -5,43 +5,41 @@ from typing import Tuple, Optional logger = get_logger("send_msg_command") + @register_command class SendMessageCommand(BaseCommand, MessageAPI): """发送消息命令,可以向指定群聊或私聊发送消息""" - + command_name = "send" command_description = "向指定群聊或私聊发送消息" command_pattern = r"^/send\s+(?Pgroup|user)\s+(?P\d+)\s+(?P.+)$" command_help = "使用方法: /send <消息内容> - 发送消息到指定群聊或用户" - command_examples = [ - "/send group 123456789 大家好!", - "/send user 987654321 私聊消息" - ] + command_examples = ["/send group 123456789 大家好!", "/send user 987654321 私聊消息"] enable_command = True - + def __init__(self, message): super().__init__(message) # 初始化MessageAPI需要的服务(虽然这里不会用到,但保持一致性) self._services = {} self.log_prefix = f"[Command:{self.command_name}]" - + async def execute(self) -> Tuple[bool, Optional[str]]: """执行发送消息命令 - + Returns: Tuple[bool, Optional[str]]: (是否执行成功, 回复消息) """ try: # 获取匹配到的参数 target_type = self.matched_groups.get("target_type") # group 或 user - target_id = self.matched_groups.get("target_id") # 群ID或用户ID - content = self.matched_groups.get("content") # 消息内容 - + target_id = self.matched_groups.get("target_id") # 群ID或用户ID + content = self.matched_groups.get("content") # 消息内容 + if not all([target_type, target_id, content]): return False, "命令参数不完整,请检查格式" - + logger.info(f"{self.log_prefix} 执行发送消息命令: {target_type}:{target_id} -> {content[:50]}...") - + # 根据目标类型调用不同的发送方法 if target_type == "group": success = await self._send_to_group(target_id, content) @@ -51,24 +49,24 @@ class SendMessageCommand(BaseCommand, MessageAPI): target_desc = f"用户 {target_id}" else: return False, f"不支持的目标类型: {target_type},只支持 group 或 user" - + # 返回执行结果 if success: return True, f"✅ 消息已成功发送到 {target_desc}" else: return False, f"❌ 消息发送失败,可能是目标 {target_desc} 不存在或没有权限" - + except Exception as e: logger.error(f"{self.log_prefix} 执行发送消息命令时出错: {e}") return False, f"命令执行出错: {str(e)}" - + async def _send_to_group(self, group_id: str, content: str) -> bool: """发送消息到群聊 - + Args: group_id: 群聊ID content: 消息内容 - + Returns: bool: 是否发送成功 """ @@ -76,27 +74,27 @@ class SendMessageCommand(BaseCommand, MessageAPI): success = await self.send_text_to_group( text=content, group_id=group_id, - platform="qq" # 默认使用QQ平台 + platform="qq", # 默认使用QQ平台 ) - + if success: logger.info(f"{self.log_prefix} 成功发送消息到群聊 {group_id}") else: logger.warning(f"{self.log_prefix} 发送消息到群聊 {group_id} 失败") - + return success - + except Exception as e: logger.error(f"{self.log_prefix} 发送群聊消息时出错: {e}") return False - + async def _send_to_user(self, user_id: str, content: str) -> bool: """发送消息到私聊 - + Args: user_id: 用户ID content: 消息内容 - + Returns: bool: 是否发送成功 """ @@ -104,16 +102,16 @@ class SendMessageCommand(BaseCommand, MessageAPI): success = await self.send_text_to_user( text=content, user_id=user_id, - platform="qq" # 默认使用QQ平台 + platform="qq", # 默认使用QQ平台 ) - + if success: logger.info(f"{self.log_prefix} 成功发送消息到用户 {user_id}") else: logger.warning(f"{self.log_prefix} 发送消息到用户 {user_id} 失败") - + return success - + except Exception as e: logger.error(f"{self.log_prefix} 发送私聊消息时出错: {e}") - return False \ No newline at end of file + return False diff --git a/src/plugins/example_command_plugin/commands/send_msg_enhanced.py b/src/plugins/example_command_plugin/commands/send_msg_enhanced.py index 810d4f15d..bd46da916 100644 --- a/src/plugins/example_command_plugin/commands/send_msg_enhanced.py +++ b/src/plugins/example_command_plugin/commands/send_msg_enhanced.py @@ -5,10 +5,11 @@ from typing import Tuple, Optional logger = get_logger("send_msg_enhanced") + @register_command class SendMessageEnhancedCommand(BaseCommand, MessageAPI): """增强版发送消息命令,支持多种消息类型和平台""" - + command_name = "sendfull" command_description = "增强版消息发送命令,支持多种类型和平台" command_pattern = r"^/sendfull\s+(?Ptext|image|emoji)\s+(?Pgroup|user)\s+(?P\d+)(?:\s+(?P\w+))?\s+(?P.+)$" @@ -17,108 +18,93 @@ class SendMessageEnhancedCommand(BaseCommand, MessageAPI): "/sendfull text group 123456789 qq 大家好!这是文本消息", "/sendfull image user 987654321 https://example.com/image.jpg", "/sendfull emoji group 123456789 😄", - "/sendfull text user 987654321 qq 私聊消息" + "/sendfull text user 987654321 qq 私聊消息", ] enable_command = True - + def __init__(self, message): super().__init__(message) self._services = {} self.log_prefix = f"[Command:{self.command_name}]" - + async def execute(self) -> Tuple[bool, Optional[str]]: """执行增强版发送消息命令""" try: # 获取匹配参数 - msg_type = self.matched_groups.get("msg_type") # 消息类型: text/image/emoji - target_type = self.matched_groups.get("target_type") # 目标类型: group/user - target_id = self.matched_groups.get("target_id") # 目标ID - platform = self.matched_groups.get("platform") or "qq" # 平台,默认qq - content = self.matched_groups.get("content") # 内容 - + msg_type = self.matched_groups.get("msg_type") # 消息类型: text/image/emoji + target_type = self.matched_groups.get("target_type") # 目标类型: group/user + target_id = self.matched_groups.get("target_id") # 目标ID + platform = self.matched_groups.get("platform") or "qq" # 平台,默认qq + content = self.matched_groups.get("content") # 内容 + if not all([msg_type, target_type, target_id, content]): return False, "命令参数不完整,请检查格式" - + # 验证消息类型 valid_types = ["text", "image", "emoji"] if msg_type not in valid_types: return False, f"不支持的消息类型: {msg_type},支持的类型: {', '.join(valid_types)}" - + # 验证目标类型 if target_type not in ["group", "user"]: return False, "目标类型只能是 group 或 user" - + logger.info(f"{self.log_prefix} 执行发送命令: {msg_type} -> {target_type}:{target_id} (平台:{platform})") - + # 根据消息类型和目标类型发送消息 - is_group = (target_type == "group") + is_group = target_type == "group" success = await self.send_message_to_target( - message_type=msg_type, - content=content, - platform=platform, - target_id=target_id, - is_group=is_group + message_type=msg_type, content=content, platform=platform, target_id=target_id, is_group=is_group ) - + # 构建结果消息 target_desc = f"{'群聊' if is_group else '用户'} {target_id} (平台: {platform})" - msg_type_desc = { - "text": "文本", - "image": "图片", - "emoji": "表情" - }.get(msg_type, msg_type) - + msg_type_desc = {"text": "文本", "image": "图片", "emoji": "表情"}.get(msg_type, msg_type) + if success: return True, f"✅ {msg_type_desc}消息已成功发送到 {target_desc}" else: return False, f"❌ {msg_type_desc}消息发送失败,可能是目标 {target_desc} 不存在或没有权限" - + except Exception as e: logger.error(f"{self.log_prefix} 执行增强发送命令时出错: {e}") return False, f"命令执行出错: {str(e)}" -@register_command +@register_command class SendQuickCommand(BaseCommand, MessageAPI): """快速发送文本消息命令""" - + command_name = "msg" command_description = "快速发送文本消息到群聊" command_pattern = r"^/msg\s+(?P\d+)\s+(?P.+)$" command_help = "使用方法: /msg <群ID> <消息内容> - 快速发送文本到指定群聊" - command_examples = [ - "/msg 123456789 大家好!", - "/msg 987654321 这是一条快速消息" - ] + command_examples = ["/msg 123456789 大家好!", "/msg 987654321 这是一条快速消息"] enable_command = True - + def __init__(self, message): super().__init__(message) self._services = {} self.log_prefix = f"[Command:{self.command_name}]" - + async def execute(self) -> Tuple[bool, Optional[str]]: """执行快速发送消息命令""" try: group_id = self.matched_groups.get("group_id") content = self.matched_groups.get("content") - + if not all([group_id, content]): return False, "命令参数不完整" - + logger.info(f"{self.log_prefix} 快速发送到群 {group_id}: {content[:50]}...") - - success = await self.send_text_to_group( - text=content, - group_id=group_id, - platform="qq" - ) - + + success = await self.send_text_to_group(text=content, group_id=group_id, platform="qq") + if success: return True, f"✅ 消息已发送到群 {group_id}" else: return False, f"❌ 发送到群 {group_id} 失败" - + except Exception as e: logger.error(f"{self.log_prefix} 快速发送命令出错: {e}") return False, f"发送失败: {str(e)}" @@ -127,44 +113,37 @@ class SendQuickCommand(BaseCommand, MessageAPI): @register_command class SendPrivateCommand(BaseCommand, MessageAPI): """发送私聊消息命令""" - + command_name = "pm" command_description = "发送私聊消息到指定用户" command_pattern = r"^/pm\s+(?P\d+)\s+(?P.+)$" command_help = "使用方法: /pm <用户ID> <消息内容> - 发送私聊消息" - command_examples = [ - "/pm 123456789 你好!", - "/pm 987654321 这是私聊消息" - ] + command_examples = ["/pm 123456789 你好!", "/pm 987654321 这是私聊消息"] enable_command = True - + def __init__(self, message): super().__init__(message) self._services = {} self.log_prefix = f"[Command:{self.command_name}]" - + async def execute(self) -> Tuple[bool, Optional[str]]: """执行私聊发送命令""" try: user_id = self.matched_groups.get("user_id") content = self.matched_groups.get("content") - + if not all([user_id, content]): return False, "命令参数不完整" - + logger.info(f"{self.log_prefix} 发送私聊到用户 {user_id}: {content[:50]}...") - - success = await self.send_text_to_user( - text=content, - user_id=user_id, - platform="qq" - ) - + + success = await self.send_text_to_user(text=content, user_id=user_id, platform="qq") + if success: return True, f"✅ 私聊消息已发送到用户 {user_id}" else: return False, f"❌ 发送私聊到用户 {user_id} 失败" - + except Exception as e: logger.error(f"{self.log_prefix} 私聊发送命令出错: {e}") - return False, f"私聊发送失败: {str(e)}" \ No newline at end of file + return False, f"私聊发送失败: {str(e)}" diff --git a/src/plugins/example_command_plugin/commands/send_msg_with_context.py b/src/plugins/example_command_plugin/commands/send_msg_with_context.py index dd6d8de87..a2b485fff 100644 --- a/src/plugins/example_command_plugin/commands/send_msg_with_context.py +++ b/src/plugins/example_command_plugin/commands/send_msg_with_context.py @@ -6,173 +6,153 @@ import time logger = get_logger("send_msg_with_context") + @register_command class ContextAwareSendCommand(BaseCommand, MessageAPI): """上下文感知的发送消息命令,展示如何利用原始消息信息""" - + command_name = "csend" command_description = "带上下文感知的发送消息命令" - command_pattern = r"^/csend\s+(?Pgroup|user|here|reply)\s+(?P.*?)(?:\s+(?P.*))?$" + command_pattern = ( + r"^/csend\s+(?Pgroup|user|here|reply)\s+(?P.*?)(?:\s+(?P.*))?$" + ) command_help = "使用方法: /csend <参数> [内容]" command_examples = [ "/csend group 123456789 大家好!", - "/csend user 987654321 私聊消息", + "/csend user 987654321 私聊消息", "/csend here 在当前聊天发送", - "/csend reply 回复当前群/私聊" + "/csend reply 回复当前群/私聊", ] enable_command = True - + # 管理员用户ID列表(示例) ADMIN_USERS = ["123456789", "987654321"] # 可以从配置文件读取 - + def __init__(self, message): super().__init__(message) self._services = {} self.log_prefix = f"[Command:{self.command_name}]" - + async def execute(self) -> Tuple[bool, Optional[str]]: """执行上下文感知的发送命令""" try: # 获取命令发送者信息 sender = self.message.message_info.user_info current_group = self.message.message_info.group_info - + # 权限检查 if not self._check_permission(sender.user_id): return False, f"❌ 权限不足,只有管理员可以使用此命令\n你的ID: {sender.user_id}" - + # 解析命令参数 target_type = self.matched_groups.get("target_type") target_id_or_content = self.matched_groups.get("target_id_or_content", "") content = self.matched_groups.get("content", "") - + # 根据目标类型处理不同情况 if target_type == "here": # 发送到当前聊天 return await self._send_to_current_chat(target_id_or_content, sender, current_group) - + elif target_type == "reply": # 回复到当前聊天,带发送者信息 return await self._send_reply_with_context(target_id_or_content, sender, current_group) - + elif target_type in ["group", "user"]: # 发送到指定目标 if not content: return False, "指定群聊或用户时需要提供消息内容" return await self._send_to_target(target_type, target_id_or_content, content, sender) - + else: return False, f"不支持的目标类型: {target_type}" - + except Exception as e: logger.error(f"{self.log_prefix} 执行上下文感知发送命令时出错: {e}") return False, f"命令执行出错: {str(e)}" - + def _check_permission(self, user_id: str) -> bool: """检查用户权限""" return user_id in self.ADMIN_USERS - + async def _send_to_current_chat(self, content: str, sender, current_group) -> Tuple[bool, str]: """发送到当前聊天""" if not content: return False, "消息内容不能为空" - + # 构建带发送者信息的消息 timestamp = time.strftime("%H:%M:%S", time.localtime()) if current_group: # 群聊 formatted_content = f"[管理员转发 {timestamp}] {sender.user_nickname}({sender.user_id}): {content}" success = await self.send_text_to_group( - text=formatted_content, - group_id=current_group.group_id, - platform="qq" + text=formatted_content, group_id=current_group.group_id, platform="qq" ) target_desc = f"当前群聊 {current_group.group_name}({current_group.group_id})" else: # 私聊 formatted_content = f"[管理员消息 {timestamp}]: {content}" - success = await self.send_text_to_user( - text=formatted_content, - user_id=sender.user_id, - platform="qq" - ) + success = await self.send_text_to_user(text=formatted_content, user_id=sender.user_id, platform="qq") target_desc = "当前私聊" - + if success: return True, f"✅ 消息已发送到{target_desc}" else: return False, f"❌ 发送到{target_desc}失败" - + async def _send_reply_with_context(self, content: str, sender, current_group) -> Tuple[bool, str]: """发送回复,带完整上下文信息""" if not content: return False, "回复内容不能为空" - + # 获取当前时间和环境信息 timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - + # 构建上下文信息 context_info = [ f"📢 管理员回复 [{timestamp}]", f"👤 发送者: {sender.user_nickname}({sender.user_id})", ] - + if current_group: context_info.append(f"👥 当前群聊: {current_group.group_name}({current_group.group_id})") target_desc = f"群聊 {current_group.group_name}" else: context_info.append("💬 当前环境: 私聊") target_desc = "私聊" - - context_info.extend([ - f"📝 回复内容: {content}", - "─" * 30 - ]) - + + context_info.extend([f"📝 回复内容: {content}", "─" * 30]) + formatted_content = "\n".join(context_info) - + # 发送消息 if current_group: success = await self.send_text_to_group( - text=formatted_content, - group_id=current_group.group_id, - platform="qq" + text=formatted_content, group_id=current_group.group_id, platform="qq" ) else: - success = await self.send_text_to_user( - text=formatted_content, - user_id=sender.user_id, - platform="qq" - ) - + success = await self.send_text_to_user(text=formatted_content, user_id=sender.user_id, platform="qq") + if success: return True, f"✅ 带上下文的回复已发送到{target_desc}" else: return False, f"❌ 发送上下文回复到{target_desc}失败" - + async def _send_to_target(self, target_type: str, target_id: str, content: str, sender) -> Tuple[bool, str]: """发送到指定目标,带发送者追踪信息""" timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - + # 构建带追踪信息的消息 tracking_info = f"[管理转发 {timestamp}] 来自 {sender.user_nickname}({sender.user_id})" formatted_content = f"{tracking_info}\n{content}" - + if target_type == "group": - success = await self.send_text_to_group( - text=formatted_content, - group_id=target_id, - platform="qq" - ) + success = await self.send_text_to_group(text=formatted_content, group_id=target_id, platform="qq") target_desc = f"群聊 {target_id}" else: # user - success = await self.send_text_to_user( - text=formatted_content, - user_id=target_id, - platform="qq" - ) + success = await self.send_text_to_user(text=formatted_content, user_id=target_id, platform="qq") target_desc = f"用户 {target_id}" - + if success: return True, f"✅ 带追踪信息的消息已发送到{target_desc}" else: @@ -182,21 +162,21 @@ class ContextAwareSendCommand(BaseCommand, MessageAPI): @register_command class MessageContextCommand(BaseCommand): """消息上下文命令,展示如何获取和利用上下文信息""" - + command_name = "context" command_description = "显示当前消息的完整上下文信息" command_pattern = r"^/context$" command_help = "使用方法: /context - 显示当前环境的上下文信息" command_examples = ["/context"] enable_command = True - + async def execute(self) -> Tuple[bool, Optional[str]]: """显示上下文信息""" try: message = self.message user = message.message_info.user_info group = message.message_info.group_info - + # 构建上下文信息 context_lines = [ "🌐 当前上下文信息", @@ -212,42 +192,50 @@ class MessageContextCommand(BaseCommand): f" 群名片: {user.user_cardname or '无'}", f" 平台: {user.platform}", ] - + if group: - context_lines.extend([ - "", - "👥 群聊环境:", - f" 群ID: {group.group_id}", - f" 群名: {group.group_name or '未知'}", - f" 平台: {group.platform}", - ]) + context_lines.extend( + [ + "", + "👥 群聊环境:", + f" 群ID: {group.group_id}", + f" 群名: {group.group_name or '未知'}", + f" 平台: {group.platform}", + ] + ) else: - context_lines.extend([ - "", - "💬 私聊环境", - ]) - + context_lines.extend( + [ + "", + "💬 私聊环境", + ] + ) + # 添加聊天流信息 - if hasattr(message, 'chat_stream') and message.chat_stream: + if hasattr(message, "chat_stream") and message.chat_stream: chat_stream = message.chat_stream - context_lines.extend([ - "", - "🔄 聊天流:", - f" 流ID: {chat_stream.stream_id}", - f" 激活状态: {'激活' if chat_stream.is_active else '非激活'}", - ]) - + context_lines.extend( + [ + "", + "🔄 聊天流:", + f" 流ID: {chat_stream.stream_id}", + f" 激活状态: {'激活' if chat_stream.is_active else '非激活'}", + ] + ) + # 添加消息内容信息 - context_lines.extend([ - "", - "📝 消息内容:", - f" 原始内容: {message.processed_plain_text}", - f" 消息长度: {len(message.processed_plain_text)} 字符", - f" 消息ID: {message.message_info.message_id}", - ]) - + context_lines.extend( + [ + "", + "📝 消息内容:", + f" 原始内容: {message.processed_plain_text}", + f" 消息长度: {len(message.processed_plain_text)} 字符", + f" 消息ID: {message.message_info.message_id}", + ] + ) + return True, "\n".join(context_lines) - + except Exception as e: logger.error(f"{self.log_prefix} 获取上下文信息时出错: {e}") - return False, f"获取上下文失败: {str(e)}" \ No newline at end of file + return False, f"获取上下文失败: {str(e)}" diff --git a/src/plugins/mute_plugin/actions/__init__.py b/src/plugins/mute_plugin/actions/__init__.py index a715e2fa7..e44fd983c 100644 --- a/src/plugins/mute_plugin/actions/__init__.py +++ b/src/plugins/mute_plugin/actions/__init__.py @@ -1,2 +1,3 @@ """测试插件动作模块""" + from . import mute_action # noqa diff --git a/src/plugins/mute_plugin/actions/mute_action.py b/src/plugins/mute_plugin/actions/mute_action.py index 969076e70..a50e18ed0 100644 --- a/src/plugins/mute_plugin/actions/mute_action.py +++ b/src/plugins/mute_plugin/actions/mute_action.py @@ -22,21 +22,20 @@ class MuteAction(PluginAction): "当有人刷屏时使用", "当有人发了擦边,或者色情内容时使用", "当有人要求禁言自己时使用", - "如果某人已经被禁言了,就不要再次禁言了,除非你想追加时间!!" + "如果某人已经被禁言了,就不要再次禁言了,除非你想追加时间!!", ] enable_plugin = False # 启用插件 associated_types = ["command", "text"] action_config_file_name = "mute_action_config.toml" - + # 激活类型设置 focus_activation_type = ActionActivationType.LLM_JUDGE # Focus模式使用LLM判定,确保谨慎 - normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活,快速响应 - - + normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活,快速响应 + # 关键词设置(用于Normal模式) activation_keywords = ["禁言", "mute", "ban", "silence"] keyword_case_sensitive = False - + # LLM判定提示词(用于Focus模式) llm_judge_prompt = """ 判定是否需要使用禁言动作的严格条件: @@ -59,13 +58,13 @@ class MuteAction(PluginAction): 注意:禁言是严厉措施,只在明确违规或用户主动要求时使用。 宁可保守也不要误判,保护用户的发言权利。 """ - + # Random激活概率(备用) random_activation_probability = 0.05 # 设置很低的概率作为兜底 # 模式启用设置 - 禁言功能在所有模式下都可用 mode_enable = ChatMode.ALL - + # 并行执行设置 - 禁言动作可以与回复并行执行,不覆盖回复内容 parallel_action = False @@ -73,15 +72,15 @@ class MuteAction(PluginAction): super().__init__(*args, **kwargs) # 生成配置文件(如果不存在) self._generate_config_if_needed() - + def _generate_config_if_needed(self): """生成配置文件(如果不存在)""" import os - + # 获取动作文件所在目录 current_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(current_dir, "mute_action_config.toml") - + if not os.path.exists(config_path): config_content = """\ # 禁言动作配置文件 @@ -130,11 +129,10 @@ log_mute_history = true def _get_template_message(self, target: str, duration_str: str, reason: str) -> str: """获取模板化的禁言消息""" - templates = self.config.get("templates", [ - "好的,禁言 {target} {duration},理由:{reason}" - ]) - + templates = self.config.get("templates", ["好的,禁言 {target} {duration},理由:{reason}"]) + import random + template = random.choice(templates) return template.format(target=target, duration=duration_str, reason=reason) @@ -162,7 +160,7 @@ log_mute_history = true # 获取时长限制配置 min_duration, max_duration, default_duration = self._get_duration_limits() - + # 验证时长格式并转换 try: duration_int = int(duration) @@ -170,9 +168,11 @@ log_mute_history = true error_msg = "禁言时长必须大于0" logger.error(f"{self.log_prefix} {error_msg}") error_templates = self.config.get("error_messages", ["禁言时长必须是正数哦~"]) - await self.send_message_by_expressor(error_templates[2] if len(error_templates) > 2 else "禁言时长必须是正数哦~") + await self.send_message_by_expressor( + error_templates[2] if len(error_templates) > 2 else "禁言时长必须是正数哦~" + ) return False, error_msg - + # 限制禁言时长范围 if duration_int < min_duration: duration_int = min_duration @@ -180,12 +180,14 @@ log_mute_history = true elif duration_int > max_duration: duration_int = max_duration logger.info(f"{self.log_prefix} 禁言时长过长,调整为{max_duration}秒") - - except (ValueError, TypeError) as e: + + except (ValueError, TypeError): error_msg = f"禁言时长格式无效: {duration}" logger.error(f"{self.log_prefix} {error_msg}") error_templates = self.config.get("error_messages", ["禁言时长必须是数字哦~"]) - await self.send_message_by_expressor(error_templates[3] if len(error_templates) > 3 else "禁言时长必须是数字哦~") + await self.send_message_by_expressor( + error_templates[3] if len(error_templates) > 3 else "禁言时长必须是数字哦~" + ) return False, error_msg # 获取用户ID @@ -206,7 +208,7 @@ log_mute_history = true # 发送表达情绪的消息 enable_formatting = self.config.get("enable_duration_formatting", True) time_str = self._format_duration(duration_int) if enable_formatting else f"{duration_int}秒" - + # 使用模板化消息 message = self._get_template_message(target, time_str, reason) await self.send_message_by_expressor(message) diff --git a/src/plugins/plugin_loader.py b/src/plugins/plugin_loader.py index 7779c1307..107150570 100644 --- a/src/plugins/plugin_loader.py +++ b/src/plugins/plugin_loader.py @@ -1,7 +1,7 @@ import importlib import pkgutil import os -from typing import Dict, List, Tuple +from typing import Dict, Tuple from src.common.logger_manager import get_logger logger = get_logger("plugin_loader") @@ -9,52 +9,53 @@ logger = get_logger("plugin_loader") class PluginLoader: """统一的插件加载器,负责加载插件的所有组件(actions、commands等)""" - + def __init__(self): self.loaded_actions = 0 self.loaded_commands = 0 self.plugin_stats: Dict[str, Dict[str, int]] = {} # 统计每个插件加载的组件数量 self.plugin_sources: Dict[str, str] = {} # 记录每个插件来自哪个路径 - + def load_all_plugins(self) -> Tuple[int, int]: """加载所有插件的所有组件 - + Returns: Tuple[int, int]: (加载的动作数量, 加载的命令数量) """ # 定义插件搜索路径(优先级从高到低) plugin_paths = [ ("plugins", "plugins"), # 项目根目录的plugins文件夹 - ("src.plugins", os.path.join("src", "plugins")) # src下的plugins文件夹 + ("src.plugins", os.path.join("src", "plugins")), # src下的plugins文件夹 ] - + total_plugins_found = 0 - + for plugin_import_path, plugin_dir_path in plugin_paths: try: plugins_loaded = self._load_plugins_from_path(plugin_import_path, plugin_dir_path) total_plugins_found += plugins_loaded - + except Exception as e: logger.error(f"从路径 {plugin_dir_path} 加载插件失败: {e}") import traceback + logger.error(traceback.format_exc()) - + if total_plugins_found == 0: logger.info("未找到任何插件目录或插件") - + # 输出加载统计 self._log_loading_stats() - + return self.loaded_actions, self.loaded_commands - + def _load_plugins_from_path(self, plugin_import_path: str, plugin_dir_path: str) -> int: """从指定路径加载插件 - + Args: plugin_import_path: 插件的导入路径 (如 "plugins" 或 "src.plugins") plugin_dir_path: 插件目录的文件系统路径 - + Returns: int: 找到的插件包数量 """ @@ -62,9 +63,9 @@ class PluginLoader: if not os.path.exists(plugin_dir_path): logger.debug(f"插件目录 {plugin_dir_path} 不存在,跳过") return 0 - + logger.info(f"正在从 {plugin_dir_path} 加载插件...") - + # 导入插件包 try: plugins_package = importlib.import_module(plugin_import_path) @@ -72,122 +73,120 @@ class PluginLoader: except ImportError as e: logger.warning(f"导入插件包 {plugin_import_path} 失败: {e}") return 0 - + # 遍历插件包中的所有子包 plugins_found = 0 - for _, plugin_name, is_pkg in pkgutil.iter_modules( - plugins_package.__path__, plugins_package.__name__ + "." - ): + for _, plugin_name, is_pkg in pkgutil.iter_modules(plugins_package.__path__, plugins_package.__name__ + "."): if not is_pkg: continue - + logger.debug(f"检测到插件: {plugin_name}") # 记录插件来源 self.plugin_sources[plugin_name] = plugin_dir_path self._load_single_plugin(plugin_name) plugins_found += 1 - + if plugins_found > 0: logger.info(f"从 {plugin_dir_path} 找到 {plugins_found} 个插件包") else: logger.debug(f"从 {plugin_dir_path} 未找到任何插件包") - + return plugins_found - + def _load_single_plugin(self, plugin_name: str) -> None: """加载单个插件的所有组件 - + Args: plugin_name: 插件名称 """ plugin_stats = {"actions": 0, "commands": 0} - + # 加载动作组件 actions_count = self._load_plugin_actions(plugin_name) plugin_stats["actions"] = actions_count self.loaded_actions += actions_count - - # 加载命令组件 + + # 加载命令组件 commands_count = self._load_plugin_commands(plugin_name) plugin_stats["commands"] = commands_count self.loaded_commands += commands_count - + # 记录插件统计信息 if actions_count > 0 or commands_count > 0: self.plugin_stats[plugin_name] = plugin_stats logger.info(f"插件 {plugin_name} 加载完成: {actions_count} 个动作, {commands_count} 个命令") - + def _load_plugin_actions(self, plugin_name: str) -> int: """加载插件的动作组件 - + Args: plugin_name: 插件名称 - + Returns: int: 加载的动作数量 """ loaded_count = 0 - + # 优先检查插件是否有actions子包 plugin_actions_path = f"{plugin_name}.actions" plugin_actions_dir = plugin_name.replace(".", os.path.sep) + os.path.sep + "actions" - + actions_loaded_from_subdir = False - + # 首先尝试从actions子目录加载 if os.path.exists(plugin_actions_dir): loaded_count += self._load_from_actions_subdir(plugin_name, plugin_actions_path, plugin_actions_dir) if loaded_count > 0: actions_loaded_from_subdir = True - + # 如果actions子目录不存在或加载失败,尝试从插件根目录加载 if not actions_loaded_from_subdir: loaded_count += self._load_actions_from_root_dir(plugin_name) - + return loaded_count - + def _load_plugin_commands(self, plugin_name: str) -> int: """加载插件的命令组件 - + Args: plugin_name: 插件名称 - + Returns: int: 加载的命令数量 """ loaded_count = 0 - + # 优先检查插件是否有commands子包 plugin_commands_path = f"{plugin_name}.commands" plugin_commands_dir = plugin_name.replace(".", os.path.sep) + os.path.sep + "commands" - + commands_loaded_from_subdir = False - + # 首先尝试从commands子目录加载 if os.path.exists(plugin_commands_dir): loaded_count += self._load_from_commands_subdir(plugin_name, plugin_commands_path, plugin_commands_dir) if loaded_count > 0: commands_loaded_from_subdir = True - + # 如果commands子目录不存在或加载失败,尝试从插件根目录加载 if not commands_loaded_from_subdir: loaded_count += self._load_commands_from_root_dir(plugin_name) - + return loaded_count - + def _load_from_actions_subdir(self, plugin_name: str, plugin_actions_path: str, plugin_actions_dir: str) -> int: """从actions子目录加载动作""" loaded_count = 0 - + try: # 尝试导入插件的actions包 actions_module = importlib.import_module(plugin_actions_path) logger.debug(f"成功加载插件动作模块: {plugin_actions_path}") - + # 遍历actions目录中的所有Python文件 actions_dir = os.path.dirname(actions_module.__file__) for file in os.listdir(actions_dir): - if file.endswith('.py') and file != '__init__.py': + if file.endswith(".py") and file != "__init__.py": action_module_name = f"{plugin_actions_path}.{file[:-3]}" try: importlib.import_module(action_module_name) @@ -195,25 +194,25 @@ class PluginLoader: loaded_count += 1 except Exception as e: logger.error(f"加载动作失败: {action_module_name}, 错误: {e}") - + except ImportError as e: logger.debug(f"插件 {plugin_name} 的actions子包导入失败: {e}") - + return loaded_count - + def _load_from_commands_subdir(self, plugin_name: str, plugin_commands_path: str, plugin_commands_dir: str) -> int: """从commands子目录加载命令""" loaded_count = 0 - + try: # 尝试导入插件的commands包 commands_module = importlib.import_module(plugin_commands_path) logger.debug(f"成功加载插件命令模块: {plugin_commands_path}") - + # 遍历commands目录中的所有Python文件 commands_dir = os.path.dirname(commands_module.__file__) for file in os.listdir(commands_dir): - if file.endswith('.py') and file != '__init__.py': + if file.endswith(".py") and file != "__init__.py": command_module_name = f"{plugin_commands_path}.{file[:-3]}" try: importlib.import_module(command_module_name) @@ -221,29 +220,29 @@ class PluginLoader: loaded_count += 1 except Exception as e: logger.error(f"加载命令失败: {command_module_name}, 错误: {e}") - + except ImportError as e: logger.debug(f"插件 {plugin_name} 的commands子包导入失败: {e}") - + return loaded_count - + def _load_actions_from_root_dir(self, plugin_name: str) -> int: """从插件根目录加载动作文件""" loaded_count = 0 - + try: # 导入插件包本身 plugin_module = importlib.import_module(plugin_name) logger.debug(f"尝试从插件根目录加载动作: {plugin_name}") - + # 遍历插件根目录中的所有Python文件 plugin_dir = os.path.dirname(plugin_module.__file__) for file in os.listdir(plugin_dir): - if file.endswith('.py') and file != '__init__.py': + if file.endswith(".py") and file != "__init__.py": # 跳过非动作文件(根据命名约定) - if not (file.endswith('_action.py') or file.endswith('_actions.py') or 'action' in file): + if not (file.endswith("_action.py") or file.endswith("_actions.py") or "action" in file): continue - + action_module_name = f"{plugin_name}.{file[:-3]}" try: importlib.import_module(action_module_name) @@ -251,29 +250,29 @@ class PluginLoader: loaded_count += 1 except Exception as e: logger.error(f"加载动作失败: {action_module_name}, 错误: {e}") - + except ImportError as e: logger.debug(f"插件 {plugin_name} 导入失败: {e}") - + return loaded_count - + def _load_commands_from_root_dir(self, plugin_name: str) -> int: """从插件根目录加载命令文件""" loaded_count = 0 - + try: # 导入插件包本身 plugin_module = importlib.import_module(plugin_name) logger.debug(f"尝试从插件根目录加载命令: {plugin_name}") - + # 遍历插件根目录中的所有Python文件 plugin_dir = os.path.dirname(plugin_module.__file__) for file in os.listdir(plugin_dir): - if file.endswith('.py') and file != '__init__.py': + if file.endswith(".py") and file != "__init__.py": # 跳过非命令文件(根据命名约定) - if not (file.endswith('_command.py') or file.endswith('_commands.py') or 'command' in file): + if not (file.endswith("_command.py") or file.endswith("_commands.py") or "command" in file): continue - + command_module_name = f"{plugin_name}.{file[:-3]}" try: importlib.import_module(command_module_name) @@ -281,23 +280,25 @@ class PluginLoader: loaded_count += 1 except Exception as e: logger.error(f"加载命令失败: {command_module_name}, 错误: {e}") - + except ImportError as e: logger.debug(f"插件 {plugin_name} 导入失败: {e}") - + return loaded_count - + def _log_loading_stats(self) -> None: """输出加载统计信息""" logger.success(f"插件加载完成: 总计 {self.loaded_actions} 个动作, {self.loaded_commands} 个命令") - + if self.plugin_stats: logger.info("插件加载详情:") for plugin_name, stats in self.plugin_stats.items(): - plugin_display_name = plugin_name.split('.')[-1] # 只显示插件名称,不显示完整路径 + plugin_display_name = plugin_name.split(".")[-1] # 只显示插件名称,不显示完整路径 source_path = self.plugin_sources.get(plugin_name, "未知路径") - logger.info(f" {plugin_display_name} (来源: {source_path}): {stats['actions']} 动作, {stats['commands']} 命令") + logger.info( + f" {plugin_display_name} (来源: {source_path}): {stats['actions']} 动作, {stats['commands']} 命令" + ) # 创建全局插件加载器实例 -plugin_loader = PluginLoader() \ No newline at end of file +plugin_loader = PluginLoader() diff --git a/src/plugins/tts_plgin/actions/tts_action.py b/src/plugins/tts_plgin/actions/tts_action.py index 12a67a0c2..0e64dcb4c 100644 --- a/src/plugins/tts_plgin/actions/tts_action.py +++ b/src/plugins/tts_plgin/actions/tts_action.py @@ -23,14 +23,14 @@ class TTSAction(PluginAction): ] enable_plugin = True # 启用插件 associated_types = ["tts_text"] - + focus_activation_type = ActionActivationType.LLM_JUDGE normal_activation_type = ActionActivationType.KEYWORD - + # 关键词配置 - Normal模式下使用关键词触发 activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"] keyword_case_sensitive = False - + # 并行执行设置 - TTS可以与回复并行执行,不覆盖回复内容 parallel_action = False diff --git a/src/plugins/vtb_action/actions/vtb_action.py b/src/plugins/vtb_action/actions/vtb_action.py index 2d3a8e507..30b625c39 100644 --- a/src/plugins/vtb_action/actions/vtb_action.py +++ b/src/plugins/vtb_action/actions/vtb_action.py @@ -22,11 +22,11 @@ class VTBAction(PluginAction): ] enable_plugin = True # 启用插件 associated_types = ["vtb_text"] - + # 激活类型设置 focus_activation_type = ActionActivationType.LLM_JUDGE # Focus模式使用LLM判定,精确识别情感表达需求 - normal_activation_type = ActionActivationType.RANDOM # Normal模式使用随机激活,增加趣味性 - + normal_activation_type = ActionActivationType.RANDOM # Normal模式使用随机激活,增加趣味性 + # LLM判定提示词(用于Focus模式) llm_judge_prompt = """ 判定是否需要使用VTB虚拟主播动作的条件: @@ -41,7 +41,7 @@ class VTBAction(PluginAction): 3. 不涉及情感的日常对话 4. 已经有足够的情感表达 """ - + # Random激活概率(用于Normal模式) random_activation_probability = 0.08 # 较低概率,避免过度使用