🤖 自动格式化代码 [skip ci]

This commit is contained in:
github-actions[bot]
2025-05-23 07:49:17 +00:00
parent 68b248dd9d
commit c874af5c87
4 changed files with 112 additions and 65 deletions

View File

@@ -8,7 +8,7 @@ from src.chat.person_info.person_info import person_info_manager
from abc import abstractmethod from abc import abstractmethod
import os import os
import inspect import inspect
import toml # 导入 toml 库 import toml # 导入 toml 库
logger = get_logger("plugin_action") logger = get_logger("plugin_action")
@@ -18,16 +18,25 @@ class PluginAction(BaseAction):
封装了主程序内部依赖提供简化的API接口给插件开发者 封装了主程序内部依赖提供简化的API接口给插件开发者
""" """
action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名
def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, global_config: Optional[dict] = None, **kwargs): action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名
def __init__(
self,
action_data: dict,
reasoning: str,
cycle_timers: dict,
thinking_id: str,
global_config: Optional[dict] = None,
**kwargs,
):
"""初始化插件动作基类""" """初始化插件动作基类"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id) super().__init__(action_data, reasoning, cycle_timers, thinking_id)
# 存储内部服务和对象引用 # 存储内部服务和对象引用
self._services = {} self._services = {}
self._global_config = global_config # 存储全局配置的只读引用 self._global_config = global_config # 存储全局配置的只读引用
self.config: Dict[str, Any] = {} # 用于存储插件自身的配置 self.config: Dict[str, Any] = {} # 用于存储插件自身的配置
# 从kwargs提取必要的内部服务 # 从kwargs提取必要的内部服务
if "observations" in kwargs: if "observations" in kwargs:
@@ -38,7 +47,7 @@ class PluginAction(BaseAction):
self._services["chat_stream"] = kwargs["chat_stream"] self._services["chat_stream"] = kwargs["chat_stream"]
self.log_prefix = kwargs.get("log_prefix", "") self.log_prefix = kwargs.get("log_prefix", "")
self._load_plugin_config() # 初始化时加载插件配置 self._load_plugin_config() # 初始化时加载插件配置
def _load_plugin_config(self): def _load_plugin_config(self):
""" """
@@ -49,7 +58,9 @@ class PluginAction(BaseAction):
仅支持 TOML (.toml) 格式。 仅支持 TOML (.toml) 格式。
""" """
if not self.action_config_file_name: if not self.action_config_file_name:
logger.debug(f"{self.log_prefix} 插件 {self.__class__.__name__} 未指定 action_config_file_name不加载插件配置。") logger.debug(
f"{self.log_prefix} 插件 {self.__class__.__name__} 未指定 action_config_file_name不加载插件配置。"
)
return return
try: try:
@@ -58,23 +69,29 @@ class PluginAction(BaseAction):
config_file_path = os.path.join(plugin_dir, self.action_config_file_name) config_file_path = os.path.join(plugin_dir, self.action_config_file_name)
if not os.path.exists(config_file_path): if not os.path.exists(config_file_path):
logger.warning(f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置文件 {config_file_path} 不存在。") logger.warning(
f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置文件 {config_file_path} 不存在。"
)
return return
file_ext = os.path.splitext(self.action_config_file_name)[1].lower() file_ext = os.path.splitext(self.action_config_file_name)[1].lower()
if file_ext == '.toml': if file_ext == ".toml":
with open(config_file_path, 'r', encoding='utf-8') as f: with open(config_file_path, "r", encoding="utf-8") as f:
self.config = toml.load(f) or {} self.config = toml.load(f) or {}
logger.info(f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置已从 {config_file_path} 加载。") logger.info(f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置已从 {config_file_path} 加载。")
else: else:
logger.warning(f"{self.log_prefix} 不支持的插件配置文件格式: {file_ext}。仅支持 .toml。插件配置未加载。") logger.warning(
self.config = {} #确保未加载时为空字典 f"{self.log_prefix} 不支持的插件配置文件格式: {file_ext}。仅支持 .toml。插件配置未加载。"
)
self.config = {} # 确保未加载时为空字典
return return
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 加载插件 {self.__class__.__name__} 的配置文件 {self.action_config_file_name} 时出错: {e}") logger.error(
self.config = {} # 出错时确保 config 是一个空字典 f"{self.log_prefix} 加载插件 {self.__class__.__name__} 的配置文件 {self.action_config_file_name} 时出错: {e}"
)
self.config = {} # 出错时确保 config 是一个空字典
def get_global_config(self, key: str, default: Any = None) -> Any: def get_global_config(self, key: str, default: Any = None) -> Any:
""" """

View File

@@ -2,4 +2,3 @@
# 导入所有动作模块以确保装饰器被执行 # 导入所有动作模块以确保装饰器被执行
from . import pic_action # noqa from . import pic_action # noqa

View File

@@ -27,6 +27,7 @@ default_seed = 42
# custom_parameter = "some_value" # custom_parameter = "some_value"
""" """
def generate_config(): def generate_config():
# 获取当前脚本所在的目录 # 获取当前脚本所在的目录
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -44,5 +45,6 @@ def generate_config():
print(f"配置文件已存在: {config_file_path}") print(f"配置文件已存在: {config_file_path}")
print("未进行任何更改。如果您想重新生成,请先删除或重命名现有文件。") print("未进行任何更改。如果您想重新生成,请先删除或重命名现有文件。")
if __name__ == "__main__": if __name__ == "__main__":
generate_config() generate_config()

View File

@@ -1,9 +1,9 @@
import asyncio import asyncio
import json import json
import urllib.request import urllib.request
import urllib.error import urllib.error
import base64 # 新增用于Base64编码 import base64 # 新增用于Base64编码
import traceback # 新增:用于打印堆栈跟踪 import traceback # 新增:用于打印堆栈跟踪
from typing import Tuple from typing import Tuple
from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
@@ -22,9 +22,7 @@ class PicAction(PluginAction):
"""根据描述使用火山引擎HTTP API生成图片的动作处理类""" """根据描述使用火山引擎HTTP API生成图片的动作处理类"""
action_name = "pic_action" action_name = "pic_action"
action_description = ( action_description = "可以根据特定的描述,使用火山引擎模型生成并发送一张图片 (通过HTTP API)"
"可以根据特定的描述,使用火山引擎模型生成并发送一张图片 (通过HTTP API)"
)
action_parameters = { action_parameters = {
"description": "图片描述,输入你想要生成并发送的图片的描述,必填", "description": "图片描述,输入你想要生成并发送的图片的描述,必填",
"size": "图片尺寸,例如 '1024x1024' (可选, 默认从配置或 '1024x1024')", "size": "图片尺寸,例如 '1024x1024' (可选, 默认从配置或 '1024x1024')",
@@ -36,17 +34,27 @@ class PicAction(PluginAction):
default = False default = False
action_config_file_name = "pic_action_config.toml" action_config_file_name = "pic_action_config.toml"
def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, global_config: dict = None, **kwargs): def __init__(
self,
action_data: dict,
reasoning: str,
cycle_timers: dict,
thinking_id: str,
global_config: dict = None,
**kwargs,
):
super().__init__(action_data, reasoning, cycle_timers, thinking_id, global_config, **kwargs) super().__init__(action_data, reasoning, cycle_timers, thinking_id, global_config, **kwargs)
http_base_url = self.config.get("base_url") http_base_url = self.config.get("base_url")
http_api_key = self.config.get("volcano_generate_api_key") http_api_key = self.config.get("volcano_generate_api_key")
if not (http_base_url and http_api_key): if not (http_base_url and http_api_key):
logger.error(f"{self.log_prefix} PicAction初始化, 但HTTP配置 (base_url 或 volcano_generate_api_key) 缺失. HTTP图片生成将失败.") logger.error(
f"{self.log_prefix} PicAction初始化, 但HTTP配置 (base_url 或 volcano_generate_api_key) 缺失. HTTP图片生成将失败."
)
else: else:
logger.info(f"{self.log_prefix} HTTP方式初始化完成. Base URL: {http_base_url}, API Key已配置.") logger.info(f"{self.log_prefix} HTTP方式初始化完成. Base URL: {http_base_url}, API Key已配置.")
# _restore_env_vars 方法不再需要,已移除 # _restore_env_vars 方法不再需要,已移除
async def process(self) -> Tuple[bool, str]: async def process(self) -> Tuple[bool, str]:
@@ -70,40 +78,50 @@ class PicAction(PluginAction):
default_model = self.config.get("default_model", "doubao-seedream-3-0-t2i-250415") default_model = self.config.get("default_model", "doubao-seedream-3-0-t2i-250415")
image_size = self.action_data.get("size", self.config.get("default_size", "1024x1024")) image_size = self.action_data.get("size", self.config.get("default_size", "1024x1024"))
# guidance_scale 现在完全由配置文件控制 # guidance_scale 现在完全由配置文件控制
guidance_scale_input = self.config.get("default_guidance_scale", 2.5) # 默认2.5 guidance_scale_input = self.config.get("default_guidance_scale", 2.5) # 默认2.5
guidance_scale_val = 2.5 # Fallback default guidance_scale_val = 2.5 # Fallback default
try: try:
guidance_scale_val = float(guidance_scale_input) guidance_scale_val = float(guidance_scale_input)
except (ValueError, TypeError): except (ValueError, TypeError):
logger.warning(f"{self.log_prefix} 配置文件中的 default_guidance_scale 值 '{guidance_scale_input}' 无效 (应为浮点数),使用默认值 2.5。") logger.warning(
f"{self.log_prefix} 配置文件中的 default_guidance_scale 值 '{guidance_scale_input}' 无效 (应为浮点数),使用默认值 2.5。"
)
guidance_scale_val = 2.5 guidance_scale_val = 2.5
# Seed parameter - ensure it's always an integer # Seed parameter - ensure it's always an integer
seed_config_value = self.config.get("default_seed") seed_config_value = self.config.get("default_seed")
seed_val = 42 # Default seed if not configured or invalid seed_val = 42 # Default seed if not configured or invalid
if seed_config_value is not None: if seed_config_value is not None:
try: try:
seed_val = int(seed_config_value) seed_val = int(seed_config_value)
except (ValueError, TypeError): except (ValueError, TypeError):
logger.warning(f"{self.log_prefix} 配置文件中的 default_seed ('{seed_config_value}') 无效,将使用默认种子 42。") logger.warning(
f"{self.log_prefix} 配置文件中的 default_seed ('{seed_config_value}') 无效,将使用默认种子 42。"
)
# seed_val is already 42 # seed_val is already 42
else: else:
logger.info(f"{self.log_prefix} 未在配置中找到 default_seed将使用默认种子 42。建议在配置文件中添加 default_seed。") logger.info(
f"{self.log_prefix} 未在配置中找到 default_seed将使用默认种子 42。建议在配置文件中添加 default_seed。"
)
# seed_val is already 42 # seed_val is already 42
# Watermark 现在完全由配置文件控制 # Watermark 现在完全由配置文件控制
effective_watermark_source = self.config.get("default_watermark", True) # 默认True effective_watermark_source = self.config.get("default_watermark", True) # 默认True
if isinstance(effective_watermark_source, bool): if isinstance(effective_watermark_source, bool):
watermark_val = effective_watermark_source watermark_val = effective_watermark_source
elif isinstance(effective_watermark_source, str): elif isinstance(effective_watermark_source, str):
watermark_val = effective_watermark_source.lower() == 'true' watermark_val = effective_watermark_source.lower() == "true"
else: else:
logger.warning(f"{self.log_prefix} 配置文件中的 default_watermark 值 '{effective_watermark_source}' 无效 (应为布尔值或 \'true\'/'false\'),使用默认值 True。") logger.warning(
f"{self.log_prefix} 配置文件中的 default_watermark 值 '{effective_watermark_source}' 无效 (应为布尔值或 'true'/'false'),使用默认值 True。"
)
watermark_val = True watermark_val = True
await self.send_message_by_expressor(f"收到!正在为您生成关于 '{description}' 的图片,请稍候...(模型: {default_model}, 尺寸: {image_size}") await self.send_message_by_expressor(
f"收到!正在为您生成关于 '{description}' 的图片,请稍候...(模型: {default_model}, 尺寸: {image_size}"
)
try: try:
success, result = await asyncio.to_thread( success, result = await asyncio.to_thread(
@@ -113,23 +131,20 @@ class PicAction(PluginAction):
size=image_size, size=image_size,
seed=seed_val, seed=seed_val,
guidance_scale=guidance_scale_val, guidance_scale=guidance_scale_val,
watermark=watermark_val watermark=watermark_val,
) )
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} (HTTP) 异步请求执行失败: {e!r}", exc_info=True) logger.error(f"{self.log_prefix} (HTTP) 异步请求执行失败: {e!r}", exc_info=True)
traceback.print_exc() traceback.print_exc()
success = False success = False
result = f"图片生成服务遇到意外问题: {str(e)[:100]}" result = f"图片生成服务遇到意外问题: {str(e)[:100]}"
if success: if success:
image_url = result image_url = result
logger.info(f"{self.log_prefix} 图片URL获取成功: {image_url[:70]}... 下载并编码.") logger.info(f"{self.log_prefix} 图片URL获取成功: {image_url[:70]}... 下载并编码.")
try: try:
encode_success, encode_result = await asyncio.to_thread( encode_success, encode_result = await asyncio.to_thread(self._download_and_encode_base64, image_url)
self._download_and_encode_base64,
image_url
)
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} (B64) 异步下载/编码失败: {e!r}", exc_info=True) logger.error(f"{self.log_prefix} (B64) 异步下载/编码失败: {e!r}", exc_info=True)
traceback.print_exc() traceback.print_exc()
@@ -149,7 +164,7 @@ class PicAction(PluginAction):
await self.send_message_by_expressor(f"获取到图片URL但在处理图片时失败了{encode_result}") await self.send_message_by_expressor(f"获取到图片URL但在处理图片时失败了{encode_result}")
return False, f"图片处理失败(Base64): {encode_result}" return False, f"图片处理失败(Base64): {encode_result}"
else: else:
error_message = result error_message = result
await self.send_message_by_expressor(f"哎呀,生成图片时遇到问题:{error_message}") await self.send_message_by_expressor(f"哎呀,生成图片时遇到问题:{error_message}")
return False, f"图片生成失败: {error_message}" return False, f"图片生成失败: {error_message}"
@@ -160,22 +175,24 @@ class PicAction(PluginAction):
with urllib.request.urlopen(image_url, timeout=30) as response: with urllib.request.urlopen(image_url, timeout=30) as response:
if response.status == 200: if response.status == 200:
image_bytes = response.read() image_bytes = response.read()
base64_encoded_image = base64.b64encode(image_bytes).decode('utf-8') base64_encoded_image = base64.b64encode(image_bytes).decode("utf-8")
logger.info(f"{self.log_prefix} (B64) 图片下载编码完成. Base64长度: {len(base64_encoded_image)}") logger.info(f"{self.log_prefix} (B64) 图片下载编码完成. Base64长度: {len(base64_encoded_image)}")
return True, base64_encoded_image return True, base64_encoded_image
else: else:
error_msg = f"下载图片失败 (状态: {response.status})" error_msg = f"下载图片失败 (状态: {response.status})"
logger.error(f"{self.log_prefix} (B64) {error_msg} URL: {image_url}") logger.error(f"{self.log_prefix} (B64) {error_msg} URL: {image_url}")
return False, error_msg return False, error_msg
except Exception as e: # Catches all exceptions from urlopen, b64encode, etc. except Exception as e: # Catches all exceptions from urlopen, b64encode, etc.
logger.error(f"{self.log_prefix} (B64) 下载或编码时错误: {e!r}", exc_info=True) logger.error(f"{self.log_prefix} (B64) 下载或编码时错误: {e!r}", exc_info=True)
traceback.print_exc() traceback.print_exc()
return False, f"下载或编码图片时发生错误: {str(e)[:100]}" return False, f"下载或编码图片时发生错误: {str(e)[:100]}"
def _make_http_image_request(self, prompt: str, model: str, size: str, seed: int | None, guidance_scale: float, watermark: bool) -> Tuple[bool, str]: def _make_http_image_request(
self, prompt: str, model: str, size: str, seed: int | None, guidance_scale: float, watermark: bool
) -> Tuple[bool, str]:
base_url = self.config.get("base_url") base_url = self.config.get("base_url")
generate_api_key = self.config.get("volcano_generate_api_key") generate_api_key = self.config.get("volcano_generate_api_key")
endpoint = f"{base_url.rstrip('/')}/images/generations" endpoint = f"{base_url.rstrip('/')}/images/generations"
payload_dict = { payload_dict = {
@@ -185,22 +202,26 @@ class PicAction(PluginAction):
"size": size, "size": size,
"guidance_scale": guidance_scale, "guidance_scale": guidance_scale,
"watermark": watermark, "watermark": watermark,
"seed": seed, # seed is now always an int from process() "seed": seed, # seed is now always an int from process()
"api-key": generate_api_key "api-key": generate_api_key,
} }
# if seed is not None: # No longer needed, seed is always an int # if seed is not None: # No longer needed, seed is always an int
# payload_dict["seed"] = seed # payload_dict["seed"] = seed
data = json.dumps(payload_dict).encode('utf-8') data = json.dumps(payload_dict).encode("utf-8")
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Accept": "application/json", "Accept": "application/json",
"Authorization": f"Bearer {generate_api_key}" "Authorization": f"Bearer {generate_api_key}",
} }
logger.info(f"{self.log_prefix} (HTTP) 发起图片请求: {model}, Prompt: {prompt[:30]}... To: {endpoint}") logger.info(f"{self.log_prefix} (HTTP) 发起图片请求: {model}, Prompt: {prompt[:30]}... To: {endpoint}")
logger.debug(f"{self.log_prefix} (HTTP) Request Headers: {{...Authorization: Bearer {generate_api_key[:10]}...}}") logger.debug(
logger.debug(f"{self.log_prefix} (HTTP) Request Body (api-key omitted): {json.dumps({k: v for k, v in payload_dict.items() if k != 'api-key'})}") f"{self.log_prefix} (HTTP) Request Headers: {{...Authorization: Bearer {generate_api_key[:10]}...}}"
)
logger.debug(
f"{self.log_prefix} (HTTP) Request Body (api-key omitted): {json.dumps({k: v for k, v in payload_dict.items() if k != 'api-key'})}"
)
req = urllib.request.Request(endpoint, data=data, headers=headers, method="POST") req = urllib.request.Request(endpoint, data=data, headers=headers, method="POST")
@@ -208,26 +229,34 @@ class PicAction(PluginAction):
with urllib.request.urlopen(req, timeout=60) as response: with urllib.request.urlopen(req, timeout=60) as response:
response_status = response.status response_status = response.status
response_body_bytes = response.read() response_body_bytes = response.read()
response_body_str = response_body_bytes.decode('utf-8') response_body_str = response_body_bytes.decode("utf-8")
logger.info(f"{self.log_prefix} (HTTP) 响应: {response_status}. Preview: {response_body_str[:150]}...") logger.info(f"{self.log_prefix} (HTTP) 响应: {response_status}. Preview: {response_body_str[:150]}...")
if 200 <= response_status < 300: if 200 <= response_status < 300:
response_data = json.loads(response_body_str) response_data = json.loads(response_body_str)
image_url = None image_url = None
if isinstance(response_data.get("data"), list) and response_data["data"] and isinstance(response_data["data"][0], dict): if (
isinstance(response_data.get("data"), list)
and response_data["data"]
and isinstance(response_data["data"][0], dict)
):
image_url = response_data["data"][0].get("url") image_url = response_data["data"][0].get("url")
elif response_data.get("url"): elif response_data.get("url"):
image_url = response_data.get("url") image_url = response_data.get("url")
if image_url: if image_url:
logger.info(f"{self.log_prefix} (HTTP) 图片生成成功URL: {image_url[:70]}...") logger.info(f"{self.log_prefix} (HTTP) 图片生成成功URL: {image_url[:70]}...")
return True, image_url return True, image_url
else: else:
logger.error(f"{self.log_prefix} (HTTP) API成功但无图片URL. 响应预览: {response_body_str[:300]}...") logger.error(
f"{self.log_prefix} (HTTP) API成功但无图片URL. 响应预览: {response_body_str[:300]}..."
)
return False, "图片生成API响应成功但未找到图片URL" return False, "图片生成API响应成功但未找到图片URL"
else: else:
logger.error(f"{self.log_prefix} (HTTP) API请求失败. 状态: {response.status}. 正文: {response_body_str[:300]}...") logger.error(
f"{self.log_prefix} (HTTP) API请求失败. 状态: {response.status}. 正文: {response_body_str[:300]}..."
)
return False, f"图片API请求失败(状态码 {response.status})" return False, f"图片API请求失败(状态码 {response.status})"
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} (HTTP) 图片生成时意外错误: {e!r}", exc_info=True) logger.error(f"{self.log_prefix} (HTTP) 图片生成时意外错误: {e!r}", exc_info=True)