From 095cbbe58cee3bd9970cb0d2b7e44e4f1b3ef4ba Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Mon, 9 Jun 2025 16:53:11 +0800 Subject: [PATCH] =?UTF-8?q?ref:=E4=BF=AE=E6=94=B9=E4=BA=86=E6=8F=92?= =?UTF-8?q?=E4=BB=B6api=E7=9A=84=E6=96=87=E4=BB=B6=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/actions/plugin_action.py | 669 +------------------- src/chat/actions/plugin_api/__init__.py | 13 + src/chat/actions/plugin_api/config_api.py | 53 ++ src/chat/actions/plugin_api/database_api.py | 381 +++++++++++ src/chat/actions/plugin_api/llm_api.py | 61 ++ src/chat/actions/plugin_api/message_api.py | 231 +++++++ src/chat/actions/plugin_api/utils_api.py | 121 ++++ 7 files changed, 873 insertions(+), 656 deletions(-) create mode 100644 src/chat/actions/plugin_api/__init__.py create mode 100644 src/chat/actions/plugin_api/config_api.py create mode 100644 src/chat/actions/plugin_api/database_api.py create mode 100644 src/chat/actions/plugin_api/llm_api.py create mode 100644 src/chat/actions/plugin_api/message_api.py create mode 100644 src/chat/actions/plugin_api/utils_api.py diff --git a/src/chat/actions/plugin_action.py b/src/chat/actions/plugin_action.py index 373ac7f28..24944c63e 100644 --- a/src/chat/actions/plugin_action.py +++ b/src/chat/actions/plugin_action.py @@ -4,29 +4,29 @@ from src.chat.actions.base_action import BaseAction, register_action, ActionActi from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.focus_chat.hfc_utils import create_empty_anchor_message from src.common.logger_manager import get_logger -from src.llm_models.utils_model import LLMRequest -from src.person_info.person_info import person_info_manager -from abc import abstractmethod from src.config.config import global_config import os import inspect import toml # 导入 toml 库 -from src.common.database.database_model import ActionRecords -from src.common.database.database import db -from peewee import Model, DoesNotExist -import json -import time +from abc import abstractmethod + +# 导入拆分后的API模块 +from src.chat.actions.plugin_api.message_api import MessageAPI +from src.chat.actions.plugin_api.llm_api import LLMAPI +from src.chat.actions.plugin_api.database_api import DatabaseAPI +from src.chat.actions.plugin_api.config_api import ConfigAPI +from src.chat.actions.plugin_api.utils_api import UtilsAPI # 以下为类型注解需要 -from src.chat.message_receive.chat_stream import ChatStream -from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor -from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer -from src.chat.focus_chat.info.obs_info import ObsInfo +from src.chat.message_receive.chat_stream import ChatStream # noqa +from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor # noqa +from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer # noqa +from src.chat.focus_chat.info.obs_info import ObsInfo # noqa logger = get_logger("plugin_action") -class PluginAction(BaseAction): +class PluginAction(BaseAction, MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, UtilsAPI): """插件动作基类 封装了主程序内部依赖,提供简化的API接口给插件开发者 @@ -118,284 +118,6 @@ class PluginAction(BaseAction): ) self.config = {} # 出错时确保 config 是一个空字典 - def get_global_config(self, key: str, default: Any = None) -> Any: - """ - 安全地从全局配置中获取一个值。 - 插件应使用此方法读取全局配置,以保证只读和隔离性。 - """ - - return global_config.get(key, default) - - async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]: - """根据用户名获取用户ID""" - person_id = person_info_manager.get_person_id_by_person_name(person_name) - user_id = await person_info_manager.get_value(person_id, "user_id") - platform = await person_info_manager.get_value(person_id, "platform") - return platform, user_id - - # 提供简化的API方法 - async def send_message(self, type: str, data: str, target: Optional[str] = "", display_message: str = "") -> bool: - """发送消息的简化方法 - - Args: - text: 要发送的消息文本 - target: 目标消息(可选) - - Returns: - bool: 是否发送成功 - """ - try: - expressor: DefaultExpressor = self._services.get("expressor") - chat_stream: ChatStream = self._services.get("chat_stream") - - if not expressor or not chat_stream: - logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") - return False - - # 构造简化的动作数据 - # reply_data = {"text": text, "target": target or "", "emojis": []} - - # 获取锚定消息(如果有) - observations = self._services.get("observations", []) - - if len(observations) > 0: - chatting_observation: ChattingObservation = next( - obs for obs in observations if isinstance(obs, ChattingObservation) - ) - - anchor_message = chatting_observation.search_message_by_text(target) - else: - anchor_message = None - - # 如果没有找到锚点消息,创建一个占位符 - if not anchor_message: - logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") - anchor_message = await create_empty_anchor_message( - chat_stream.platform, chat_stream.group_info, chat_stream - ) - else: - anchor_message.update_chat_stream(chat_stream) - - response_set = [ - (type, data), - ] - - # 调用内部方法发送消息 - success = await expressor.send_response_messages( - anchor_message=anchor_message, - response_set=response_set, - display_message=display_message, - ) - - return success - except Exception as e: - logger.error(f"{self.log_prefix} 发送消息时出错: {e}") - traceback.print_exc() - return False - - async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool: - """发送消息的简化方法 - - Args: - text: 要发送的消息文本 - target: 目标消息(可选) - - Returns: - bool: 是否发送成功 - """ - expressor: DefaultExpressor = self._services.get("expressor") - chat_stream: ChatStream = self._services.get("chat_stream") - - if not expressor or not chat_stream: - logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") - return False - - # 构造简化的动作数据 - reply_data = {"text": text, "target": target or "", "emojis": []} - - # 获取锚定消息(如果有) - observations = self._services.get("observations", []) - - # 查找 ChattingObservation 实例 - chatting_observation = None - for obs in observations: - if isinstance(obs, ChattingObservation): - chatting_observation = obs - break - - if not chatting_observation: - logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符") - anchor_message = await create_empty_anchor_message( - chat_stream.platform, chat_stream.group_info, chat_stream - ) - else: - anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) - if not anchor_message: - logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") - anchor_message = await create_empty_anchor_message( - chat_stream.platform, chat_stream.group_info, chat_stream - ) - else: - anchor_message.update_chat_stream(chat_stream) - - # 调用内部方法发送消息 - success, _ = await expressor.deal_reply( - cycle_timers=self.cycle_timers, - action_data=reply_data, - anchor_message=anchor_message, - reasoning=self.reasoning, - thinking_id=self.thinking_id, - ) - - return success - - async def send_message_by_replyer(self, target: Optional[str] = None, extra_info_block: Optional[str] = None) -> bool: - """通过 replyer 发送消息的简化方法 - - Args: - text: 要发送的消息文本 - target: 目标消息(可选) - - Returns: - bool: 是否发送成功 - """ - replyer: DefaultReplyer = self._services.get("replyer") - chat_stream: ChatStream = self._services.get("chat_stream") - - if not replyer or not chat_stream: - logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") - return False - - # 构造简化的动作数据 - reply_data = {"target": target or "", "extra_info_block": extra_info_block} - - # 获取锚定消息(如果有) - observations = self._services.get("observations", []) - - # 查找 ChattingObservation 实例 - chatting_observation = None - for obs in observations: - if isinstance(obs, ChattingObservation): - chatting_observation = obs - break - - if not chatting_observation: - logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符") - anchor_message = await create_empty_anchor_message( - chat_stream.platform, chat_stream.group_info, chat_stream - ) - else: - anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) - if not anchor_message: - logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") - anchor_message = await create_empty_anchor_message( - chat_stream.platform, chat_stream.group_info, chat_stream - ) - else: - anchor_message.update_chat_stream(chat_stream) - - # 调用内部方法发送消息 - success, _ = await replyer.deal_reply( - cycle_timers=self.cycle_timers, - action_data=reply_data, - anchor_message=anchor_message, - reasoning=self.reasoning, - thinking_id=self.thinking_id, - ) - - return success - - def get_chat_type(self) -> str: - """获取当前聊天类型 - - Returns: - str: 聊天类型 ("group" 或 "private") - """ - chat_stream: ChatStream = self._services.get("chat_stream") - if chat_stream and hasattr(chat_stream, "group_info"): - return "group" if chat_stream.group_info else "private" - return "unknown" - - def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]: - """获取最近的消息 - - Args: - count: 要获取的消息数量 - - Returns: - List[Dict]: 消息列表,每个消息包含发送者、内容等信息 - """ - messages = [] - observations = self._services.get("observations", []) - - if observations and len(observations) > 0: - obs = observations[0] - if hasattr(obs, "get_talking_message"): - obs: ObsInfo - raw_messages = obs.get_talking_message() - # 转换为简化格式 - for msg in raw_messages[-count:]: - simple_msg = { - "sender": msg.get("sender", "未知"), - "content": msg.get("content", ""), - "timestamp": msg.get("timestamp", 0), - } - messages.append(simple_msg) - - return messages - - def get_available_models(self) -> Dict[str, Any]: - """获取所有可用的模型配置 - - Returns: - Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置 - """ - if not hasattr(global_config, "model"): - logger.error(f"{self.log_prefix} 无法获取模型列表:全局配置中未找到 model 配置") - return {} - - models = global_config.model - - return models - - async def generate_with_model( - self, - prompt: str, - model_config: Dict[str, Any], - request_type: str = "plugin.generate", - **kwargs - ) -> Tuple[bool, str]: - """使用指定模型生成内容 - - Args: - prompt: 提示词 - model_config: 模型配置(从 get_available_models 获取的模型配置) - temperature: 温度参数,控制随机性 (0-1) - max_tokens: 最大生成token数 - request_type: 请求类型标识 - **kwargs: 其他模型特定参数 - - Returns: - Tuple[bool, str]: (是否成功, 生成的内容或错误信息) - """ - try: - - - logger.info(f"prompt: {prompt}") - - llm_request = LLMRequest( - model=model_config, - request_type=request_type, - **kwargs - ) - - response,(resoning , model_name) = await llm_request.generate_response_async(prompt) - return True, response, resoning, model_name - except Exception as e: - error_msg = f"生成内容时出错: {str(e)}" - logger.error(f"{self.log_prefix} {error_msg}") - return False, error_msg - @abstractmethod async def process(self) -> Tuple[bool, str]: """插件处理逻辑,子类必须实现此方法 @@ -412,368 +134,3 @@ class PluginAction(BaseAction): Tuple[bool, str]: (是否执行成功, 回复文本) """ return await self.process() - - async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None: - """存储action执行信息到数据库 - - Args: - action_build_into_prompt: 是否构建到提示中 - action_prompt_display: 动作显示内容 - """ - try: - chat_stream: ChatStream = self._services.get("chat_stream") - if not chat_stream: - logger.error(f"{self.log_prefix} 无法存储action信息:缺少chat_stream服务") - return - - action_time = time.time() - action_id = f"{action_time}_{self.thinking_id}" - - ActionRecords.create( - action_id=action_id, - time=action_time, - action_name=self.__class__.__name__, - action_data=str(self.action_data), - action_done=action_done, - action_build_into_prompt=action_build_into_prompt, - action_prompt_display=action_prompt_display, - chat_id=chat_stream.stream_id, - chat_info_stream_id=chat_stream.stream_id, - chat_info_platform=chat_stream.platform, - user_id=chat_stream.user_info.user_id if chat_stream.user_info else "", - user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "", - user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "" - ) - logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}") - except Exception as e: - logger.error(f"{self.log_prefix} 存储action信息时出错: {e}") - traceback.print_exc() - - async def db_query( - self, - model_class: Type[Model], - query_type: str = "get", - filters: Dict[str, Any] = None, - data: Dict[str, Any] = None, - limit: int = None, - order_by: List[str] = None, - single_result: bool = False - ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: - """执行数据库查询操作 - - 这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。 - - Args: - model_class: Peewee 模型类,例如 ActionRecords, Messages 等 - query_type: 查询类型,可选值: "get", "create", "update", "delete", "count" - filters: 过滤条件字典,键为字段名,值为要匹配的值 - data: 用于创建或更新的数据字典 - limit: 限制结果数量 - order_by: 排序字段列表,使用字段名,前缀'-'表示降序 - single_result: 是否只返回单个结果 - - Returns: - 根据查询类型返回不同的结果: - - "get": 返回查询结果列表或单个结果(如果 single_result=True) - - "create": 返回创建的记录 - - "update": 返回受影响的行数 - - "delete": 返回受影响的行数 - - "count": 返回记录数量 - - 示例: - # 查询最近10条消息 - messages = await self.db_query( - Messages, - query_type="get", - filters={"chat_id": chat_stream.stream_id}, - limit=10, - order_by=["-time"] - ) - - # 创建一条记录 - new_record = await self.db_query( - ActionRecords, - query_type="create", - data={"action_id": "123", "time": time.time(), "action_name": "TestAction"} - ) - - # 更新记录 - updated_count = await self.db_query( - ActionRecords, - query_type="update", - filters={"action_id": "123"}, - data={"action_done": True} - ) - - # 删除记录 - deleted_count = await self.db_query( - ActionRecords, - query_type="delete", - filters={"action_id": "123"} - ) - - # 计数 - count = await self.db_query( - Messages, - query_type="count", - filters={"chat_id": chat_stream.stream_id} - ) - """ - try: - # 构建基本查询 - if query_type in ["get", "update", "delete", "count"]: - query = model_class.select() - - # 应用过滤条件 - if filters: - for field, value in filters.items(): - query = query.where(getattr(model_class, field) == value) - - # 执行查询 - if query_type == "get": - # 应用排序 - if order_by: - for field in order_by: - if field.startswith("-"): - query = query.order_by(getattr(model_class, field[1:]).desc()) - else: - query = query.order_by(getattr(model_class, field)) - - # 应用限制 - if limit: - query = query.limit(limit) - - # 执行查询 - results = list(query.dicts()) - - # 返回结果 - if single_result: - return results[0] if results else None - return results - - elif query_type == "create": - if not data: - raise ValueError("创建记录需要提供data参数") - - # 创建记录 - record = model_class.create(**data) - # 返回创建的记录 - return model_class.select().where(model_class.id == record.id).dicts().get() - - elif query_type == "update": - if not data: - raise ValueError("更新记录需要提供data参数") - - # 更新记录 - return query.update(**data).execute() - - elif query_type == "delete": - # 删除记录 - return query.delete().execute() - - elif query_type == "count": - # 计数 - return query.count() - - else: - raise ValueError(f"不支持的查询类型: {query_type}") - - except DoesNotExist: - # 记录不存在 - if query_type == "get" and single_result: - return None - return [] - - except Exception as e: - logger.error(f"{self.log_prefix} 数据库操作出错: {e}") - traceback.print_exc() - - # 根据查询类型返回合适的默认值 - if query_type == "get": - return None if single_result else [] - elif query_type in ["create", "update", "delete", "count"]: - return None - - async def db_raw_query( - self, - sql: str, - params: List[Any] = None, - fetch_results: bool = True - ) -> Union[List[Dict[str, Any]], int, None]: - """执行原始SQL查询 - - 警告: 使用此方法需要小心,确保SQL语句已正确构造以避免SQL注入风险。 - - Args: - sql: 原始SQL查询字符串 - params: 查询参数列表,用于替换SQL中的占位符 - fetch_results: 是否获取查询结果,对于SELECT查询设为True,对于 - UPDATE/INSERT/DELETE等操作设为False - - Returns: - 如果fetch_results为True,返回查询结果列表; - 如果fetch_results为False,返回受影响的行数; - 如果出错,返回None - """ - try: - cursor = db.execute_sql(sql, params or []) - - if fetch_results: - # 获取列名 - columns = [col[0] for col in cursor.description] - - # 构建结果字典列表 - results = [] - for row in cursor.fetchall(): - results.append(dict(zip(columns, row))) - - return results - else: - # 返回受影响的行数 - return cursor.rowcount - - except Exception as e: - logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}") - traceback.print_exc() - return None - - async def db_save( - self, - model_class: Type[Model], - data: Dict[str, Any], - key_field: str = None, - key_value: Any = None - ) -> Union[Dict[str, Any], None]: - """保存数据到数据库(创建或更新) - - 如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新; - 如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。 - - Args: - model_class: Peewee模型类,如ActionRecords, Messages等 - data: 要保存的数据字典 - key_field: 用于查找现有记录的字段名,例如"action_id" - key_value: 用于查找现有记录的字段值 - - Returns: - Dict[str, Any]: 保存后的记录数据 - None: 如果操作失败 - - 示例: - # 创建或更新一条记录 - record = await self.db_save( - ActionRecords, - { - "action_id": "123", - "time": time.time(), - "action_name": "TestAction", - "action_done": True - }, - key_field="action_id", - key_value="123" - ) - """ - try: - # 如果提供了key_field和key_value,尝试更新现有记录 - if key_field and key_value is not None: - # 查找现有记录 - existing_records = list(model_class.select().where( - getattr(model_class, key_field) == key_value - ).limit(1)) - - if existing_records: - # 更新现有记录 - existing_record = existing_records[0] - for field, value in data.items(): - setattr(existing_record, field, value) - existing_record.save() - - # 返回更新后的记录 - updated_record = model_class.select().where( - model_class.id == existing_record.id - ).dicts().get() - return updated_record - - # 如果没有找到现有记录或未提供key_field和key_value,创建新记录 - new_record = model_class.create(**data) - - # 返回创建的记录 - created_record = model_class.select().where( - model_class.id == new_record.id - ).dicts().get() - return created_record - - except Exception as e: - logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}") - traceback.print_exc() - return None - - async def db_get( - self, - model_class: Type[Model], - filters: Dict[str, Any] = None, - order_by: str = None, - limit: int = None - ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: - """从数据库获取记录 - - 这是db_query方法的简化版本,专注于数据检索操作。 - - Args: - model_class: Peewee模型类 - filters: 过滤条件,字段名和值的字典 - order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序 - limit: 结果数量限制,如果为1则返回单个记录而不是列表 - - Returns: - 如果limit=1,返回单个记录字典或None; - 否则返回记录字典列表或空列表。 - - 示例: - # 获取单个记录 - record = await self.db_get( - ActionRecords, - filters={"action_id": "123"}, - limit=1 - ) - - # 获取最近10条记录 - records = await self.db_get( - Messages, - filters={"chat_id": chat_stream.stream_id}, - order_by="-time", - limit=10 - ) - """ - try: - # 构建查询 - query = model_class.select() - - # 应用过滤条件 - if filters: - for field, value in filters.items(): - query = query.where(getattr(model_class, field) == value) - - # 应用排序 - if order_by: - if order_by.startswith("-"): - query = query.order_by(getattr(model_class, order_by[1:]).desc()) - else: - query = query.order_by(getattr(model_class, order_by)) - - # 应用限制 - if limit: - query = query.limit(limit) - - # 执行查询 - results = list(query.dicts()) - - # 返回结果 - if limit == 1: - return results[0] if results else None - return results - - except Exception as e: - logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}") - traceback.print_exc() - return None if limit == 1 else [] diff --git a/src/chat/actions/plugin_api/__init__.py b/src/chat/actions/plugin_api/__init__.py new file mode 100644 index 000000000..1db320ddb --- /dev/null +++ b/src/chat/actions/plugin_api/__init__.py @@ -0,0 +1,13 @@ +from src.chat.actions.plugin_api.message_api import MessageAPI +from src.chat.actions.plugin_api.llm_api import LLMAPI +from src.chat.actions.plugin_api.database_api import DatabaseAPI +from src.chat.actions.plugin_api.config_api import ConfigAPI +from src.chat.actions.plugin_api.utils_api import UtilsAPI + +__all__ = [ + 'MessageAPI', + 'LLMAPI', + 'DatabaseAPI', + 'ConfigAPI', + 'UtilsAPI', +] \ No newline at end of file diff --git a/src/chat/actions/plugin_api/config_api.py b/src/chat/actions/plugin_api/config_api.py new file mode 100644 index 000000000..f136cea7e --- /dev/null +++ b/src/chat/actions/plugin_api/config_api.py @@ -0,0 +1,53 @@ +from typing import Any +from src.common.logger_manager import get_logger +from src.config.config import global_config +from src.person_info.person_info import person_info_manager + +logger = get_logger("config_api") + +class ConfigAPI: + """配置API模块 + + 提供了配置读取和用户信息获取等功能 + """ + + def get_global_config(self, key: str, default: Any = None) -> Any: + """ + 安全地从全局配置中获取一个值。 + 插件应使用此方法读取全局配置,以保证只读和隔离性。 + + Args: + key: 配置键名 + default: 如果配置不存在时返回的默认值 + + Returns: + Any: 配置值或默认值 + """ + return global_config.get(key, default) + + async def get_user_id_by_person_name(self, person_name: str) -> tuple[str, str]: + """根据用户名获取用户ID + + Args: + person_name: 用户名 + + Returns: + tuple[str, str]: (平台, 用户ID) + """ + person_id = person_info_manager.get_person_id_by_person_name(person_name) + user_id = await person_info_manager.get_value(person_id, "user_id") + platform = await person_info_manager.get_value(person_id, "platform") + return platform, user_id + + async def get_person_info(self, person_id: str, key: str, default: Any = None) -> Any: + """获取用户信息 + + Args: + person_id: 用户ID + key: 信息键名 + default: 默认值 + + Returns: + Any: 用户信息值或默认值 + """ + return await person_info_manager.get_value(person_id, key, default) \ No newline at end of file diff --git a/src/chat/actions/plugin_api/database_api.py b/src/chat/actions/plugin_api/database_api.py new file mode 100644 index 000000000..d8a45aefa --- /dev/null +++ b/src/chat/actions/plugin_api/database_api.py @@ -0,0 +1,381 @@ +import traceback +import time +from typing import Dict, List, Any, Union, Type +from src.common.logger_manager import get_logger +from src.common.database.database_model import ActionRecords +from src.common.database.database import db +from peewee import Model, DoesNotExist + +logger = get_logger("database_api") + +class DatabaseAPI: + """数据库API模块 + + 提供了数据库操作相关的功能 + """ + + async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None: + """存储action执行信息到数据库 + + Args: + action_build_into_prompt: 是否构建到提示中 + action_prompt_display: 动作显示内容 + action_done: 动作是否已完成 + """ + try: + chat_stream = self._services.get("chat_stream") + if not chat_stream: + logger.error(f"{self.log_prefix} 无法存储action信息:缺少chat_stream服务") + return + + action_time = time.time() + action_id = f"{action_time}_{self.thinking_id}" + + ActionRecords.create( + action_id=action_id, + time=action_time, + action_name=self.__class__.__name__, + action_data=str(self.action_data), + action_done=action_done, + action_build_into_prompt=action_build_into_prompt, + action_prompt_display=action_prompt_display, + chat_id=chat_stream.stream_id, + chat_info_stream_id=chat_stream.stream_id, + chat_info_platform=chat_stream.platform, + user_id=chat_stream.user_info.user_id if chat_stream.user_info else "", + user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "", + user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "" + ) + logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}") + except Exception as e: + logger.error(f"{self.log_prefix} 存储action信息时出错: {e}") + traceback.print_exc() + + async def db_query( + self, + model_class: Type[Model], + query_type: str = "get", + filters: Dict[str, Any] = None, + data: Dict[str, Any] = None, + limit: int = None, + order_by: List[str] = None, + single_result: bool = False + ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: + """执行数据库查询操作 + + 这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。 + + Args: + model_class: Peewee 模型类,例如 ActionRecords, Messages 等 + query_type: 查询类型,可选值: "get", "create", "update", "delete", "count" + filters: 过滤条件字典,键为字段名,值为要匹配的值 + data: 用于创建或更新的数据字典 + limit: 限制结果数量 + order_by: 排序字段列表,使用字段名,前缀'-'表示降序 + single_result: 是否只返回单个结果 + + Returns: + 根据查询类型返回不同的结果: + - "get": 返回查询结果列表或单个结果(如果 single_result=True) + - "create": 返回创建的记录 + - "update": 返回受影响的行数 + - "delete": 返回受影响的行数 + - "count": 返回记录数量 + + 示例: + # 查询最近10条消息 + messages = await self.db_query( + Messages, + query_type="get", + filters={"chat_id": chat_stream.stream_id}, + limit=10, + order_by=["-time"] + ) + + # 创建一条记录 + new_record = await self.db_query( + ActionRecords, + query_type="create", + data={"action_id": "123", "time": time.time(), "action_name": "TestAction"} + ) + + # 更新记录 + updated_count = await self.db_query( + ActionRecords, + query_type="update", + filters={"action_id": "123"}, + data={"action_done": True} + ) + + # 删除记录 + deleted_count = await self.db_query( + ActionRecords, + query_type="delete", + filters={"action_id": "123"} + ) + + # 计数 + count = await self.db_query( + Messages, + query_type="count", + filters={"chat_id": chat_stream.stream_id} + ) + """ + try: + # 构建基本查询 + if query_type in ["get", "update", "delete", "count"]: + query = model_class.select() + + # 应用过滤条件 + if filters: + for field, value in filters.items(): + query = query.where(getattr(model_class, field) == value) + + # 执行查询 + if query_type == "get": + # 应用排序 + if order_by: + for field in order_by: + if field.startswith("-"): + query = query.order_by(getattr(model_class, field[1:]).desc()) + else: + query = query.order_by(getattr(model_class, field)) + + # 应用限制 + if limit: + query = query.limit(limit) + + # 执行查询 + results = list(query.dicts()) + + # 返回结果 + if single_result: + return results[0] if results else None + return results + + elif query_type == "create": + if not data: + raise ValueError("创建记录需要提供data参数") + + # 创建记录 + record = model_class.create(**data) + # 返回创建的记录 + return model_class.select().where(model_class.id == record.id).dicts().get() + + elif query_type == "update": + if not data: + raise ValueError("更新记录需要提供data参数") + + # 更新记录 + return query.update(**data).execute() + + elif query_type == "delete": + # 删除记录 + return query.delete().execute() + + elif query_type == "count": + # 计数 + return query.count() + + else: + raise ValueError(f"不支持的查询类型: {query_type}") + + except DoesNotExist: + # 记录不存在 + if query_type == "get" and single_result: + return None + return [] + + except Exception as e: + logger.error(f"{self.log_prefix} 数据库操作出错: {e}") + traceback.print_exc() + + # 根据查询类型返回合适的默认值 + if query_type == "get": + return None if single_result else [] + elif query_type in ["create", "update", "delete", "count"]: + return None + + async def db_raw_query( + self, + sql: str, + params: List[Any] = None, + fetch_results: bool = True + ) -> Union[List[Dict[str, Any]], int, None]: + """执行原始SQL查询 + + 警告: 使用此方法需要小心,确保SQL语句已正确构造以避免SQL注入风险。 + + Args: + sql: 原始SQL查询字符串 + params: 查询参数列表,用于替换SQL中的占位符 + fetch_results: 是否获取查询结果,对于SELECT查询设为True,对于 + UPDATE/INSERT/DELETE等操作设为False + + Returns: + 如果fetch_results为True,返回查询结果列表; + 如果fetch_results为False,返回受影响的行数; + 如果出错,返回None + """ + try: + cursor = db.execute_sql(sql, params or []) + + if fetch_results: + # 获取列名 + columns = [col[0] for col in cursor.description] + + # 构建结果字典列表 + results = [] + for row in cursor.fetchall(): + results.append(dict(zip(columns, row))) + + return results + else: + # 返回受影响的行数 + return cursor.rowcount + + except Exception as e: + logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}") + traceback.print_exc() + return None + + async def db_save( + self, + model_class: Type[Model], + data: Dict[str, Any], + key_field: str = None, + key_value: Any = None + ) -> Union[Dict[str, Any], None]: + """保存数据到数据库(创建或更新) + + 如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新; + 如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。 + + Args: + model_class: Peewee模型类,如ActionRecords, Messages等 + data: 要保存的数据字典 + key_field: 用于查找现有记录的字段名,例如"action_id" + key_value: 用于查找现有记录的字段值 + + Returns: + Dict[str, Any]: 保存后的记录数据 + None: 如果操作失败 + + 示例: + # 创建或更新一条记录 + record = await self.db_save( + ActionRecords, + { + "action_id": "123", + "time": time.time(), + "action_name": "TestAction", + "action_done": True + }, + key_field="action_id", + key_value="123" + ) + """ + try: + # 如果提供了key_field和key_value,尝试更新现有记录 + if key_field and key_value is not None: + # 查找现有记录 + existing_records = list(model_class.select().where( + getattr(model_class, key_field) == key_value + ).limit(1)) + + if existing_records: + # 更新现有记录 + existing_record = existing_records[0] + for field, value in data.items(): + setattr(existing_record, field, value) + existing_record.save() + + # 返回更新后的记录 + updated_record = model_class.select().where( + model_class.id == existing_record.id + ).dicts().get() + return updated_record + + # 如果没有找到现有记录或未提供key_field和key_value,创建新记录 + new_record = model_class.create(**data) + + # 返回创建的记录 + created_record = model_class.select().where( + model_class.id == new_record.id + ).dicts().get() + return created_record + + except Exception as e: + logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}") + traceback.print_exc() + return None + + async def db_get( + self, + model_class: Type[Model], + filters: Dict[str, Any] = None, + order_by: str = None, + limit: int = None + ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: + """从数据库获取记录 + + 这是db_query方法的简化版本,专注于数据检索操作。 + + Args: + model_class: Peewee模型类 + filters: 过滤条件,字段名和值的字典 + order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序 + limit: 结果数量限制,如果为1则返回单个记录而不是列表 + + Returns: + 如果limit=1,返回单个记录字典或None; + 否则返回记录字典列表或空列表。 + + 示例: + # 获取单个记录 + record = await self.db_get( + ActionRecords, + filters={"action_id": "123"}, + limit=1 + ) + + # 获取最近10条记录 + records = await self.db_get( + Messages, + filters={"chat_id": chat_stream.stream_id}, + order_by="-time", + limit=10 + ) + """ + try: + # 构建查询 + query = model_class.select() + + # 应用过滤条件 + if filters: + for field, value in filters.items(): + query = query.where(getattr(model_class, field) == value) + + # 应用排序 + if order_by: + if order_by.startswith("-"): + query = query.order_by(getattr(model_class, order_by[1:]).desc()) + else: + query = query.order_by(getattr(model_class, order_by)) + + # 应用限制 + if limit: + query = query.limit(limit) + + # 执行查询 + results = list(query.dicts()) + + # 返回结果 + if limit == 1: + return results[0] if results else None + return results + + except Exception as e: + logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}") + traceback.print_exc() + return None if limit == 1 else [] \ No newline at end of file diff --git a/src/chat/actions/plugin_api/llm_api.py b/src/chat/actions/plugin_api/llm_api.py new file mode 100644 index 000000000..0e80e897b --- /dev/null +++ b/src/chat/actions/plugin_api/llm_api.py @@ -0,0 +1,61 @@ +from typing import Tuple, Dict, Any +from src.common.logger_manager import get_logger +from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config + +logger = get_logger("llm_api") + +class LLMAPI: + """LLM API模块 + + 提供了与LLM模型交互的功能 + """ + + def get_available_models(self) -> Dict[str, Any]: + """获取所有可用的模型配置 + + Returns: + Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置 + """ + if not hasattr(global_config, "model"): + logger.error(f"{self.log_prefix} 无法获取模型列表:全局配置中未找到 model 配置") + return {} + + models = global_config.model + + return models + + async def generate_with_model( + self, + prompt: str, + model_config: Dict[str, Any], + request_type: str = "plugin.generate", + **kwargs + ) -> Tuple[bool, str, str, str]: + """使用指定模型生成内容 + + Args: + prompt: 提示词 + model_config: 模型配置(从 get_available_models 获取的模型配置) + request_type: 请求类型标识 + **kwargs: 其他模型特定参数,如temperature、max_tokens等 + + Returns: + Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) + """ + try: + logger.info(f"{self.log_prefix} 使用模型生成内容,提示词: {prompt[:100]}...") + + llm_request = LLMRequest( + model=model_config, + request_type=request_type, + **kwargs + ) + + response, (reasoning, model_name) = await llm_request.generate_response_async(prompt) + return True, response, reasoning, model_name + + except Exception as e: + error_msg = f"生成内容时出错: {str(e)}" + logger.error(f"{self.log_prefix} {error_msg}") + return False, error_msg, "", "" \ No newline at end of file diff --git a/src/chat/actions/plugin_api/message_api.py b/src/chat/actions/plugin_api/message_api.py new file mode 100644 index 000000000..38816a30e --- /dev/null +++ b/src/chat/actions/plugin_api/message_api.py @@ -0,0 +1,231 @@ +import traceback +from typing import Optional, List, Dict, Any +from src.common.logger_manager import get_logger +from src.chat.heart_flow.observation.chatting_observation import ChattingObservation +from src.chat.focus_chat.hfc_utils import create_empty_anchor_message + +# 以下为类型注解需要 +from src.chat.message_receive.chat_stream import ChatStream +from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor +from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer +from src.chat.focus_chat.info.obs_info import ObsInfo + +logger = get_logger("message_api") + +class MessageAPI: + """消息API模块 + + 提供了发送消息、获取消息历史等功能 + """ + + async def send_message(self, type: str, data: str, target: Optional[str] = "", display_message: str = "") -> bool: + """发送消息的简化方法 + + Args: + type: 消息类型,如"text"、"image"等 + data: 消息内容 + target: 目标消息(可选) + display_message: 显示的消息内容(可选) + + Returns: + bool: 是否发送成功 + """ + try: + expressor: DefaultExpressor = self._services.get("expressor") + chat_stream: ChatStream = self._services.get("chat_stream") + + if not expressor or not chat_stream: + logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") + return False + + # 获取锚定消息(如果有) + observations = self._services.get("observations", []) + + if len(observations) > 0: + chatting_observation: ChattingObservation = next( + (obs for obs in observations if isinstance(obs, ChattingObservation)), None + ) + + if chatting_observation: + anchor_message = chatting_observation.search_message_by_text(target) + else: + anchor_message = None + else: + anchor_message = None + + # 如果没有找到锚点消息,创建一个占位符 + if not anchor_message: + logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") + anchor_message = await create_empty_anchor_message( + chat_stream.platform, chat_stream.group_info, chat_stream + ) + else: + anchor_message.update_chat_stream(chat_stream) + + response_set = [ + (type, data), + ] + + # 调用内部方法发送消息 + success = await expressor.send_response_messages( + anchor_message=anchor_message, + response_set=response_set, + display_message=display_message, + ) + + return success + except Exception as e: + logger.error(f"{self.log_prefix} 发送消息时出错: {e}") + traceback.print_exc() + return False + + async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool: + """通过expressor发送文本消息的简化方法 + + Args: + text: 要发送的消息文本 + target: 目标消息(可选) + + Returns: + bool: 是否发送成功 + """ + expressor: DefaultExpressor = self._services.get("expressor") + chat_stream: ChatStream = self._services.get("chat_stream") + + if not expressor or not chat_stream: + logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") + return False + + # 构造简化的动作数据 + reply_data = {"text": text, "target": target or "", "emojis": []} + + # 获取锚定消息(如果有) + observations = self._services.get("observations", []) + + # 查找 ChattingObservation 实例 + chatting_observation = None + for obs in observations: + if isinstance(obs, ChattingObservation): + chatting_observation = obs + break + + if not chatting_observation: + logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符") + anchor_message = await create_empty_anchor_message( + chat_stream.platform, chat_stream.group_info, chat_stream + ) + else: + anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) + if not anchor_message: + logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") + anchor_message = await create_empty_anchor_message( + chat_stream.platform, chat_stream.group_info, chat_stream + ) + else: + anchor_message.update_chat_stream(chat_stream) + + # 调用内部方法发送消息 + success, _ = await expressor.deal_reply( + cycle_timers=self.cycle_timers, + action_data=reply_data, + anchor_message=anchor_message, + reasoning=self.reasoning, + thinking_id=self.thinking_id, + ) + + return success + + async def send_message_by_replyer(self, target: Optional[str] = None, extra_info_block: Optional[str] = None) -> bool: + """通过replyer发送消息的简化方法 + + Args: + target: 目标消息(可选) + extra_info_block: 额外信息块(可选) + + Returns: + bool: 是否发送成功 + """ + replyer: DefaultReplyer = self._services.get("replyer") + chat_stream: ChatStream = self._services.get("chat_stream") + + if not replyer or not chat_stream: + logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") + return False + + # 构造简化的动作数据 + reply_data = {"target": target or "", "extra_info_block": extra_info_block} + + # 获取锚定消息(如果有) + observations = self._services.get("observations", []) + + # 查找 ChattingObservation 实例 + chatting_observation = None + for obs in observations: + if isinstance(obs, ChattingObservation): + chatting_observation = obs + break + + if not chatting_observation: + logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符") + anchor_message = await create_empty_anchor_message( + chat_stream.platform, chat_stream.group_info, chat_stream + ) + else: + anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) + if not anchor_message: + logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") + anchor_message = await create_empty_anchor_message( + chat_stream.platform, chat_stream.group_info, chat_stream + ) + else: + anchor_message.update_chat_stream(chat_stream) + + # 调用内部方法发送消息 + success, _ = await replyer.deal_reply( + cycle_timers=self.cycle_timers, + action_data=reply_data, + anchor_message=anchor_message, + reasoning=self.reasoning, + thinking_id=self.thinking_id, + ) + + return success + + def get_chat_type(self) -> str: + """获取当前聊天类型 + + Returns: + str: 聊天类型 ("group" 或 "private") + """ + chat_stream: ChatStream = self._services.get("chat_stream") + if chat_stream and hasattr(chat_stream, "group_info"): + return "group" if chat_stream.group_info else "private" + return "unknown" + + def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]: + """获取最近的消息 + + Args: + count: 要获取的消息数量 + + Returns: + List[Dict]: 消息列表,每个消息包含发送者、内容等信息 + """ + messages = [] + observations = self._services.get("observations", []) + + if observations and len(observations) > 0: + obs = observations[0] + if hasattr(obs, "get_talking_message"): + obs: ObsInfo + raw_messages = obs.get_talking_message() + # 转换为简化格式 + for msg in raw_messages[-count:]: + simple_msg = { + "sender": msg.get("sender", "未知"), + "content": msg.get("content", ""), + "timestamp": msg.get("timestamp", 0), + } + messages.append(simple_msg) + + return messages \ No newline at end of file diff --git a/src/chat/actions/plugin_api/utils_api.py b/src/chat/actions/plugin_api/utils_api.py new file mode 100644 index 000000000..b5c476fa1 --- /dev/null +++ b/src/chat/actions/plugin_api/utils_api.py @@ -0,0 +1,121 @@ +import os +import json +import time +from typing import Any, Dict, List, Optional +from src.common.logger_manager import get_logger + +logger = get_logger("utils_api") + +class UtilsAPI: + """工具类API模块 + + 提供了各种辅助功能 + """ + + def get_plugin_path(self) -> str: + """获取当前插件的路径 + + Returns: + str: 插件目录的绝对路径 + """ + import inspect + plugin_module_path = inspect.getfile(self.__class__) + plugin_dir = os.path.dirname(plugin_module_path) + return plugin_dir + + def read_json_file(self, file_path: str, default: Any = None) -> Any: + """读取JSON文件 + + Args: + file_path: 文件路径,可以是相对于插件目录的路径 + default: 如果文件不存在或读取失败时返回的默认值 + + Returns: + Any: JSON数据或默认值 + """ + try: + # 如果是相对路径,则相对于插件目录 + if not os.path.isabs(file_path): + file_path = os.path.join(self.get_plugin_path(), file_path) + + if not os.path.exists(file_path): + logger.warning(f"{self.log_prefix} 文件不存在: {file_path}") + return default + + with open(file_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"{self.log_prefix} 读取JSON文件出错: {e}") + return default + + def write_json_file(self, file_path: str, data: Any, indent: int = 2) -> bool: + """写入JSON文件 + + Args: + file_path: 文件路径,可以是相对于插件目录的路径 + data: 要写入的数据 + indent: JSON缩进 + + Returns: + bool: 是否写入成功 + """ + try: + # 如果是相对路径,则相对于插件目录 + if not os.path.isabs(file_path): + file_path = os.path.join(self.get_plugin_path(), file_path) + + # 确保目录存在 + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=indent) + return True + except Exception as e: + logger.error(f"{self.log_prefix} 写入JSON文件出错: {e}") + return False + + def get_timestamp(self) -> int: + """获取当前时间戳 + + Returns: + int: 当前时间戳(秒) + """ + return int(time.time()) + + def format_time(self, timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str: + """格式化时间 + + Args: + timestamp: 时间戳,如果为None则使用当前时间 + format_str: 时间格式字符串 + + Returns: + str: 格式化后的时间字符串 + """ + import datetime + if timestamp is None: + timestamp = time.time() + return datetime.datetime.fromtimestamp(timestamp).strftime(format_str) + + def parse_time(self, time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int: + """解析时间字符串为时间戳 + + Args: + time_str: 时间字符串 + format_str: 时间格式字符串 + + Returns: + int: 时间戳(秒) + """ + import datetime + dt = datetime.datetime.strptime(time_str, format_str) + return int(dt.timestamp()) + + def generate_unique_id(self) -> str: + """生成唯一ID + + Returns: + str: 唯一ID + """ + import uuid + return str(uuid.uuid4()) \ No newline at end of file