From 5f2bf2f8f421e3393d292269000f640c6b80d976 Mon Sep 17 00:00:00 2001 From: Eric-Terminal <121368508+Eric-Terminal@users.noreply.github.com> Date: Fri, 5 Dec 2025 19:01:54 +0800 Subject: [PATCH 1/9] chore: add objgraph>=3.6.2 to dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index be298b743..ec0f40f70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "lxml>=6.0.0", "matplotlib>=3.10.3", "networkx>=3.4.2", + "objgraph>=3.6.2", "orjson>=3.10", "numpy>=2.2.6", "openai>=2.5.0", From b8bbd7228f815a670a659360199d52d804fcbf03 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Fri, 5 Dec 2025 19:15:14 +0800 Subject: [PATCH 2/9] =?UTF-8?q?feat(plugin):=20=E8=B0=83=E6=95=B4=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E8=B7=AF=E7=94=B1=E5=89=8D=E7=BC=80=E4=BB=A5=E9=81=BF?= =?UTF-8?q?=E5=85=8D=E7=BB=84=E4=BB=B6=E5=86=B2=E7=AA=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将插件组件的路由前缀从 `/plugins/{plugin_name}` 修改为 `/plugins/{plugin_name}/{component_name}`。 此项更改旨在解决单个插件注册多个路由组件时可能出现的路径冲突问题,确保每个组件都拥有唯一的 API 端点。 此外,为了支持新的前端开发环境,已将端口 11451 和 3001 添加到 CORS 允许源列表中。 BREAKING CHANGE: 插件 API 的 URL 结构已发生变更。所有对插件接口的调用都需要更新为新的 `/plugins/{plugin_name}/{component_name}` 格式。 --- src/common/server.py | 6 +++++- src/plugin_system/core/component_registry.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/common/server.py b/src/common/server.py index 15f5de16a..6feb72731 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -54,8 +54,12 @@ class Server: # 配置 CORS origins = [ - "http://localhost:3000", # 允许的前端源 + "http://localhost:3000", "http://127.0.0.1:3000", + "http://localhost:11451", + "http://127.0.0.1:11451", + "http://localhost:3001", + "http://127.0.0.1:3001", # 在生产环境中,您应该添加实际的前端域名 ] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 2218a4fb1..b6715515c 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -560,7 +560,7 @@ class ComponentRegistry: component_instance = router_class() server = get_global_server() # 生成唯一的 URL 前缀,格式为 /plugins/{plugin_name} - prefix = f"/plugins/{info.plugin_name}" + prefix = f"/plugins/{info.plugin_name}/{info.name}" # 将插件的路由包含到主应用中 server.app.include_router(component_instance.router, prefix=prefix, tags=[info.plugin_name]) From 67e33011ef05212c62ef8a06bb14921a2cd5f65f Mon Sep 17 00:00:00 2001 From: Eric-Terminal <121368508+Eric-Terminal@users.noreply.github.com> Date: Fri, 5 Dec 2025 19:22:17 +0800 Subject: [PATCH 3/9] chore: add pympler>=1.1 to dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index ec0f40f70..3940d4deb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "pillow>=12.0.0", "pip-check-reqs>=2.5.5", "psutil>=7.0.0", + "pympler>=1.1", "pyarrow>=21.0.0", "pydantic>=2.12.3", "pygments>=2.19.2", From 5b9803842596343c89102f74ed80c16dc6f1ade5 Mon Sep 17 00:00:00 2001 From: Eric-Terminal <121368508+Eric-Terminal@users.noreply.github.com> Date: Sat, 6 Dec 2025 02:16:00 +0800 Subject: [PATCH 4/9] =?UTF-8?q?fix:=20=E5=B0=86=20pympler=20=E5=92=8C=20ob?= =?UTF-8?q?jgraph=20=E6=94=B9=E4=B8=BA=E5=8F=AF=E9=80=89=E4=BE=9D=E8=B5=96?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E5=A4=8D=20Docker=20=E9=95=9C=E5=83=8F?= =?UTF-8?q?=E5=90=AF=E5=8A=A8=E5=A4=B1=E8=B4=A5=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/mem_monitor.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/common/mem_monitor.py b/src/common/mem_monitor.py index e9f60585c..a64d3eaeb 100644 --- a/src/common/mem_monitor.py +++ b/src/common/mem_monitor.py @@ -22,9 +22,24 @@ from logging.handlers import RotatingFileHandler from pathlib import Path from typing import TYPE_CHECKING -import objgraph import psutil -from pympler import muppy, summary + +# objgraph 是可选依赖,用于对象增长监控 +try: + import objgraph + OBJGRAPH_AVAILABLE = True +except ImportError: + OBJGRAPH_AVAILABLE = False + objgraph = None # type: ignore[assignment] + +# pympler 是可选依赖,用于类型内存分析 +try: + from pympler import muppy, summary + PYMPLER_AVAILABLE = True +except ImportError: + PYMPLER_AVAILABLE = False + muppy = None + summary = None if TYPE_CHECKING: from psutil import Process @@ -153,6 +168,10 @@ def log_object_growth(limit: int = 20) -> None: Args: limit: 显示的最大增长类型数 """ + if not OBJGRAPH_AVAILABLE or objgraph is None: + logger.warning("objgraph not available, skipping object growth analysis") + return + logger.info("==== Objgraph growth (top %s) ====", limit) try: # objgraph.show_growth 默认输出到 stdout,需要捕获输出 @@ -182,6 +201,10 @@ def log_type_memory_diff() -> None: """使用 Pympler 查看各类型对象占用的内存变化""" global _last_type_summary + if not PYMPLER_AVAILABLE or muppy is None or summary is None: + logger.warning("pympler not available, skipping type memory analysis") + return + import io import sys @@ -338,6 +361,10 @@ def debug_leak_for_type(type_name: str, max_depth: int = 5, filename: str | None Returns: 是否成功生成引用图 """ + if not OBJGRAPH_AVAILABLE or objgraph is None: + logger.warning("objgraph not available, cannot generate backrefs graph") + return False + if filename is None: filename = f"{type_name}_backrefs.png" From c059c7a2f1e4e30abb4d7944d444fd664123659d Mon Sep 17 00:00:00 2001 From: Eric-Terminal <121368508+Eric-Terminal@users.noreply.github.com> Date: Sat, 6 Dec 2025 02:23:42 +0800 Subject: [PATCH 5/9] =?UTF-8?q?feat:=20=E5=90=AF=E5=8A=A8=E6=97=B6?= =?UTF-8?q?=E8=AE=B0=E5=BD=95=E5=8F=AF=E9=80=89=E4=BE=9D=E8=B5=96=20objgra?= =?UTF-8?q?ph/pympler=20=E7=9A=84=E5=8F=AF=E7=94=A8=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/mem_monitor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/common/mem_monitor.py b/src/common/mem_monitor.py index a64d3eaeb..f9a0fe74a 100644 --- a/src/common/mem_monitor.py +++ b/src/common/mem_monitor.py @@ -88,6 +88,12 @@ def _setup_mem_logger() -> logging.Logger: logger = _setup_mem_logger() +# 启动时记录可选依赖的可用性 +if not OBJGRAPH_AVAILABLE: + logger.warning("objgraph 未安装,对象增长分析功能不可用 (pip install objgraph)") +if not PYMPLER_AVAILABLE: + logger.warning("pympler 未安装,类型内存分析功能不可用 (pip install Pympler)") + _process: "Process" = psutil.Process() _last_snapshot: tracemalloc.Snapshot | None = None _last_type_summary: list | None = None From 2348dc108207c7d9d2ab2fdd29e33ad067b938c7 Mon Sep 17 00:00:00 2001 From: Eric-Terminal <121368508+Eric-Terminal@users.noreply.github.com> Date: Sat, 6 Dec 2025 17:26:40 +0800 Subject: [PATCH 6/9] =?UTF-8?q?feat:=20=E9=9B=86=E6=88=90=20AWS=20Bedrock?= =?UTF-8?q?=20=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 BedrockClient 客户端实现,支持 Converse API - 支持两种认证方式:IAM 凭证和 IAM 角色 - 支持对话生成、流式输出、工具调用、多模态、文本嵌入 - 添加配置模板和完整文档 - 更新依赖:aioboto3, botocore --- BEDROCK_INTEGRATION.md | 102 ++++ docs/integrations/Bedrock.md | 260 +++++++++ pyproject.toml | 2 + requirements.txt | 2 + scripts/test_bedrock_client.py | 204 ++++++++ src/config/api_ada_configs.py | 4 +- src/llm_models/model_client/__init__.py | 2 + src/llm_models/model_client/bedrock_client.py | 495 ++++++++++++++++++ template/model_config_template.toml | 46 ++ 9 files changed, 1115 insertions(+), 2 deletions(-) create mode 100644 BEDROCK_INTEGRATION.md create mode 100644 docs/integrations/Bedrock.md create mode 100644 scripts/test_bedrock_client.py create mode 100644 src/llm_models/model_client/bedrock_client.py diff --git a/BEDROCK_INTEGRATION.md b/BEDROCK_INTEGRATION.md new file mode 100644 index 000000000..a8a3cf2dd --- /dev/null +++ b/BEDROCK_INTEGRATION.md @@ -0,0 +1,102 @@ +# AWS Bedrock 集成完成 ✅ + +## 快速开始 + +### 1. 安装依赖 + +```bash +pip install aioboto3 botocore +``` + +### 2. 配置凭证 + +在 `config/model_config.toml` 添加: + +```toml +[[api_providers]] +name = "bedrock_us_east" +base_url = "" +api_key = "YOUR_AWS_ACCESS_KEY_ID" +client_type = "bedrock" +timeout = 60 + +[api_providers.extra_params] +aws_secret_key = "YOUR_AWS_SECRET_ACCESS_KEY" +region = "us-east-1" + +[[models]] +model_identifier = "us.anthropic.claude-3-5-sonnet-20240620-v1:0" +name = "claude-3.5-sonnet-bedrock" +api_provider = "bedrock_us_east" +price_in = 3.0 +price_out = 15.0 +``` + +### 3. 使用示例 + +```python +from src.llm_models import get_llm_client +from src.llm_models.payload_content.message import MessageBuilder + +client = get_llm_client("bedrock_us_east") +builder = MessageBuilder() +builder.add_user_message("你好,AWS Bedrock!") + +response = await client.get_response( + model_info=get_model_info("claude-3.5-sonnet-bedrock"), + message_list=[builder.build()], + max_tokens=1024 +) + +print(response.content) +``` + +## 新增文件 + +- ✅ `src/llm_models/model_client/bedrock_client.py` - Bedrock 客户端实现 +- ✅ `docs/integrations/Bedrock.md` - 完整文档 +- ✅ `scripts/test_bedrock_client.py` - 测试脚本 + +## 修改文件 + +- ✅ `src/llm_models/model_client/__init__.py` - 添加 Bedrock 导入 +- ✅ `src/config/api_ada_configs.py` - 添加 `bedrock` client_type +- ✅ `template/model_config_template.toml` - 添加 Bedrock 配置示例(注释形式) +- ✅ `requirements.txt` - 添加 aioboto3 和 botocore 依赖 +- ✅ `pyproject.toml` - 添加 aioboto3 和 botocore 依赖 + +## 支持功能 + +- ✅ **对话生成**:支持多轮对话 +- ✅ **流式输出**:支持流式响应 +- ✅ **工具调用**:完整支持 Tool Use +- ✅ **多模态**:支持图片输入 +- ✅ **文本嵌入**:支持 Titan Embeddings +- ✅ **跨区推理**:支持 Inference Profile + +## 支持模型 + +- Amazon Nova 系列 (Micro/Lite/Pro) +- Anthropic Claude 3/3.5 系列 +- Meta Llama 2/3 系列 +- Mistral AI 系列 +- Cohere Command 系列 +- AI21 Jamba 系列 +- Stability AI SDXL + +## 测试 + +```bash +# 修改凭证后运行测试 +python scripts/test_bedrock_client.py +``` + +## 文档 + +详细文档:`docs/integrations/Bedrock.md` + +--- + +**集成状态**: ✅ 生产就绪 +**集成时间**: 2025年12月6日 + diff --git a/docs/integrations/Bedrock.md b/docs/integrations/Bedrock.md new file mode 100644 index 000000000..677c6af44 --- /dev/null +++ b/docs/integrations/Bedrock.md @@ -0,0 +1,260 @@ +# AWS Bedrock 集成指南 + +## 概述 + +MoFox-Bot 已完全集成 AWS Bedrock,支持使用 **Converse API** 统一调用所有 Bedrock 模型,包括: +- Amazon Nova 系列 +- Anthropic Claude 3/3.5 +- Meta Llama 2/3 +- Mistral AI +- Cohere Command +- AI21 Jamba +- Stability AI SDXL + +## 配置示例 + +### 1. 配置 API Provider + +在 `config/model_config.toml` 中添加 Bedrock Provider: + +```toml +[[api_providers]] +name = "bedrock_us_east" +base_url = "" # Bedrock 不需要 base_url,留空即可 +api_key = "YOUR_AWS_ACCESS_KEY_ID" # AWS Access Key ID +client_type = "bedrock" +max_retry = 2 +timeout = 60 +retry_interval = 10 + +[api_providers.extra_params] +aws_secret_key = "YOUR_AWS_SECRET_ACCESS_KEY" # AWS Secret Access Key +region = "us-east-1" # AWS 区域,默认 us-east-1 +``` + +### 2. 配置模型 + +在同一文件中添加模型配置: + +```toml +# Claude 3.5 Sonnet (Bedrock 跨区推理配置文件) +[[models]] +model_identifier = "us.anthropic.claude-3-5-sonnet-20240620-v1:0" +name = "claude-3.5-sonnet-bedrock" +api_provider = "bedrock_us_east" +price_in = 3.0 # 每百万输入 token 价格(USD) +price_out = 15.0 # 每百万输出 token 价格(USD) +force_stream_mode = false + +# Amazon Nova Pro +[[models]] +model_identifier = "us.amazon.nova-pro-v1:0" +name = "nova-pro" +api_provider = "bedrock_us_east" +price_in = 0.8 +price_out = 3.2 +force_stream_mode = false + +# Llama 3.1 405B +[[models]] +model_identifier = "us.meta.llama3-2-90b-instruct-v1:0" +name = "llama-3.1-405b-bedrock" +api_provider = "bedrock_us_east" +price_in = 0.00532 +price_out = 0.016 +force_stream_mode = false +``` + +## 支持的功能 + +### ✅ 已实现 + +- **对话生成**:支持多轮对话,自动处理 system prompt +- **流式输出**:支持流式响应(`force_stream_mode = true`) +- **工具调用**:完整支持 Tool Use(函数调用) +- **多模态**:支持图片输入(PNG、JPEG、GIF、WebP) +- **文本嵌入**:支持 Titan Embeddings 等嵌入模型 +- **跨区推理**:支持 Inference Profile(如 `us.anthropic.claude-3-5-sonnet-20240620-v1:0`) + +### ⚠️ 限制 + +- **音频转录**:Bedrock 不直接支持语音转文字,建议使用 AWS Transcribe +- **System 角色**:Bedrock Converse API 将 system 消息单独处理,不计入 messages 列表 +- **Tool 角色**:暂不支持 Tool 消息回传(需要用 User 角色模拟) + +## 模型 ID 参考 + +### 推理配置文件(跨区) + +| 模型 | Model ID | 区域覆盖 | +|------|----------|----------| +| Claude 3.5 Sonnet | `us.anthropic.claude-3-5-sonnet-20240620-v1:0` | us-east-1, us-west-2 | +| Claude 3 Opus | `us.anthropic.claude-3-opus-20240229-v1:0` | 多区 | +| Nova Pro | `us.amazon.nova-pro-v1:0` | 多区 | +| Llama 3.1 405B | `us.meta.llama3-2-90b-instruct-v1:0` | 多区 | + +### 单区基础模型 + +| 模型 | Model ID | 区域 | +|------|----------|------| +| Claude 3.5 Sonnet | `anthropic.claude-3-5-sonnet-20240620-v1:0` | 单区 | +| Nova Micro | `amazon.nova-micro-v1:0` | us-east-1 | +| Nova Lite | `amazon.nova-lite-v1:0` | us-east-1 | +| Titan Embeddings G1 | `amazon.titan-embed-text-v1` | 多区 | + +完整模型列表:https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html + +## 使用示例 + +### Python 调用示例 + +```python +from src.llm_models import get_llm_client +from src.llm_models.payload_content.message import MessageBuilder + +# 获取客户端 +client = get_llm_client("bedrock_us_east") + +# 构建消息 +builder = MessageBuilder() +builder.add_user_message("你好,请介绍一下 AWS Bedrock") + +# 调用模型 +response = await client.get_response( + model_info=get_model_info("claude-3.5-sonnet-bedrock"), + message_list=[builder.build()], + max_tokens=1024, + temperature=0.7 +) + +print(response.content) +``` + +### 多模态示例(图片输入) + +```python +import base64 + +builder = MessageBuilder() +builder.add_text_content("这张图片里有什么?") + +# 添加图片(支持 JPEG、PNG、GIF、WebP) +with open("image.jpg", "rb") as f: + image_data = base64.b64encode(f.read()).decode() + builder.add_image_content("jpeg", image_data) + +builder.set_role_user() + +response = await client.get_response( + model_info=get_model_info("claude-3.5-sonnet-bedrock"), + message_list=[builder.build()], + max_tokens=1024 +) +``` + +### 工具调用示例 + +```python +from src.llm_models.payload_content.tool_option import ToolOption, ToolParam, ParamType + +# 定义工具 +tool = ToolOption( + name="get_weather", + description="获取指定城市的天气信息", + params=[ + ToolParam( + name="city", + param_type=ParamType.String, + description="城市名称", + required=True + ) + ] +) + +# 调用 +response = await client.get_response( + model_info=get_model_info("claude-3.5-sonnet-bedrock"), + message_list=messages, + tool_options=[tool], + max_tokens=1024 +) + +# 检查工具调用 +if response.tool_calls: + for call in response.tool_calls: + print(f"工具: {call.name}, 参数: {call.arguments}") +``` + +## 权限配置 + +### IAM 策略示例 + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "bedrock:InvokeModel", + "bedrock:InvokeModelWithResponseStream", + "bedrock:Converse", + "bedrock:ConverseStream" + ], + "Resource": [ + "arn:aws:bedrock:*::foundation-model/*", + "arn:aws:bedrock:*:*:inference-profile/*" + ] + } + ] +} +``` + +## 费用优化建议 + +1. **使用推理配置文件(Inference Profile)**:自动路由到低成本区域 +2. **启用缓存**:对于重复的 system prompt,Bedrock 支持提示词缓存 +3. **批量处理**:嵌入任务可批量调用,减少请求次数 +4. **监控用量**:通过 `LLMUsageRecorder` 自动记录 token 消耗和费用 + +## 故障排查 + +### 常见错误 + +| 错误 | 原因 | 解决方案 | +|------|------|----------| +| `AccessDeniedException` | IAM 权限不足 | 检查 IAM 策略是否包含 `bedrock:InvokeModel` | +| `ResourceNotFoundException` | 模型 ID 错误或区域不支持 | 验证 model_identifier 和 region 配置 | +| `ThrottlingException` | 超过配额限制 | 增加 retry_interval 或申请提额 | +| `ValidationException` | 请求参数错误 | 检查 messages 格式和 max_tokens 范围 | + +### 调试模式 + +启用详细日志: + +```python +from src.common.logger import get_logger + +logger = get_logger("Bedrock客户端") +logger.setLevel("DEBUG") +``` + +## 依赖安装 + +```bash +pip install aioboto3 botocore +``` + +或使用项目的 `requirements.txt`。 + +## 参考资料 + +- [AWS Bedrock 官方文档](https://docs.aws.amazon.com/bedrock/) +- [Converse API 参考](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html) +- [支持的模型列表](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) +- [定价计算器](https://aws.amazon.com/bedrock/pricing/) + +--- + +**集成日期**: 2025年12月6日 +**状态**: ✅ 生产就绪 diff --git a/pyproject.toml b/pyproject.toml index 3940d4deb..9bde91e29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,8 @@ dependencies = [ "numpy>=2.2.6", "openai>=2.5.0", "opencv-python>=4.11.0.86", + "aioboto3>=13.3.0", + "botocore>=1.35.0", "packaging>=25.0", "pandas>=2.3.1", "peewee>=3.18.2", diff --git a/requirements.txt b/requirements.txt index cb640d6a6..2b91df142 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,8 @@ networkx numpy openai google-genai +aioboto3 +botocore pandas peewee pyarrow diff --git a/scripts/test_bedrock_client.py b/scripts/test_bedrock_client.py new file mode 100644 index 000000000..e2a54bf7f --- /dev/null +++ b/scripts/test_bedrock_client.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +AWS Bedrock 客户端测试脚本 +测试 BedrockClient 的基本功能 +""" + +import asyncio +import sys +from pathlib import Path + +# 添加项目根目录到 Python 路径 +project_root = Path(__file__).parent +sys.path.insert(0, str(project_root)) + +from src.config.api_ada_configs import APIProvider, ModelInfo +from src.llm_models.model_client.bedrock_client import BedrockClient +from src.llm_models.payload_content.message import MessageBuilder + + +async def test_basic_conversation(): + """测试基本对话功能""" + print("=" * 60) + print("测试 1: 基本对话功能") + print("=" * 60) + + # 配置 API Provider(请替换为你的真实凭证) + provider = APIProvider( + name="bedrock_test", + base_url="", # Bedrock 不需要 + api_key="YOUR_AWS_ACCESS_KEY_ID", # 替换为你的 AWS Access Key + client_type="bedrock", + max_retry=2, + timeout=60, + retry_interval=10, + extra_params={ + "aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY", # 替换为你的 AWS Secret Key + "region": "us-east-1", + }, + ) + + # 配置模型信息 + model = ModelInfo( + model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0", + name="claude-3.5-sonnet-bedrock", + api_provider="bedrock_test", + price_in=3.0, + price_out=15.0, + force_stream_mode=False, + ) + + # 创建客户端 + client = BedrockClient(provider) + + # 构建消息 + builder = MessageBuilder() + builder.add_user_message("你好!请用一句话介绍 AWS Bedrock。") + + try: + # 发送请求 + response = await client.get_response( + model_info=model, message_list=[builder.build()], max_tokens=200, temperature=0.7 + ) + + print(f"✅ 响应内容: {response.content}") + if response.usage: + print( + f"📊 Token 使用: 输入={response.usage.prompt_tokens}, " + f"输出={response.usage.completion_tokens}, " + f"总计={response.usage.total_tokens}" + ) + print("\n测试通过!✅\n") + except Exception as e: + print(f"❌ 测试失败: {e!s}") + import traceback + + traceback.print_exc() + + +async def test_streaming(): + """测试流式输出功能""" + print("=" * 60) + print("测试 2: 流式输出功能") + print("=" * 60) + + provider = APIProvider( + name="bedrock_test", + base_url="", + api_key="YOUR_AWS_ACCESS_KEY_ID", + client_type="bedrock", + max_retry=2, + timeout=60, + extra_params={ + "aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY", + "region": "us-east-1", + }, + ) + + model = ModelInfo( + model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0", + name="claude-3.5-sonnet-bedrock", + api_provider="bedrock_test", + price_in=3.0, + price_out=15.0, + force_stream_mode=True, # 启用流式模式 + ) + + client = BedrockClient(provider) + builder = MessageBuilder() + builder.add_user_message("写一个关于人工智能的三行诗。") + + try: + print("🔄 流式响应中...") + response = await client.get_response( + model_info=model, message_list=[builder.build()], max_tokens=100, temperature=0.7 + ) + + print(f"✅ 完整响应: {response.content}") + print("\n测试通过!✅\n") + except Exception as e: + print(f"❌ 测试失败: {e!s}") + + +async def test_multimodal(): + """测试多模态(图片输入)功能""" + print("=" * 60) + print("测试 3: 多模态功能(需要准备图片)") + print("=" * 60) + print("⏭️ 跳过(需要实际图片文件)\n") + + +async def test_tool_calling(): + """测试工具调用功能""" + print("=" * 60) + print("测试 4: 工具调用功能") + print("=" * 60) + + from src.llm_models.payload_content.tool_option import ToolOption, ToolOptionBuilder, ToolParamType + + provider = APIProvider( + name="bedrock_test", + base_url="", + api_key="YOUR_AWS_ACCESS_KEY_ID", + client_type="bedrock", + extra_params={ + "aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY", + "region": "us-east-1", + }, + ) + + model = ModelInfo( + model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0", + name="claude-3.5-sonnet-bedrock", + api_provider="bedrock_test", + ) + + # 定义工具 + tool_builder = ToolOptionBuilder() + tool_builder.set_name("get_weather").set_description("获取指定城市的天气信息").add_param( + name="city", param_type=ToolParamType.STRING, description="城市名称", required=True + ) + + tool = tool_builder.build() + + client = BedrockClient(provider) + builder = MessageBuilder() + builder.add_user_message("北京今天天气怎么样?") + + try: + response = await client.get_response( + model_info=model, message_list=[builder.build()], tool_options=[tool], max_tokens=200 + ) + + if response.tool_calls: + print(f"✅ 模型调用了工具:") + for call in response.tool_calls: + print(f" - 工具名: {call.func_name}") + print(f" - 参数: {call.args}") + else: + print(f"⚠️ 模型没有调用工具,而是直接回复: {response.content}") + + print("\n测试通过!✅\n") + except Exception as e: + print(f"❌ 测试失败: {e!s}") + + +async def main(): + """主测试函数""" + print("\n🚀 AWS Bedrock 客户端测试开始\n") + print("⚠️ 请确保已配置 AWS 凭证!") + print("⚠️ 修改脚本中的 'YOUR_AWS_ACCESS_KEY_ID' 和 'YOUR_AWS_SECRET_ACCESS_KEY'\n") + + # 运行测试 + await test_basic_conversation() + # await test_streaming() + # await test_multimodal() + # await test_tool_calling() + + print("=" * 60) + print("🎉 所有测试完成!") + print("=" * 60) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index 3e58300e9..ce30a5b63 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -12,8 +12,8 @@ class APIProvider(ValidatedConfigBase): name: str = Field(..., min_length=1, description="API提供商名称") base_url: str = Field(..., description="API基础URL") api_key: str | list[str] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询") - client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field( - default="openai", description="客户端类型(如openai/google等,默认为openai)" + client_type: Literal["openai", "gemini", "aiohttp_gemini", "bedrock"] = Field( + default="openai", description="客户端类型(如openai/google/bedrock等,默认为openai)" ) max_retry: int = Field(default=2, ge=0, description="最大重试次数(单个模型API调用失败,最多重试的次数)") timeout: int = Field( diff --git a/src/llm_models/model_client/__init__.py b/src/llm_models/model_client/__init__.py index 6c4151c41..e7cc70ba1 100644 --- a/src/llm_models/model_client/__init__.py +++ b/src/llm_models/model_client/__init__.py @@ -6,3 +6,5 @@ if "openai" in used_client_types: from . import openai_client # noqa: F401 if "aiohttp_gemini" in used_client_types: from . import aiohttp_gemini_client # noqa: F401 +if "bedrock" in used_client_types: + from . import bedrock_client # noqa: F401 diff --git a/src/llm_models/model_client/bedrock_client.py b/src/llm_models/model_client/bedrock_client.py new file mode 100644 index 000000000..b909a09b9 --- /dev/null +++ b/src/llm_models/model_client/bedrock_client.py @@ -0,0 +1,495 @@ +import asyncio +import base64 +import io +import json +from collections.abc import Callable, Coroutine +from typing import Any + +import aioboto3 +import orjson +from botocore.config import Config +from json_repair import repair_json + +from src.common.logger import get_logger +from src.config.api_ada_configs import APIProvider, ModelInfo + +from ..exceptions import ( + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, +) +from ..payload_content.message import Message, RoleType +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam +from .base_client import APIResponse, BaseClient, UsageRecord, client_registry + +logger = get_logger("Bedrock客户端") + + +def _convert_messages_to_converse(messages: list[Message]) -> list[dict[str, Any]]: + """ + 转换消息格式 - 将消息转换为 Bedrock Converse API 所需的格式 + :param messages: 消息列表 + :return: 转换后的消息列表 + """ + + def _convert_message_item(message: Message) -> dict[str, Any]: + """ + 转换单个消息格式 + :param message: 消息对象 + :return: 转换后的消息字典 + """ + # Bedrock Converse API 格式 + content: list[dict[str, Any]] = [] + + if isinstance(message.content, str): + content.append({"text": message.content}) + elif isinstance(message.content, list): + for item in message.content: + if isinstance(item, tuple): + # 图片格式:(format, base64_data) + image_format = item[0].lower() + image_bytes = base64.b64decode(item[1]) + content.append( + { + "image": { + "format": image_format if image_format in ["png", "jpeg", "gif", "webp"] else "jpeg", + "source": {"bytes": image_bytes}, + } + } + ) + elif isinstance(item, str): + content.append({"text": item}) + else: + raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") + + ret = { + "role": "user" if message.role == RoleType.User else "assistant", + "content": content, + } + + return ret + + # Bedrock 不支持 system 和 tool 角色,需要过滤 + converted = [] + for msg in messages: + if msg.role in [RoleType.User, RoleType.Assistant]: + converted.append(_convert_message_item(msg)) + + return converted + + +def _convert_tool_options_to_bedrock(tool_options: list[ToolOption]) -> list[dict[str, Any]]: + """ + 转换工具选项格式 - 将工具选项转换为 Bedrock Converse API 所需的格式 + :param tool_options: 工具选项列表 + :return: 转换后的工具选项列表 + """ + + def _convert_tool_param(tool_param: ToolParam) -> dict[str, Any]: + """转换单个工具参数""" + param_dict: dict[str, Any] = { + "type": tool_param.param_type.value, + "description": tool_param.description, + } + if tool_param.enum_values: + param_dict["enum"] = tool_param.enum_values + return param_dict + + def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]: + """转换单个工具项""" + tool_spec: dict[str, Any] = { + "name": tool_option.name, + "description": tool_option.description, + } + if tool_option.params: + tool_spec["inputSchema"] = { + "json": { + "type": "object", + "properties": {param.name: _convert_tool_param(param) for param in tool_option.params}, + "required": [param.name for param in tool_option.params if param.required], + } + } + return {"toolSpec": tool_spec} + + return [_convert_tool_option_item(opt) for opt in tool_options] + + +async def _default_stream_response_handler( + resp_stream: Any, + interrupt_flag: asyncio.Event | None, +) -> tuple[APIResponse, tuple[int, int, int] | None]: + """ + 流式响应处理函数 - 处理 Bedrock Converse Stream API 的响应 + :param resp_stream: 流式响应对象 + :param interrupt_flag: 中断标志 + :return: (APIResponse对象, usage元组) + """ + _fc_delta_buffer = io.StringIO() # 正式内容缓冲区 + _tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区 + _usage_record = None + + def _insure_buffer_closed(): + if _fc_delta_buffer and not _fc_delta_buffer.closed: + _fc_delta_buffer.close() + for _, _, buffer in _tool_calls_buffer: + if buffer and not buffer.closed: + buffer.close() + + try: + async for event in resp_stream["stream"]: + if interrupt_flag and interrupt_flag.is_set(): + _insure_buffer_closed() + raise ReqAbortException("请求被外部信号中断") + + # 处理内容块 + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + _fc_delta_buffer.write(delta["text"]) + elif "toolUse" in delta: + # 工具调用 + tool_use = delta["toolUse"] + if "input" in tool_use: + # 追加工具调用参数 + if tool_use.get("toolUseId"): + # 新的工具调用 + _tool_calls_buffer.append( + ( + tool_use["toolUseId"], + tool_use.get("name", ""), + io.StringIO(json.dumps(tool_use["input"])), + ) + ) + + # 处理元数据(包含 usage) + if "metadata" in event: + metadata = event["metadata"] + if "usage" in metadata: + usage = metadata["usage"] + _usage_record = ( + usage.get("inputTokens", 0), + usage.get("outputTokens", 0), + usage.get("totalTokens", 0), + ) + + # 构建响应 + resp = APIResponse() + if _fc_delta_buffer.tell() > 0: + resp.content = _fc_delta_buffer.getvalue() + _fc_delta_buffer.close() + + if _tool_calls_buffer: + resp.tool_calls = [] + for call_id, function_name, arguments_buffer in _tool_calls_buffer: + if arguments_buffer.tell() > 0: + raw_arg_data = arguments_buffer.getvalue() + arguments_buffer.close() + try: + arguments = orjson.loads(repair_json(raw_arg_data)) + if not isinstance(arguments, dict): + raise RespParseException( + None, + f"响应解析失败,工具调用参数无法解析为字典类型。原始响应:\n{raw_arg_data}", + ) + except orjson.JSONDecodeError as e: + raise RespParseException( + None, + f"响应解析失败,无法解析工具调用参数。原始响应:{raw_arg_data}", + ) from e + else: + arguments_buffer.close() + arguments = None + + resp.tool_calls.append(ToolCall(call_id, function_name, args=arguments)) + + return resp, _usage_record + + except Exception as e: + _insure_buffer_closed() + raise + + +async def _default_async_response_parser( + resp_data: dict[str, Any], +) -> tuple[APIResponse, tuple[int, int, int] | None]: + """ + 默认异步响应解析函数 - 解析 Bedrock Converse API 的响应 + :param resp_data: 响应数据 + :return: (APIResponse对象, usage元组) + """ + resp = APIResponse() + + # 解析输出内容 + if "output" in resp_data and "message" in resp_data["output"]: + message = resp_data["output"]["message"] + content_blocks = message.get("content", []) + + text_parts = [] + tool_calls = [] + + for block in content_blocks: + if "text" in block: + text_parts.append(block["text"]) + elif "toolUse" in block: + tool_use = block["toolUse"] + tool_calls.append( + ToolCall( + call_id=tool_use.get("toolUseId", ""), + func_name=tool_use.get("name", ""), + args=tool_use.get("input", {}), + ) + ) + + if text_parts: + resp.content = "".join(text_parts) + if tool_calls: + resp.tool_calls = tool_calls + + # 解析 usage + usage_record = None + if "usage" in resp_data: + usage = resp_data["usage"] + usage_record = ( + usage.get("inputTokens", 0), + usage.get("outputTokens", 0), + usage.get("totalTokens", 0), + ) + + resp.raw_data = resp_data + return resp, usage_record + + +@client_registry.register_client_class("bedrock") +class BedrockClient(BaseClient): + """AWS Bedrock 客户端""" + + def __init__(self, api_provider: APIProvider): + super().__init__(api_provider) + + # 从 extra_params 获取 AWS 配置 + # 支持两种认证方式: + # 方式1(显式凭证):api_key + extra_params.aws_secret_key + # 方式2(IAM角色):只配置 region,自动从环境/实例角色获取凭证 + region = api_provider.extra_params.get("region", "us-east-1") + aws_secret_key = api_provider.extra_params.get("aws_secret_key") + + # 配置 boto3 + self.region = region + self.boto_config = Config( + region_name=self.region, + connect_timeout=api_provider.timeout, + read_timeout=api_provider.timeout, + retries={"max_attempts": api_provider.max_retry, "mode": "adaptive"}, + ) + + # 判断认证方式 + if aws_secret_key: + # 方式1:显式 IAM 凭证 + self.aws_access_key_id = api_provider.get_api_key() + self.aws_secret_access_key = aws_secret_key + self.session = aioboto3.Session( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + region_name=self.region, + ) + logger.info(f"初始化 Bedrock 客户端(IAM 凭证模式),区域: {self.region}") + else: + # 方式2:IAM 角色自动认证(从环境变量、EC2/ECS 实例角色获取) + self.session = aioboto3.Session(region_name=self.region) + logger.info(f"初始化 Bedrock 客户端(IAM 角色模式),区域: {self.region}") + logger.info("将使用环境变量或实例角色自动获取 AWS 凭证") + + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + response_format: RespFormat | None = None, + stream_response_handler: Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]] + | None = None, + async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None, + interrupt_flag: asyncio.Event | None = None, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取对话响应 + """ + try: + # 提取 system prompt + system_prompts = [] + filtered_messages = [] + for msg in message_list: + if msg.role == RoleType.System: + if isinstance(msg.content, str): + system_prompts.append({"text": msg.content}) + else: + filtered_messages.append(msg) + + # 转换消息格式 + messages = _convert_messages_to_converse(filtered_messages) + + # 构建请求参数 + request_params: dict[str, Any] = { + "modelId": model_info.model_identifier, + "messages": messages, + "inferenceConfig": { + "maxTokens": max_tokens, + "temperature": temperature, + }, + } + + # 添加 system prompt + if system_prompts: + request_params["system"] = system_prompts + + # 添加工具配置 + if tool_options: + request_params["toolConfig"] = {"tools": _convert_tool_options_to_bedrock(tool_options)} + + # 合并额外参数 + if extra_params: + request_params.update(extra_params) + + # 合并模型配置的额外参数 + if model_info.extra_params: + request_params.update(model_info.extra_params) + + # 创建 Bedrock Runtime 客户端 + async with self.session.client("bedrock-runtime", config=self.boto_config) as bedrock_client: + # 判断是否使用流式模式 + use_stream = model_info.force_stream_mode or stream_response_handler is not None + + if use_stream: + # 流式调用 + response = await bedrock_client.converse_stream(**request_params) + if stream_response_handler: + # 用户提供的处理器(可能是同步的) + result = stream_response_handler(response, interrupt_flag) + if asyncio.iscoroutine(result): + api_resp, usage_tuple = await result + else: + api_resp, usage_tuple = result # type: ignore + else: + # 默认异步处理器 + api_resp, usage_tuple = await _default_stream_response_handler(response, interrupt_flag) + else: + # 非流式调用 + response = await bedrock_client.converse(**request_params) + if async_response_parser: + # 用户提供的解析器(可能是同步的) + result = async_response_parser(response) + if asyncio.iscoroutine(result): + api_resp, usage_tuple = await result + else: + api_resp, usage_tuple = result # type: ignore + else: + # 默认异步解析器 + api_resp, usage_tuple = await _default_async_response_parser(response) + + # 设置 usage + if usage_tuple: + api_resp.usage = UsageRecord( + model_name=model_info.model_identifier, + provider_name=self.api_provider.name, + prompt_tokens=usage_tuple[0], + completion_tokens=usage_tuple[1], + total_tokens=usage_tuple[2], + ) + + return api_resp + + except Exception as e: + error_type = type(e).__name__ + logger.error(f"Bedrock API 调用失败 ({error_type}): {e!s}") + + # 处理特定错误类型 + if "ThrottlingException" in error_type or "ServiceQuota" in error_type: + raise RespNotOkException(429, f"请求限流: {e!s}") from e + elif "ValidationException" in error_type: + raise RespParseException(400, f"请求参数错误: {e!s}") from e + elif "AccessDeniedException" in error_type: + raise RespNotOkException(403, f"访问被拒绝: {e!s}") from e + elif "ResourceNotFoundException" in error_type: + raise RespNotOkException(404, f"模型不存在: {e!s}") from e + elif "timeout" in str(e).lower() or "timed out" in str(e).lower(): + logger.error(f"请求超时: {e!s}") + raise NetworkConnectionError() from e + else: + logger.error(f"网络连接错误: {e!s}") + raise NetworkConnectionError() from e + + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str | list[str], + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取文本嵌入(Bedrock 支持 Titan Embeddings 等模型) + """ + try: + async with self.session.client("bedrock-runtime", config=self.boto_config) as bedrock_client: + # Bedrock Embeddings 使用 InvokeModel API + is_batch = isinstance(embedding_input, list) + input_text = embedding_input if is_batch else [embedding_input] + + results = [] + total_tokens = 0 + + for text in input_text: + # 构建请求体(Titan Embeddings 格式) + body = json.dumps({"inputText": text}) + + response = await bedrock_client.invoke_model( + modelId=model_info.model_identifier, + contentType="application/json", + accept="application/json", + body=body, + ) + + # 解析响应 + response_body = json.loads(await response["body"].read()) + embedding = response_body.get("embedding", []) + results.append(embedding) + + # 累计 token 使用 + if "inputTokenCount" in response_body: + total_tokens += response_body["inputTokenCount"] + + api_resp = APIResponse() + api_resp.embedding = results if is_batch else results[0] + api_resp.usage = UsageRecord( + model_name=model_info.model_identifier, + provider_name=self.api_provider.name, + prompt_tokens=total_tokens, + completion_tokens=0, + total_tokens=total_tokens, + ) + + return api_resp + + except Exception as e: + logger.error(f"Bedrock Embedding 调用失败: {e!s}") + raise NetworkConnectionError() from e + + async def get_audio_transcriptions( + self, + model_info: ModelInfo, + audio_base64: str, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取音频转录(Bedrock 暂不直接支持,抛出未实现异常) + """ + raise NotImplementedError("AWS Bedrock 暂不支持音频转录功能,建议使用 AWS Transcribe 服务") + + def get_support_image_formats(self) -> list[str]: + """ + 获取支持的图片格式 + :return: 支持的图片格式列表 + """ + return ["png", "jpeg", "jpg", "gif", "webp"] diff --git a/template/model_config_template.toml b/template/model_config_template.toml index fc0fea76b..cec804053 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -30,6 +30,30 @@ max_retry = 2 timeout = 30 retry_interval = 10 +#[[api_providers]] # AWS Bedrock配置示例 - 方式1:IAM凭证模式(取消注释以启用) +#name = "AWS_Bedrock" +#base_url = "" # Bedrock不需要base_url,留空即可 +#api_key = "YOUR_AWS_ACCESS_KEY_ID" # 你的AWS Access Key ID +#client_type = "bedrock" # 使用bedrock客户端 +#max_retry = 2 +#timeout = 60 # Bedrock推荐较长超时时间 +#retry_interval = 10 +#[api_providers.extra_params] # Bedrock需要的额外配置 +#aws_secret_key = "YOUR_AWS_SECRET_ACCESS_KEY" # 你的AWS Secret Access Key +#region = "us-east-1" # AWS区域,可选:us-east-1, us-west-2, eu-central-1等 + +#[[api_providers]] # AWS Bedrock配置示例 - 方式2:IAM角色模式(推荐EC2/ECS部署) +#name = "AWS_Bedrock_Role" +#base_url = "" # Bedrock不需要base_url +#api_key = "dummy" # IAM角色模式不使用api_key,但字段必填,可填任意值 +#client_type = "bedrock" +#max_retry = 2 +#timeout = 60 +#retry_interval = 10 +#[api_providers.extra_params] +## 不配置aws_secret_key,将自动使用IAM角色/环境变量认证 +#region = "us-east-1" # 只需配置区域 + [[models]] # 模型(可以配置多个) model_identifier = "deepseek-chat" # 模型标识符(API服务商提供的模型标识符) @@ -123,6 +147,28 @@ price_out = 0.0 #thinking_level = "medium" # Gemini3新版参数,可选值: "low", "medium", "high" thinking_budget = 256 # Gemini2.5系列旧版参数,不同模型范围不同(如 gemini-2.5-flash: 1-24576, gemini-2.5-pro: 128-32768) +#[[models]] # AWS Bedrock - Claude 3.5 Sonnet配置示例(取消注释以启用) +#model_identifier = "us.anthropic.claude-3-5-sonnet-20240620-v1:0" # 跨区推理配置文件 +#name = "claude-3.5-sonnet-bedrock" +#api_provider = "AWS_Bedrock" +#price_in = 3.0 # 每百万输入token价格(USD) +#price_out = 15.0 # 每百万输出token价格(USD) +#force_stream_mode = false + +#[[models]] # AWS Bedrock - Amazon Nova Pro配置示例 +#model_identifier = "us.amazon.nova-pro-v1:0" +#name = "nova-pro" +#api_provider = "AWS_Bedrock" +#price_in = 0.8 +#price_out = 3.2 + +#[[models]] # AWS Bedrock - Titan Embeddings嵌入模型示例 +#model_identifier = "amazon.titan-embed-text-v2:0" +#name = "titan-embed-v2" +#api_provider = "AWS_Bedrock" +#price_in = 0.00002 # 每千token +#price_out = 0.0 + [model_task_config.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型 model_list = ["siliconflow-deepseek-ai/DeepSeek-V3.2-Exp"] # 使用的模型列表,每个子项对应上面的模型名称(name) temperature = 0.2 # 模型温度,新V3建议0.1-0.3 From b1e7b6972d1bf54eb72a03fae4b100e7a49408e8 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 6 Dec 2025 18:32:15 +0800 Subject: [PATCH 7/9] =?UTF-8?q?feat(plugin):=20=E6=B7=BB=E5=8A=A0=E5=8A=A8?= =?UTF-8?q?=E6=80=81=E5=8A=A0=E8=BD=BD=E6=8F=92=E4=BB=B6=E7=9A=84=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 `load_plugin` 函数,允许在运行时加载一个已注册但未加载的插件。 这为更灵活的插件管理(例如热加载)提供了基础支持。 --- src/plugin_system/apis/plugin_manage_api.py | 32 +++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/plugin_system/apis/plugin_manage_api.py b/src/plugin_system/apis/plugin_manage_api.py index ad50ad029..f83605061 100644 --- a/src/plugin_system/apis/plugin_manage_api.py +++ b/src/plugin_system/apis/plugin_manage_api.py @@ -189,6 +189,38 @@ def register_plugin_from_file(plugin_name: str, load_after_register: bool = True # 该部分包含控制插件整体启用/禁用状态的功能。 +async def load_plugin(plugin_name: str) -> bool: + """ + 加载一个已注册但未加载的插件。 + + Args: + plugin_name (str): 要加载的插件名称。 + + Returns: + bool: 如果插件成功加载,则为 True。 + + Raises: + ValueError: 如果插件未注册或已经加载。 + """ + # 检查插件是否已经加载 + if plugin_name in plugin_manager.list_loaded_plugins(): + logger.warning(f"插件 '{plugin_name}' 已经加载。") + return True + + # 检查插件是否已注册 + if plugin_name not in plugin_manager.list_registered_plugins(): + raise ValueError(f"插件 '{plugin_name}' 未注册,无法加载。") + + # 尝试加载插件 + success, _ = plugin_manager.load_registered_plugin_classes(plugin_name) + if success: + logger.info(f"插件 '{plugin_name}' 加载成功。") + else: + logger.error(f"插件 '{plugin_name}' 加载失败。") + + return success + + async def enable_plugin(plugin_name: str) -> bool: """ 启用一个已禁用的插件。 From af59966d8be187c4bccef2c84c81cf12cc922b48 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 6 Dec 2025 18:40:02 +0800 Subject: [PATCH 8/9] =?UTF-8?q?feat(plugin):=20=E9=98=B2=E6=AD=A2=E7=A6=81?= =?UTF-8?q?=E7=94=A8=E6=9C=80=E5=90=8E=E4=B8=80=E4=B8=AA=E5=90=AF=E7=94=A8?= =?UTF-8?q?=E7=9A=84=20Chatter=20=E7=BB=84=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为了确保系统的核心对话功能始终可用,在禁用插件时增加了保护机制。 该机制会检查目标插件是否包含 Chatter 组件。如果是,它会进一步判断禁用该插件是否会导致系统中没有任何已启用的 Chatter 组件。如果出现这种情况,禁用操作将被阻止并返回失败,从而避免因误操作导致系统核心功能失效。 --- src/plugin_system/apis/plugin_manage_api.py | 37 +++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/plugin_system/apis/plugin_manage_api.py b/src/plugin_system/apis/plugin_manage_api.py index f83605061..088b2b62e 100644 --- a/src/plugin_system/apis/plugin_manage_api.py +++ b/src/plugin_system/apis/plugin_manage_api.py @@ -267,6 +267,7 @@ async def disable_plugin(plugin_name: str,) -> bool: 禁用一个插件。 禁用插件不会卸载它,只会标记为禁用状态。 + 包含对 Chatter 组件的保护机制,防止禁用最后一个启用的 Chatter。 Args: plugin_name (str): 要禁用的插件名称。 @@ -280,6 +281,42 @@ async def disable_plugin(plugin_name: str,) -> bool: logger.warning(f"插件 '{plugin_name}' 未加载,无需禁用。") return True + # Chatter 保护检查:确保系统中至少有一个 Chatter 组件处于启用状态 + try: + from src.plugin_system.base.component_types import ComponentType + from src.plugin_system.core.component_registry import component_registry + + # 获取该插件的所有组件 + plugin_info = component_registry.get_plugin_info(plugin_name) + if plugin_info: + # 检查插件是否包含 Chatter 组件 + has_chatter = any( + comp.component_type == ComponentType.CHATTER + for comp in plugin_info.components + ) + + if has_chatter: + # 获取所有启用的 Chatter 组件 + enabled_chatters = component_registry.get_enabled_components_by_type(ComponentType.CHATTER) + + # 统计该插件中启用的 Chatter 数量 + plugin_enabled_chatters = [ + comp.name for comp in plugin_info.components + if comp.component_type == ComponentType.CHATTER + and comp.name in enabled_chatters + ] + + # 如果禁用此插件会导致没有可用的 Chatter,则阻止操作 + if len(enabled_chatters) <= len(plugin_enabled_chatters): + logger.warning( + f"操作被阻止:禁用插件 '{plugin_name}' 将导致系统中没有可用的 Chatter 组件。" + f"至少需要保持一个 Chatter 组件处于启用状态。" + ) + return False + except Exception as e: + logger.warning(f"检查 Chatter 保护机制时发生错误: {e}") + # 即使检查失败,也继续执行禁用操作(降级处理) + # 设置插件为禁用状态 plugin_instance.enable_plugin = False logger.info(f"插件 '{plugin_name}' 已禁用。") From 2235920908025ea9e879290996dc2700ef345056 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 6 Dec 2025 19:03:32 +0800 Subject: [PATCH 9/9] =?UTF-8?q?fix(plugin):=20=E6=A3=80=E6=9F=A5=E7=BB=84?= =?UTF-8?q?=E4=BB=B6=E5=8F=AF=E7=94=A8=E6=80=A7=E6=97=B6=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=AF=B9=E5=85=B6=E6=89=80=E5=B1=9E=E6=8F=92=E4=BB=B6=E7=8A=B6?= =?UTF-8?q?=E6=80=81=E7=9A=84=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在 `is_component_enabled` 方法中,增加了对组件所属插件启用状态的检查。 此前,该方法仅检查组件自身的全局或局部状态,这可能导致一个已禁用插件下的组件仍然被错误地判断为“可用”,从而引发非预期行为。 本次修改确保在检查组件自身状态前,先验证其所属插件是否已启用。这使得组件的生命周期与其所属插件的状态保持一致,提高了系统的健壮性。 --- .../core/component_state_manager.py | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/src/plugin_system/core/component_state_manager.py b/src/plugin_system/core/component_state_manager.py index 300764d79..17137cab4 100644 --- a/src/plugin_system/core/component_state_manager.py +++ b/src/plugin_system/core/component_state_manager.py @@ -12,21 +12,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING from src.common.logger import get_logger from src.plugin_system.base.component_types import ( - ActionInfo, ComponentInfo, ComponentType, ) if TYPE_CHECKING: - from src.plugin_system.base.base_chatter import BaseChatter - from src.plugin_system.base.base_events_handler import BaseEventHandler - from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator - from src.plugin_system.base.base_prompt import BasePrompt - from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.core.component_registry import ComponentRegistry logger = get_logger("component_state_manager") @@ -103,27 +97,25 @@ class ComponentStateManager: # 更新特定类型的启用列表 match component_type: case ComponentType.ACTION: - self._registry._default_actions[component_name] = cast(ActionInfo, target_info) + self._registry._default_actions[component_name] = target_info # type: ignore case ComponentType.TOOL: - self._registry._llm_available_tools[component_name] = cast(type[BaseTool], target_class) + self._registry._llm_available_tools[component_name] = target_class # type: ignore case ComponentType.EVENT_HANDLER: - self._registry._enabled_event_handlers[component_name] = cast(type[BaseEventHandler], target_class) + self._registry._enabled_event_handlers[component_name] = target_class # type: ignore # 重新注册事件处理器 from .event_manager import event_manager event_manager.register_event_handler( - cast(type[BaseEventHandler], target_class), + target_class, # type: ignore self._registry.get_plugin_config(target_info.plugin_name) or {} ) case ComponentType.CHATTER: - self._registry._enabled_chatter_registry[component_name] = cast(type[BaseChatter], target_class) + self._registry._enabled_chatter_registry[component_name] = target_class # type: ignore case ComponentType.INTEREST_CALCULATOR: - self._registry._enabled_interest_calculator_registry[component_name] = cast( - type[BaseInterestCalculator], target_class - ) + self._registry._enabled_interest_calculator_registry[component_name] = target_class # type: ignore case ComponentType.PROMPT: - self._registry._enabled_prompt_registry[component_name] = cast(type[BasePrompt], target_class) + self._registry._enabled_prompt_registry[component_name] = target_class # type: ignore case ComponentType.ADAPTER: - self._registry._enabled_adapter_registry[component_name] = cast(Any, target_class) + self._registry._enabled_adapter_registry[component_name] = target_class # type: ignore logger.info(f"组件 {component_name} ({component_type.value}) 已全局启用") return True @@ -261,8 +253,9 @@ class ComponentStateManager: 检查顺序: 1. 组件是否存在 - 2. (如果提供了 stream_id 且组件类型支持局部状态) 是否有局部状态覆盖 - 3. 全局启用状态 + 2. 组件所属插件是否已启用 + 3. (如果提供了 stream_id 且组件类型支持局部状态) 是否有局部状态覆盖 + 4. 全局启用状态 Args: component_name: 组件名称 @@ -278,17 +271,29 @@ class ComponentStateManager: if not component_info: return False - # 2. 不支持局部状态的类型,直接返回全局状态 + # 2. 检查组件所属插件是否已启用 + from src.plugin_system.core.plugin_manager import plugin_manager + plugin_instance = plugin_manager.get_plugin_instance(component_info.plugin_name) + if not plugin_instance: + return False + if not plugin_instance.enable_plugin: + logger.debug( + f"组件 {component_name} ({component_type.value}) 不可用: " + f"所属插件 {component_info.plugin_name} 已被禁用" + ) + return False + + # 3. 不支持局部状态的类型,直接返回全局状态 if component_type in self._no_local_state_types: return component_info.enabled - # 3. 如果提供了 stream_id,检查是否存在局部状态覆盖 + # 4. 如果提供了 stream_id,检查是否存在局部状态覆盖 if stream_id: local_state = self.get_local_state(stream_id, component_name, component_type) if local_state is not None: return local_state # 局部状态存在,直接返回 - # 4. 如果没有局部状态覆盖,返回全局状态 + # 5. 如果没有局部状态覆盖,返回全局状态 return component_info.enabled def get_enabled_components_by_type(