Compare commits
10 Commits
d302baff5f
...
767aad407a
| Author | SHA1 | Date | |
|---|---|---|---|
|
767aad407a
|
|||
|
5757999ae5
|
|||
|
42293a2b39
|
|||
|
|
1667bdc4c0 | ||
|
|
b372cb8fe0 | ||
|
|
2235920908 | ||
|
|
af59966d8b | ||
|
|
70c8557e02 | ||
|
|
b1e7b6972d | ||
|
|
2348dc1082 |
32
.gitea/workflows/build.yaml
Normal file
32
.gitea/workflows/build.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
name: Build and Push Docker Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
- gitea
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: docker.gardel.top
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
- name: Build and Push Docker Image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
push: true
|
||||
tags: docker.gardel.top/gardel/mofox:dev
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
149
.github/workflows/docker-image.yml
vendored
149
.github/workflows/docker-image.yml
vendored
@@ -1,149 +0,0 @@
|
||||
name: Docker Build and Push
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
- "v*"
|
||||
- "*.*.*"
|
||||
- "*.*.*-*"
|
||||
workflow_dispatch: # 允许手动触发工作流
|
||||
|
||||
# Workflow's jobs
|
||||
jobs:
|
||||
build-amd64:
|
||||
name: Build AMD64 Image
|
||||
runs-on: ubuntu-24.04
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
|
||||
# Build and push AMD64 image by digest
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
build-arm64:
|
||||
name: Build ARM64 Image
|
||||
runs-on: ubuntu-24.04-arm
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
|
||||
# Build and push ARM64 image by digest
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/arm64/v8
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
create-manifest:
|
||||
name: Create Multi-Arch Manifest
|
||||
runs-on: ubuntu-24.04
|
||||
needs:
|
||||
- build-amd64
|
||||
- build-arm64
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=sha,prefix=${{ github.ref_name }}-,enable=${{ github.ref_type == 'branch' }}
|
||||
|
||||
- name: Create and Push Manifest
|
||||
run: |
|
||||
# 为每个标签创建多架构镜像
|
||||
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr '\n' ' '); do
|
||||
echo "Creating manifest for $tag"
|
||||
docker buildx imagetools create -t $tag \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-amd64.outputs.digest }} \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-arm64.outputs.digest }}
|
||||
done
|
||||
102
BEDROCK_INTEGRATION.md
Normal file
102
BEDROCK_INTEGRATION.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# AWS Bedrock 集成完成 ✅
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install aioboto3 botocore
|
||||
```
|
||||
|
||||
### 2. 配置凭证
|
||||
|
||||
在 `config/model_config.toml` 添加:
|
||||
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "bedrock_us_east"
|
||||
base_url = ""
|
||||
api_key = "YOUR_AWS_ACCESS_KEY_ID"
|
||||
client_type = "bedrock"
|
||||
timeout = 60
|
||||
|
||||
[api_providers.extra_params]
|
||||
aws_secret_key = "YOUR_AWS_SECRET_ACCESS_KEY"
|
||||
region = "us-east-1"
|
||||
|
||||
[[models]]
|
||||
model_identifier = "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
name = "claude-3.5-sonnet-bedrock"
|
||||
api_provider = "bedrock_us_east"
|
||||
price_in = 3.0
|
||||
price_out = 15.0
|
||||
```
|
||||
|
||||
### 3. 使用示例
|
||||
|
||||
```python
|
||||
from src.llm_models import get_llm_client
|
||||
from src.llm_models.payload_content.message import MessageBuilder
|
||||
|
||||
client = get_llm_client("bedrock_us_east")
|
||||
builder = MessageBuilder()
|
||||
builder.add_user_message("你好,AWS Bedrock!")
|
||||
|
||||
response = await client.get_response(
|
||||
model_info=get_model_info("claude-3.5-sonnet-bedrock"),
|
||||
message_list=[builder.build()],
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
print(response.content)
|
||||
```
|
||||
|
||||
## 新增文件
|
||||
|
||||
- ✅ `src/llm_models/model_client/bedrock_client.py` - Bedrock 客户端实现
|
||||
- ✅ `docs/integrations/Bedrock.md` - 完整文档
|
||||
- ✅ `scripts/test_bedrock_client.py` - 测试脚本
|
||||
|
||||
## 修改文件
|
||||
|
||||
- ✅ `src/llm_models/model_client/__init__.py` - 添加 Bedrock 导入
|
||||
- ✅ `src/config/api_ada_configs.py` - 添加 `bedrock` client_type
|
||||
- ✅ `template/model_config_template.toml` - 添加 Bedrock 配置示例(注释形式)
|
||||
- ✅ `requirements.txt` - 添加 aioboto3 和 botocore 依赖
|
||||
- ✅ `pyproject.toml` - 添加 aioboto3 和 botocore 依赖
|
||||
|
||||
## 支持功能
|
||||
|
||||
- ✅ **对话生成**:支持多轮对话
|
||||
- ✅ **流式输出**:支持流式响应
|
||||
- ✅ **工具调用**:完整支持 Tool Use
|
||||
- ✅ **多模态**:支持图片输入
|
||||
- ✅ **文本嵌入**:支持 Titan Embeddings
|
||||
- ✅ **跨区推理**:支持 Inference Profile
|
||||
|
||||
## 支持模型
|
||||
|
||||
- Amazon Nova 系列 (Micro/Lite/Pro)
|
||||
- Anthropic Claude 3/3.5 系列
|
||||
- Meta Llama 2/3 系列
|
||||
- Mistral AI 系列
|
||||
- Cohere Command 系列
|
||||
- AI21 Jamba 系列
|
||||
- Stability AI SDXL
|
||||
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
# 修改凭证后运行测试
|
||||
python scripts/test_bedrock_client.py
|
||||
```
|
||||
|
||||
## 文档
|
||||
|
||||
详细文档:`docs/integrations/Bedrock.md`
|
||||
|
||||
---
|
||||
|
||||
**集成状态**: ✅ 生产就绪
|
||||
**集成时间**: 2025年12月6日
|
||||
|
||||
260
docs/integrations/Bedrock.md
Normal file
260
docs/integrations/Bedrock.md
Normal file
@@ -0,0 +1,260 @@
|
||||
# AWS Bedrock 集成指南
|
||||
|
||||
## 概述
|
||||
|
||||
MoFox-Bot 已完全集成 AWS Bedrock,支持使用 **Converse API** 统一调用所有 Bedrock 模型,包括:
|
||||
- Amazon Nova 系列
|
||||
- Anthropic Claude 3/3.5
|
||||
- Meta Llama 2/3
|
||||
- Mistral AI
|
||||
- Cohere Command
|
||||
- AI21 Jamba
|
||||
- Stability AI SDXL
|
||||
|
||||
## 配置示例
|
||||
|
||||
### 1. 配置 API Provider
|
||||
|
||||
在 `config/model_config.toml` 中添加 Bedrock Provider:
|
||||
|
||||
```toml
|
||||
[[api_providers]]
|
||||
name = "bedrock_us_east"
|
||||
base_url = "" # Bedrock 不需要 base_url,留空即可
|
||||
api_key = "YOUR_AWS_ACCESS_KEY_ID" # AWS Access Key ID
|
||||
client_type = "bedrock"
|
||||
max_retry = 2
|
||||
timeout = 60
|
||||
retry_interval = 10
|
||||
|
||||
[api_providers.extra_params]
|
||||
aws_secret_key = "YOUR_AWS_SECRET_ACCESS_KEY" # AWS Secret Access Key
|
||||
region = "us-east-1" # AWS 区域,默认 us-east-1
|
||||
```
|
||||
|
||||
### 2. 配置模型
|
||||
|
||||
在同一文件中添加模型配置:
|
||||
|
||||
```toml
|
||||
# Claude 3.5 Sonnet (Bedrock 跨区推理配置文件)
|
||||
[[models]]
|
||||
model_identifier = "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
name = "claude-3.5-sonnet-bedrock"
|
||||
api_provider = "bedrock_us_east"
|
||||
price_in = 3.0 # 每百万输入 token 价格(USD)
|
||||
price_out = 15.0 # 每百万输出 token 价格(USD)
|
||||
force_stream_mode = false
|
||||
|
||||
# Amazon Nova Pro
|
||||
[[models]]
|
||||
model_identifier = "us.amazon.nova-pro-v1:0"
|
||||
name = "nova-pro"
|
||||
api_provider = "bedrock_us_east"
|
||||
price_in = 0.8
|
||||
price_out = 3.2
|
||||
force_stream_mode = false
|
||||
|
||||
# Llama 3.1 405B
|
||||
[[models]]
|
||||
model_identifier = "us.meta.llama3-2-90b-instruct-v1:0"
|
||||
name = "llama-3.1-405b-bedrock"
|
||||
api_provider = "bedrock_us_east"
|
||||
price_in = 0.00532
|
||||
price_out = 0.016
|
||||
force_stream_mode = false
|
||||
```
|
||||
|
||||
## 支持的功能
|
||||
|
||||
### ✅ 已实现
|
||||
|
||||
- **对话生成**:支持多轮对话,自动处理 system prompt
|
||||
- **流式输出**:支持流式响应(`force_stream_mode = true`)
|
||||
- **工具调用**:完整支持 Tool Use(函数调用)
|
||||
- **多模态**:支持图片输入(PNG、JPEG、GIF、WebP)
|
||||
- **文本嵌入**:支持 Titan Embeddings 等嵌入模型
|
||||
- **跨区推理**:支持 Inference Profile(如 `us.anthropic.claude-3-5-sonnet-20240620-v1:0`)
|
||||
|
||||
### ⚠️ 限制
|
||||
|
||||
- **音频转录**:Bedrock 不直接支持语音转文字,建议使用 AWS Transcribe
|
||||
- **System 角色**:Bedrock Converse API 将 system 消息单独处理,不计入 messages 列表
|
||||
- **Tool 角色**:暂不支持 Tool 消息回传(需要用 User 角色模拟)
|
||||
|
||||
## 模型 ID 参考
|
||||
|
||||
### 推理配置文件(跨区)
|
||||
|
||||
| 模型 | Model ID | 区域覆盖 |
|
||||
|------|----------|----------|
|
||||
| Claude 3.5 Sonnet | `us.anthropic.claude-3-5-sonnet-20240620-v1:0` | us-east-1, us-west-2 |
|
||||
| Claude 3 Opus | `us.anthropic.claude-3-opus-20240229-v1:0` | 多区 |
|
||||
| Nova Pro | `us.amazon.nova-pro-v1:0` | 多区 |
|
||||
| Llama 3.1 405B | `us.meta.llama3-2-90b-instruct-v1:0` | 多区 |
|
||||
|
||||
### 单区基础模型
|
||||
|
||||
| 模型 | Model ID | 区域 |
|
||||
|------|----------|------|
|
||||
| Claude 3.5 Sonnet | `anthropic.claude-3-5-sonnet-20240620-v1:0` | 单区 |
|
||||
| Nova Micro | `amazon.nova-micro-v1:0` | us-east-1 |
|
||||
| Nova Lite | `amazon.nova-lite-v1:0` | us-east-1 |
|
||||
| Titan Embeddings G1 | `amazon.titan-embed-text-v1` | 多区 |
|
||||
|
||||
完整模型列表:https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
|
||||
|
||||
## 使用示例
|
||||
|
||||
### Python 调用示例
|
||||
|
||||
```python
|
||||
from src.llm_models import get_llm_client
|
||||
from src.llm_models.payload_content.message import MessageBuilder
|
||||
|
||||
# 获取客户端
|
||||
client = get_llm_client("bedrock_us_east")
|
||||
|
||||
# 构建消息
|
||||
builder = MessageBuilder()
|
||||
builder.add_user_message("你好,请介绍一下 AWS Bedrock")
|
||||
|
||||
# 调用模型
|
||||
response = await client.get_response(
|
||||
model_info=get_model_info("claude-3.5-sonnet-bedrock"),
|
||||
message_list=[builder.build()],
|
||||
max_tokens=1024,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
print(response.content)
|
||||
```
|
||||
|
||||
### 多模态示例(图片输入)
|
||||
|
||||
```python
|
||||
import base64
|
||||
|
||||
builder = MessageBuilder()
|
||||
builder.add_text_content("这张图片里有什么?")
|
||||
|
||||
# 添加图片(支持 JPEG、PNG、GIF、WebP)
|
||||
with open("image.jpg", "rb") as f:
|
||||
image_data = base64.b64encode(f.read()).decode()
|
||||
builder.add_image_content("jpeg", image_data)
|
||||
|
||||
builder.set_role_user()
|
||||
|
||||
response = await client.get_response(
|
||||
model_info=get_model_info("claude-3.5-sonnet-bedrock"),
|
||||
message_list=[builder.build()],
|
||||
max_tokens=1024
|
||||
)
|
||||
```
|
||||
|
||||
### 工具调用示例
|
||||
|
||||
```python
|
||||
from src.llm_models.payload_content.tool_option import ToolOption, ToolParam, ParamType
|
||||
|
||||
# 定义工具
|
||||
tool = ToolOption(
|
||||
name="get_weather",
|
||||
description="获取指定城市的天气信息",
|
||||
params=[
|
||||
ToolParam(
|
||||
name="city",
|
||||
param_type=ParamType.String,
|
||||
description="城市名称",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# 调用
|
||||
response = await client.get_response(
|
||||
model_info=get_model_info("claude-3.5-sonnet-bedrock"),
|
||||
message_list=messages,
|
||||
tool_options=[tool],
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
# 检查工具调用
|
||||
if response.tool_calls:
|
||||
for call in response.tool_calls:
|
||||
print(f"工具: {call.name}, 参数: {call.arguments}")
|
||||
```
|
||||
|
||||
## 权限配置
|
||||
|
||||
### IAM 策略示例
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"bedrock:InvokeModel",
|
||||
"bedrock:InvokeModelWithResponseStream",
|
||||
"bedrock:Converse",
|
||||
"bedrock:ConverseStream"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:bedrock:*::foundation-model/*",
|
||||
"arn:aws:bedrock:*:*:inference-profile/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 费用优化建议
|
||||
|
||||
1. **使用推理配置文件(Inference Profile)**:自动路由到低成本区域
|
||||
2. **启用缓存**:对于重复的 system prompt,Bedrock 支持提示词缓存
|
||||
3. **批量处理**:嵌入任务可批量调用,减少请求次数
|
||||
4. **监控用量**:通过 `LLMUsageRecorder` 自动记录 token 消耗和费用
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 常见错误
|
||||
|
||||
| 错误 | 原因 | 解决方案 |
|
||||
|------|------|----------|
|
||||
| `AccessDeniedException` | IAM 权限不足 | 检查 IAM 策略是否包含 `bedrock:InvokeModel` |
|
||||
| `ResourceNotFoundException` | 模型 ID 错误或区域不支持 | 验证 model_identifier 和 region 配置 |
|
||||
| `ThrottlingException` | 超过配额限制 | 增加 retry_interval 或申请提额 |
|
||||
| `ValidationException` | 请求参数错误 | 检查 messages 格式和 max_tokens 范围 |
|
||||
|
||||
### 调试模式
|
||||
|
||||
启用详细日志:
|
||||
|
||||
```python
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("Bedrock客户端")
|
||||
logger.setLevel("DEBUG")
|
||||
```
|
||||
|
||||
## 依赖安装
|
||||
|
||||
```bash
|
||||
pip install aioboto3 botocore
|
||||
```
|
||||
|
||||
或使用项目的 `requirements.txt`。
|
||||
|
||||
## 参考资料
|
||||
|
||||
- [AWS Bedrock 官方文档](https://docs.aws.amazon.com/bedrock/)
|
||||
- [Converse API 参考](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html)
|
||||
- [支持的模型列表](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html)
|
||||
- [定价计算器](https://aws.amazon.com/bedrock/pricing/)
|
||||
|
||||
---
|
||||
|
||||
**集成日期**: 2025年12月6日
|
||||
**状态**: ✅ 生产就绪
|
||||
@@ -37,6 +37,8 @@ dependencies = [
|
||||
"numpy>=2.2.6",
|
||||
"openai>=2.5.0",
|
||||
"opencv-python>=4.11.0.86",
|
||||
"aioboto3>=13.3.0",
|
||||
"botocore>=1.35.0",
|
||||
"packaging>=25.0",
|
||||
"pandas>=2.3.1",
|
||||
"peewee>=3.18.2",
|
||||
|
||||
@@ -22,6 +22,8 @@ networkx
|
||||
numpy
|
||||
openai
|
||||
google-genai
|
||||
aioboto3
|
||||
botocore
|
||||
pandas
|
||||
peewee
|
||||
pyarrow
|
||||
|
||||
204
scripts/test_bedrock_client.py
Normal file
204
scripts/test_bedrock_client.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AWS Bedrock 客户端测试脚本
|
||||
测试 BedrockClient 的基本功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
from src.llm_models.model_client.bedrock_client import BedrockClient
|
||||
from src.llm_models.payload_content.message import MessageBuilder
|
||||
|
||||
|
||||
async def test_basic_conversation():
|
||||
"""测试基本对话功能"""
|
||||
print("=" * 60)
|
||||
print("测试 1: 基本对话功能")
|
||||
print("=" * 60)
|
||||
|
||||
# 配置 API Provider(请替换为你的真实凭证)
|
||||
provider = APIProvider(
|
||||
name="bedrock_test",
|
||||
base_url="", # Bedrock 不需要
|
||||
api_key="YOUR_AWS_ACCESS_KEY_ID", # 替换为你的 AWS Access Key
|
||||
client_type="bedrock",
|
||||
max_retry=2,
|
||||
timeout=60,
|
||||
retry_interval=10,
|
||||
extra_params={
|
||||
"aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY", # 替换为你的 AWS Secret Key
|
||||
"region": "us-east-1",
|
||||
},
|
||||
)
|
||||
|
||||
# 配置模型信息
|
||||
model = ModelInfo(
|
||||
model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
name="claude-3.5-sonnet-bedrock",
|
||||
api_provider="bedrock_test",
|
||||
price_in=3.0,
|
||||
price_out=15.0,
|
||||
force_stream_mode=False,
|
||||
)
|
||||
|
||||
# 创建客户端
|
||||
client = BedrockClient(provider)
|
||||
|
||||
# 构建消息
|
||||
builder = MessageBuilder()
|
||||
builder.add_user_message("你好!请用一句话介绍 AWS Bedrock。")
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
response = await client.get_response(
|
||||
model_info=model, message_list=[builder.build()], max_tokens=200, temperature=0.7
|
||||
)
|
||||
|
||||
print(f"✅ 响应内容: {response.content}")
|
||||
if response.usage:
|
||||
print(
|
||||
f"📊 Token 使用: 输入={response.usage.prompt_tokens}, "
|
||||
f"输出={response.usage.completion_tokens}, "
|
||||
f"总计={response.usage.total_tokens}"
|
||||
)
|
||||
print("\n测试通过!✅\n")
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def test_streaming():
|
||||
"""测试流式输出功能"""
|
||||
print("=" * 60)
|
||||
print("测试 2: 流式输出功能")
|
||||
print("=" * 60)
|
||||
|
||||
provider = APIProvider(
|
||||
name="bedrock_test",
|
||||
base_url="",
|
||||
api_key="YOUR_AWS_ACCESS_KEY_ID",
|
||||
client_type="bedrock",
|
||||
max_retry=2,
|
||||
timeout=60,
|
||||
extra_params={
|
||||
"aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY",
|
||||
"region": "us-east-1",
|
||||
},
|
||||
)
|
||||
|
||||
model = ModelInfo(
|
||||
model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
name="claude-3.5-sonnet-bedrock",
|
||||
api_provider="bedrock_test",
|
||||
price_in=3.0,
|
||||
price_out=15.0,
|
||||
force_stream_mode=True, # 启用流式模式
|
||||
)
|
||||
|
||||
client = BedrockClient(provider)
|
||||
builder = MessageBuilder()
|
||||
builder.add_user_message("写一个关于人工智能的三行诗。")
|
||||
|
||||
try:
|
||||
print("🔄 流式响应中...")
|
||||
response = await client.get_response(
|
||||
model_info=model, message_list=[builder.build()], max_tokens=100, temperature=0.7
|
||||
)
|
||||
|
||||
print(f"✅ 完整响应: {response.content}")
|
||||
print("\n测试通过!✅\n")
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e!s}")
|
||||
|
||||
|
||||
async def test_multimodal():
|
||||
"""测试多模态(图片输入)功能"""
|
||||
print("=" * 60)
|
||||
print("测试 3: 多模态功能(需要准备图片)")
|
||||
print("=" * 60)
|
||||
print("⏭️ 跳过(需要实际图片文件)\n")
|
||||
|
||||
|
||||
async def test_tool_calling():
|
||||
"""测试工具调用功能"""
|
||||
print("=" * 60)
|
||||
print("测试 4: 工具调用功能")
|
||||
print("=" * 60)
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolOption, ToolOptionBuilder, ToolParamType
|
||||
|
||||
provider = APIProvider(
|
||||
name="bedrock_test",
|
||||
base_url="",
|
||||
api_key="YOUR_AWS_ACCESS_KEY_ID",
|
||||
client_type="bedrock",
|
||||
extra_params={
|
||||
"aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY",
|
||||
"region": "us-east-1",
|
||||
},
|
||||
)
|
||||
|
||||
model = ModelInfo(
|
||||
model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
name="claude-3.5-sonnet-bedrock",
|
||||
api_provider="bedrock_test",
|
||||
)
|
||||
|
||||
# 定义工具
|
||||
tool_builder = ToolOptionBuilder()
|
||||
tool_builder.set_name("get_weather").set_description("获取指定城市的天气信息").add_param(
|
||||
name="city", param_type=ToolParamType.STRING, description="城市名称", required=True
|
||||
)
|
||||
|
||||
tool = tool_builder.build()
|
||||
|
||||
client = BedrockClient(provider)
|
||||
builder = MessageBuilder()
|
||||
builder.add_user_message("北京今天天气怎么样?")
|
||||
|
||||
try:
|
||||
response = await client.get_response(
|
||||
model_info=model, message_list=[builder.build()], tool_options=[tool], max_tokens=200
|
||||
)
|
||||
|
||||
if response.tool_calls:
|
||||
print(f"✅ 模型调用了工具:")
|
||||
for call in response.tool_calls:
|
||||
print(f" - 工具名: {call.func_name}")
|
||||
print(f" - 参数: {call.args}")
|
||||
else:
|
||||
print(f"⚠️ 模型没有调用工具,而是直接回复: {response.content}")
|
||||
|
||||
print("\n测试通过!✅\n")
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e!s}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("\n🚀 AWS Bedrock 客户端测试开始\n")
|
||||
print("⚠️ 请确保已配置 AWS 凭证!")
|
||||
print("⚠️ 修改脚本中的 'YOUR_AWS_ACCESS_KEY_ID' 和 'YOUR_AWS_SECRET_ACCESS_KEY'\n")
|
||||
|
||||
# 运行测试
|
||||
await test_basic_conversation()
|
||||
# await test_streaming()
|
||||
# await test_multimodal()
|
||||
# await test_tool_calling()
|
||||
|
||||
print("=" * 60)
|
||||
print("🎉 所有测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -4,6 +4,7 @@ import binascii
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import json_repair
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@@ -1023,6 +1024,15 @@ class EmojiManager:
|
||||
- 必须是表情包,非普通截图。
|
||||
- 图中文字不超过5个。
|
||||
请确保你的最终输出是严格的JSON对象,不要添加任何额外解释或文本。
|
||||
输出格式:
|
||||
```json
|
||||
{{
|
||||
"detailed_description": "",
|
||||
"keywords": [],
|
||||
"refined_sentence": "",
|
||||
"is_compliant": true
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
image_data_for_vlm, image_format_for_vlm = image_base64, image_format
|
||||
@@ -1042,16 +1052,14 @@ class EmojiManager:
|
||||
if not vlm_response_str:
|
||||
continue
|
||||
|
||||
match = re.search(r"\{.*\}", vlm_response_str, re.DOTALL)
|
||||
if match:
|
||||
vlm_response_json = json.loads(match.group(0))
|
||||
description = vlm_response_json.get("detailed_description", "")
|
||||
emotions = vlm_response_json.get("keywords", [])
|
||||
refined_description = vlm_response_json.get("refined_sentence", "")
|
||||
is_compliant = vlm_response_json.get("is_compliant", False)
|
||||
if description and emotions and refined_description:
|
||||
logger.info("[VLM分析] 成功解析VLM返回的JSON数据。")
|
||||
break
|
||||
vlm_response_json = self._parse_json_response(vlm_response_str)
|
||||
description = vlm_response_json.get("detailed_description", "")
|
||||
emotions = vlm_response_json.get("keywords", [])
|
||||
refined_description = vlm_response_json.get("refined_sentence", "")
|
||||
is_compliant = vlm_response_json.get("is_compliant", False)
|
||||
if description and emotions and refined_description:
|
||||
logger.info("[VLM分析] 成功解析VLM返回的JSON数据。")
|
||||
break
|
||||
logger.warning("[VLM分析] VLM返回的JSON数据不完整或格式错误,准备重试。")
|
||||
except (json.JSONDecodeError, AttributeError) as e:
|
||||
logger.error(f"VLM JSON解析失败 (第 {i+1}/3 次): {e}")
|
||||
@@ -1122,7 +1130,7 @@ class EmojiManager:
|
||||
if emoji_base64 is None: # 再次检查读取
|
||||
logger.error(f"[注册失败] 无法读取图片以生成描述: {filename}")
|
||||
return False
|
||||
|
||||
|
||||
# 等待描述生成完成
|
||||
description, emotions = await self.build_emoji_description(emoji_base64)
|
||||
|
||||
@@ -1135,7 +1143,7 @@ class EmojiManager:
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
new_emoji.description = description
|
||||
new_emoji.emotion = emotions
|
||||
except Exception as build_desc_error:
|
||||
@@ -1196,6 +1204,29 @@ class EmojiManager:
|
||||
logger.error(f"[错误] 删除异常处理文件时出错: {remove_error}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _parse_json_response(cls, response: str) -> dict[str, Any] | None:
|
||||
"""解析 LLM 的 JSON 响应"""
|
||||
try:
|
||||
# 尝试提取 JSON 代码块
|
||||
json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# 尝试直接解析
|
||||
json_str = response.strip()
|
||||
|
||||
# 移除可能的注释
|
||||
json_str = re.sub(r"//.*", "", json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
|
||||
data = json_repair.loads(json_str)
|
||||
return data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"JSON 解析失败: {e}, 响应: {response[:200]}")
|
||||
return None
|
||||
|
||||
|
||||
emoji_manager = None
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ class APIProvider(ValidatedConfigBase):
|
||||
name: str = Field(..., min_length=1, description="API提供商名称")
|
||||
base_url: str = Field(..., description="API基础URL")
|
||||
api_key: str | list[str] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询")
|
||||
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(
|
||||
default="openai", description="客户端类型(如openai/google等,默认为openai)"
|
||||
client_type: Literal["openai", "gemini", "aiohttp_gemini", "bedrock"] = Field(
|
||||
default="openai", description="客户端类型(如openai/google/bedrock等,默认为openai)"
|
||||
)
|
||||
max_retry: int = Field(default=2, ge=0, description="最大重试次数(单个模型API调用失败,最多重试的次数)")
|
||||
timeout: int = Field(
|
||||
|
||||
@@ -6,3 +6,5 @@ if "openai" in used_client_types:
|
||||
from . import openai_client # noqa: F401
|
||||
if "aiohttp_gemini" in used_client_types:
|
||||
from . import aiohttp_gemini_client # noqa: F401
|
||||
if "bedrock" in used_client_types:
|
||||
from . import bedrock_client # noqa: F401
|
||||
|
||||
495
src/llm_models/model_client/bedrock_client.py
Normal file
495
src/llm_models/model_client/bedrock_client.py
Normal file
@@ -0,0 +1,495 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
import aioboto3
|
||||
import orjson
|
||||
from botocore.config import Config
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
|
||||
from ..exceptions import (
|
||||
NetworkConnectionError,
|
||||
ReqAbortException,
|
||||
RespNotOkException,
|
||||
RespParseException,
|
||||
)
|
||||
from ..payload_content.message import Message, RoleType
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
|
||||
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
|
||||
|
||||
logger = get_logger("Bedrock客户端")
|
||||
|
||||
|
||||
def _convert_messages_to_converse(messages: list[Message]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
转换消息格式 - 将消息转换为 Bedrock Converse API 所需的格式
|
||||
:param messages: 消息列表
|
||||
:return: 转换后的消息列表
|
||||
"""
|
||||
|
||||
def _convert_message_item(message: Message) -> dict[str, Any]:
|
||||
"""
|
||||
转换单个消息格式
|
||||
:param message: 消息对象
|
||||
:return: 转换后的消息字典
|
||||
"""
|
||||
# Bedrock Converse API 格式
|
||||
content: list[dict[str, Any]] = []
|
||||
|
||||
if isinstance(message.content, str):
|
||||
content.append({"text": message.content})
|
||||
elif isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if isinstance(item, tuple):
|
||||
# 图片格式:(format, base64_data)
|
||||
image_format = item[0].lower()
|
||||
image_bytes = base64.b64decode(item[1])
|
||||
content.append(
|
||||
{
|
||||
"image": {
|
||||
"format": image_format if image_format in ["png", "jpeg", "gif", "webp"] else "jpeg",
|
||||
"source": {"bytes": image_bytes},
|
||||
}
|
||||
}
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
content.append({"text": item})
|
||||
else:
|
||||
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||
|
||||
ret = {
|
||||
"role": "user" if message.role == RoleType.User else "assistant",
|
||||
"content": content,
|
||||
}
|
||||
|
||||
return ret
|
||||
|
||||
# Bedrock 不支持 system 和 tool 角色,需要过滤
|
||||
converted = []
|
||||
for msg in messages:
|
||||
if msg.role in [RoleType.User, RoleType.Assistant]:
|
||||
converted.append(_convert_message_item(msg))
|
||||
|
||||
return converted
|
||||
|
||||
|
||||
def _convert_tool_options_to_bedrock(tool_options: list[ToolOption]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
转换工具选项格式 - 将工具选项转换为 Bedrock Converse API 所需的格式
|
||||
:param tool_options: 工具选项列表
|
||||
:return: 转换后的工具选项列表
|
||||
"""
|
||||
|
||||
def _convert_tool_param(tool_param: ToolParam) -> dict[str, Any]:
|
||||
"""转换单个工具参数"""
|
||||
param_dict: dict[str, Any] = {
|
||||
"type": tool_param.param_type.value,
|
||||
"description": tool_param.description,
|
||||
}
|
||||
if tool_param.enum_values:
|
||||
param_dict["enum"] = tool_param.enum_values
|
||||
return param_dict
|
||||
|
||||
def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
|
||||
"""转换单个工具项"""
|
||||
tool_spec: dict[str, Any] = {
|
||||
"name": tool_option.name,
|
||||
"description": tool_option.description,
|
||||
}
|
||||
if tool_option.params:
|
||||
tool_spec["inputSchema"] = {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
|
||||
"required": [param.name for param in tool_option.params if param.required],
|
||||
}
|
||||
}
|
||||
return {"toolSpec": tool_spec}
|
||||
|
||||
return [_convert_tool_option_item(opt) for opt in tool_options]
|
||||
|
||||
|
||||
async def _default_stream_response_handler(
|
||||
resp_stream: Any,
|
||||
interrupt_flag: asyncio.Event | None,
|
||||
) -> tuple[APIResponse, tuple[int, int, int] | None]:
|
||||
"""
|
||||
流式响应处理函数 - 处理 Bedrock Converse Stream API 的响应
|
||||
:param resp_stream: 流式响应对象
|
||||
:param interrupt_flag: 中断标志
|
||||
:return: (APIResponse对象, usage元组)
|
||||
"""
|
||||
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区
|
||||
_tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区
|
||||
_usage_record = None
|
||||
|
||||
def _insure_buffer_closed():
|
||||
if _fc_delta_buffer and not _fc_delta_buffer.closed:
|
||||
_fc_delta_buffer.close()
|
||||
for _, _, buffer in _tool_calls_buffer:
|
||||
if buffer and not buffer.closed:
|
||||
buffer.close()
|
||||
|
||||
try:
|
||||
async for event in resp_stream["stream"]:
|
||||
if interrupt_flag and interrupt_flag.is_set():
|
||||
_insure_buffer_closed()
|
||||
raise ReqAbortException("请求被外部信号中断")
|
||||
|
||||
# 处理内容块
|
||||
if "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
_fc_delta_buffer.write(delta["text"])
|
||||
elif "toolUse" in delta:
|
||||
# 工具调用
|
||||
tool_use = delta["toolUse"]
|
||||
if "input" in tool_use:
|
||||
# 追加工具调用参数
|
||||
if tool_use.get("toolUseId"):
|
||||
# 新的工具调用
|
||||
_tool_calls_buffer.append(
|
||||
(
|
||||
tool_use["toolUseId"],
|
||||
tool_use.get("name", ""),
|
||||
io.StringIO(json.dumps(tool_use["input"])),
|
||||
)
|
||||
)
|
||||
|
||||
# 处理元数据(包含 usage)
|
||||
if "metadata" in event:
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
usage = metadata["usage"]
|
||||
_usage_record = (
|
||||
usage.get("inputTokens", 0),
|
||||
usage.get("outputTokens", 0),
|
||||
usage.get("totalTokens", 0),
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
resp = APIResponse()
|
||||
if _fc_delta_buffer.tell() > 0:
|
||||
resp.content = _fc_delta_buffer.getvalue()
|
||||
_fc_delta_buffer.close()
|
||||
|
||||
if _tool_calls_buffer:
|
||||
resp.tool_calls = []
|
||||
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
|
||||
if arguments_buffer.tell() > 0:
|
||||
raw_arg_data = arguments_buffer.getvalue()
|
||||
arguments_buffer.close()
|
||||
try:
|
||||
arguments = orjson.loads(repair_json(raw_arg_data))
|
||||
if not isinstance(arguments, dict):
|
||||
raise RespParseException(
|
||||
None,
|
||||
f"响应解析失败,工具调用参数无法解析为字典类型。原始响应:\n{raw_arg_data}",
|
||||
)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise RespParseException(
|
||||
None,
|
||||
f"响应解析失败,无法解析工具调用参数。原始响应:{raw_arg_data}",
|
||||
) from e
|
||||
else:
|
||||
arguments_buffer.close()
|
||||
arguments = None
|
||||
|
||||
resp.tool_calls.append(ToolCall(call_id, function_name, args=arguments))
|
||||
|
||||
return resp, _usage_record
|
||||
|
||||
except Exception as e:
|
||||
_insure_buffer_closed()
|
||||
raise
|
||||
|
||||
|
||||
async def _default_async_response_parser(
|
||||
resp_data: dict[str, Any],
|
||||
) -> tuple[APIResponse, tuple[int, int, int] | None]:
|
||||
"""
|
||||
默认异步响应解析函数 - 解析 Bedrock Converse API 的响应
|
||||
:param resp_data: 响应数据
|
||||
:return: (APIResponse对象, usage元组)
|
||||
"""
|
||||
resp = APIResponse()
|
||||
|
||||
# 解析输出内容
|
||||
if "output" in resp_data and "message" in resp_data["output"]:
|
||||
message = resp_data["output"]["message"]
|
||||
content_blocks = message.get("content", [])
|
||||
|
||||
text_parts = []
|
||||
tool_calls = []
|
||||
|
||||
for block in content_blocks:
|
||||
if "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
elif "toolUse" in block:
|
||||
tool_use = block["toolUse"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=tool_use.get("toolUseId", ""),
|
||||
func_name=tool_use.get("name", ""),
|
||||
args=tool_use.get("input", {}),
|
||||
)
|
||||
)
|
||||
|
||||
if text_parts:
|
||||
resp.content = "".join(text_parts)
|
||||
if tool_calls:
|
||||
resp.tool_calls = tool_calls
|
||||
|
||||
# 解析 usage
|
||||
usage_record = None
|
||||
if "usage" in resp_data:
|
||||
usage = resp_data["usage"]
|
||||
usage_record = (
|
||||
usage.get("inputTokens", 0),
|
||||
usage.get("outputTokens", 0),
|
||||
usage.get("totalTokens", 0),
|
||||
)
|
||||
|
||||
resp.raw_data = resp_data
|
||||
return resp, usage_record
|
||||
|
||||
|
||||
@client_registry.register_client_class("bedrock")
|
||||
class BedrockClient(BaseClient):
|
||||
"""AWS Bedrock 客户端"""
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
super().__init__(api_provider)
|
||||
|
||||
# 从 extra_params 获取 AWS 配置
|
||||
# 支持两种认证方式:
|
||||
# 方式1(显式凭证):api_key + extra_params.aws_secret_key
|
||||
# 方式2(IAM角色):只配置 region,自动从环境/实例角色获取凭证
|
||||
region = api_provider.extra_params.get("region", "us-east-1")
|
||||
aws_secret_key = api_provider.extra_params.get("aws_secret_key")
|
||||
|
||||
# 配置 boto3
|
||||
self.region = region
|
||||
self.boto_config = Config(
|
||||
region_name=self.region,
|
||||
connect_timeout=api_provider.timeout,
|
||||
read_timeout=api_provider.timeout,
|
||||
retries={"max_attempts": api_provider.max_retry, "mode": "adaptive"},
|
||||
)
|
||||
|
||||
# 判断认证方式
|
||||
if aws_secret_key:
|
||||
# 方式1:显式 IAM 凭证
|
||||
self.aws_access_key_id = api_provider.get_api_key()
|
||||
self.aws_secret_access_key = aws_secret_key
|
||||
self.session = aioboto3.Session(
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
region_name=self.region,
|
||||
)
|
||||
logger.info(f"初始化 Bedrock 客户端(IAM 凭证模式),区域: {self.region}")
|
||||
else:
|
||||
# 方式2:IAM 角色自动认证(从环境变量、EC2/ECS 实例角色获取)
|
||||
self.session = aioboto3.Session(region_name=self.region)
|
||||
logger.info(f"初始化 Bedrock 客户端(IAM 角色模式),区域: {self.region}")
|
||||
logger.info("将使用环境变量或实例角色自动获取 AWS 凭证")
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
message_list: list[Message],
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取对话响应
|
||||
"""
|
||||
try:
|
||||
# 提取 system prompt
|
||||
system_prompts = []
|
||||
filtered_messages = []
|
||||
for msg in message_list:
|
||||
if msg.role == RoleType.System:
|
||||
if isinstance(msg.content, str):
|
||||
system_prompts.append({"text": msg.content})
|
||||
else:
|
||||
filtered_messages.append(msg)
|
||||
|
||||
# 转换消息格式
|
||||
messages = _convert_messages_to_converse(filtered_messages)
|
||||
|
||||
# 构建请求参数
|
||||
request_params: dict[str, Any] = {
|
||||
"modelId": model_info.model_identifier,
|
||||
"messages": messages,
|
||||
"inferenceConfig": {
|
||||
"maxTokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
# 添加 system prompt
|
||||
if system_prompts:
|
||||
request_params["system"] = system_prompts
|
||||
|
||||
# 添加工具配置
|
||||
if tool_options:
|
||||
request_params["toolConfig"] = {"tools": _convert_tool_options_to_bedrock(tool_options)}
|
||||
|
||||
# 合并额外参数
|
||||
if extra_params:
|
||||
request_params.update(extra_params)
|
||||
|
||||
# 合并模型配置的额外参数
|
||||
if model_info.extra_params:
|
||||
request_params.update(model_info.extra_params)
|
||||
|
||||
# 创建 Bedrock Runtime 客户端
|
||||
async with self.session.client("bedrock-runtime", config=self.boto_config) as bedrock_client:
|
||||
# 判断是否使用流式模式
|
||||
use_stream = model_info.force_stream_mode or stream_response_handler is not None
|
||||
|
||||
if use_stream:
|
||||
# 流式调用
|
||||
response = await bedrock_client.converse_stream(**request_params)
|
||||
if stream_response_handler:
|
||||
# 用户提供的处理器(可能是同步的)
|
||||
result = stream_response_handler(response, interrupt_flag)
|
||||
if asyncio.iscoroutine(result):
|
||||
api_resp, usage_tuple = await result
|
||||
else:
|
||||
api_resp, usage_tuple = result # type: ignore
|
||||
else:
|
||||
# 默认异步处理器
|
||||
api_resp, usage_tuple = await _default_stream_response_handler(response, interrupt_flag)
|
||||
else:
|
||||
# 非流式调用
|
||||
response = await bedrock_client.converse(**request_params)
|
||||
if async_response_parser:
|
||||
# 用户提供的解析器(可能是同步的)
|
||||
result = async_response_parser(response)
|
||||
if asyncio.iscoroutine(result):
|
||||
api_resp, usage_tuple = await result
|
||||
else:
|
||||
api_resp, usage_tuple = result # type: ignore
|
||||
else:
|
||||
# 默认异步解析器
|
||||
api_resp, usage_tuple = await _default_async_response_parser(response)
|
||||
|
||||
# 设置 usage
|
||||
if usage_tuple:
|
||||
api_resp.usage = UsageRecord(
|
||||
model_name=model_info.model_identifier,
|
||||
provider_name=self.api_provider.name,
|
||||
prompt_tokens=usage_tuple[0],
|
||||
completion_tokens=usage_tuple[1],
|
||||
total_tokens=usage_tuple[2],
|
||||
)
|
||||
|
||||
return api_resp
|
||||
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
logger.error(f"Bedrock API 调用失败 ({error_type}): {e!s}")
|
||||
|
||||
# 处理特定错误类型
|
||||
if "ThrottlingException" in error_type or "ServiceQuota" in error_type:
|
||||
raise RespNotOkException(429, f"请求限流: {e!s}") from e
|
||||
elif "ValidationException" in error_type:
|
||||
raise RespParseException(400, f"请求参数错误: {e!s}") from e
|
||||
elif "AccessDeniedException" in error_type:
|
||||
raise RespNotOkException(403, f"访问被拒绝: {e!s}") from e
|
||||
elif "ResourceNotFoundException" in error_type:
|
||||
raise RespNotOkException(404, f"模型不存在: {e!s}") from e
|
||||
elif "timeout" in str(e).lower() or "timed out" in str(e).lower():
|
||||
logger.error(f"请求超时: {e!s}")
|
||||
raise NetworkConnectionError() from e
|
||||
else:
|
||||
logger.error(f"网络连接错误: {e!s}")
|
||||
raise NetworkConnectionError() from e
|
||||
|
||||
async def get_embedding(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
embedding_input: str | list[str],
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取文本嵌入(Bedrock 支持 Titan Embeddings 等模型)
|
||||
"""
|
||||
try:
|
||||
async with self.session.client("bedrock-runtime", config=self.boto_config) as bedrock_client:
|
||||
# Bedrock Embeddings 使用 InvokeModel API
|
||||
is_batch = isinstance(embedding_input, list)
|
||||
input_text = embedding_input if is_batch else [embedding_input]
|
||||
|
||||
results = []
|
||||
total_tokens = 0
|
||||
|
||||
for text in input_text:
|
||||
# 构建请求体(Titan Embeddings 格式)
|
||||
body = json.dumps({"inputText": text})
|
||||
|
||||
response = await bedrock_client.invoke_model(
|
||||
modelId=model_info.model_identifier,
|
||||
contentType="application/json",
|
||||
accept="application/json",
|
||||
body=body,
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
response_body = json.loads(await response["body"].read())
|
||||
embedding = response_body.get("embedding", [])
|
||||
results.append(embedding)
|
||||
|
||||
# 累计 token 使用
|
||||
if "inputTokenCount" in response_body:
|
||||
total_tokens += response_body["inputTokenCount"]
|
||||
|
||||
api_resp = APIResponse()
|
||||
api_resp.embedding = results if is_batch else results[0]
|
||||
api_resp.usage = UsageRecord(
|
||||
model_name=model_info.model_identifier,
|
||||
provider_name=self.api_provider.name,
|
||||
prompt_tokens=total_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
return api_resp
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Bedrock Embedding 调用失败: {e!s}")
|
||||
raise NetworkConnectionError() from e
|
||||
|
||||
async def get_audio_transcriptions(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
audio_base64: str,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取音频转录(Bedrock 暂不直接支持,抛出未实现异常)
|
||||
"""
|
||||
raise NotImplementedError("AWS Bedrock 暂不支持音频转录功能,建议使用 AWS Transcribe 服务")
|
||||
|
||||
def get_support_image_formats(self) -> list[str]:
|
||||
"""
|
||||
获取支持的图片格式
|
||||
:return: 支持的图片格式列表
|
||||
"""
|
||||
return ["png", "jpeg", "jpg", "gif", "webp"]
|
||||
@@ -187,8 +187,8 @@ class ShortTermMemoryManager:
|
||||
"importance": 0.7,
|
||||
"attributes": {{
|
||||
"time": "时间信息",
|
||||
"attribute1": "其他属性1"
|
||||
"attribute2": "其他属性2"
|
||||
"attribute1": "其他属性1",
|
||||
"attribute2": "其他属性2",
|
||||
...
|
||||
}}
|
||||
}}
|
||||
@@ -327,7 +327,7 @@ class ShortTermMemoryManager:
|
||||
# 创建决策对象
|
||||
# 将 LLM 返回的大写操作名转换为小写(适配枚举定义)
|
||||
operation_str = data.get("operation", "CREATE_NEW").lower()
|
||||
|
||||
|
||||
decision = ShortTermDecision(
|
||||
operation=ShortTermOperation(operation_str),
|
||||
target_memory_id=data.get("target_memory_id"),
|
||||
@@ -597,35 +597,35 @@ class ShortTermMemoryManager:
|
||||
# 1. 正常筛选:重要性达标的记忆
|
||||
candidates = [mem for mem in self.memories if mem.importance >= self.transfer_importance_threshold]
|
||||
candidate_ids = {mem.id for mem in candidates}
|
||||
|
||||
|
||||
# 2. 检查低重要性记忆是否积压
|
||||
# 剩余的都是低重要性记忆
|
||||
low_importance_memories = [mem for mem in self.memories if mem.id not in candidate_ids]
|
||||
|
||||
|
||||
# 如果低重要性记忆数量超过了上限(说明积压严重)
|
||||
# 我们需要清理掉一部分,而不是转移它们
|
||||
if len(low_importance_memories) > self.max_memories:
|
||||
# 目标保留数量(降至上限的 90%)
|
||||
target_keep_count = int(self.max_memories * 0.9)
|
||||
num_to_remove = len(low_importance_memories) - target_keep_count
|
||||
|
||||
|
||||
if num_to_remove > 0:
|
||||
# 按创建时间排序,删除最早的
|
||||
low_importance_memories.sort(key=lambda x: x.created_at)
|
||||
to_remove = low_importance_memories[:num_to_remove]
|
||||
|
||||
|
||||
for mem in to_remove:
|
||||
if mem in self.memories:
|
||||
self.memories.remove(mem)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"短期记忆清理: 移除了 {len(to_remove)} 条低重要性记忆 "
|
||||
f"(保留 {len(self.memories)} 条)"
|
||||
)
|
||||
|
||||
|
||||
# 触发保存
|
||||
asyncio.create_task(self._save_to_disk())
|
||||
|
||||
|
||||
return candidates
|
||||
|
||||
async def clear_transferred_memories(self, memory_ids: list[str]) -> None:
|
||||
|
||||
@@ -189,6 +189,38 @@ def register_plugin_from_file(plugin_name: str, load_after_register: bool = True
|
||||
# 该部分包含控制插件整体启用/禁用状态的功能。
|
||||
|
||||
|
||||
async def load_plugin(plugin_name: str) -> bool:
|
||||
"""
|
||||
加载一个已注册但未加载的插件。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要加载的插件名称。
|
||||
|
||||
Returns:
|
||||
bool: 如果插件成功加载,则为 True。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果插件未注册或已经加载。
|
||||
"""
|
||||
# 检查插件是否已经加载
|
||||
if plugin_name in plugin_manager.list_loaded_plugins():
|
||||
logger.warning(f"插件 '{plugin_name}' 已经加载。")
|
||||
return True
|
||||
|
||||
# 检查插件是否已注册
|
||||
if plugin_name not in plugin_manager.list_registered_plugins():
|
||||
raise ValueError(f"插件 '{plugin_name}' 未注册,无法加载。")
|
||||
|
||||
# 尝试加载插件
|
||||
success, _ = plugin_manager.load_registered_plugin_classes(plugin_name)
|
||||
if success:
|
||||
logger.info(f"插件 '{plugin_name}' 加载成功。")
|
||||
else:
|
||||
logger.error(f"插件 '{plugin_name}' 加载失败。")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def enable_plugin(plugin_name: str) -> bool:
|
||||
"""
|
||||
启用一个已禁用的插件。
|
||||
@@ -235,6 +267,7 @@ async def disable_plugin(plugin_name: str,) -> bool:
|
||||
禁用一个插件。
|
||||
|
||||
禁用插件不会卸载它,只会标记为禁用状态。
|
||||
包含对 Chatter 组件的保护机制,防止禁用最后一个启用的 Chatter。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要禁用的插件名称。
|
||||
@@ -248,6 +281,42 @@ async def disable_plugin(plugin_name: str,) -> bool:
|
||||
logger.warning(f"插件 '{plugin_name}' 未加载,无需禁用。")
|
||||
return True
|
||||
|
||||
# Chatter 保护检查:确保系统中至少有一个 Chatter 组件处于启用状态
|
||||
try:
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
# 获取该插件的所有组件
|
||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
||||
if plugin_info:
|
||||
# 检查插件是否包含 Chatter 组件
|
||||
has_chatter = any(
|
||||
comp.component_type == ComponentType.CHATTER
|
||||
for comp in plugin_info.components
|
||||
)
|
||||
|
||||
if has_chatter:
|
||||
# 获取所有启用的 Chatter 组件
|
||||
enabled_chatters = component_registry.get_enabled_components_by_type(ComponentType.CHATTER)
|
||||
|
||||
# 统计该插件中启用的 Chatter 数量
|
||||
plugin_enabled_chatters = [
|
||||
comp.name for comp in plugin_info.components
|
||||
if comp.component_type == ComponentType.CHATTER
|
||||
and comp.name in enabled_chatters
|
||||
]
|
||||
|
||||
# 如果禁用此插件会导致没有可用的 Chatter,则阻止操作
|
||||
if len(enabled_chatters) <= len(plugin_enabled_chatters):
|
||||
logger.warning(
|
||||
f"操作被阻止:禁用插件 '{plugin_name}' 将导致系统中没有可用的 Chatter 组件。"
|
||||
f"至少需要保持一个 Chatter 组件处于启用状态。"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"检查 Chatter 保护机制时发生错误: {e}")
|
||||
# 即使检查失败,也继续执行禁用操作(降级处理)
|
||||
|
||||
# 设置插件为禁用状态
|
||||
plugin_instance.enable_plugin = False
|
||||
logger.info(f"插件 '{plugin_name}' 已禁用。")
|
||||
|
||||
@@ -12,21 +12,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
ActionInfo,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
|
||||
from src.plugin_system.base.base_prompt import BasePrompt
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.core.component_registry import ComponentRegistry
|
||||
|
||||
logger = get_logger("component_state_manager")
|
||||
@@ -103,27 +97,25 @@ class ComponentStateManager:
|
||||
# 更新特定类型的启用列表
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._registry._default_actions[component_name] = cast(ActionInfo, target_info)
|
||||
self._registry._default_actions[component_name] = target_info # type: ignore
|
||||
case ComponentType.TOOL:
|
||||
self._registry._llm_available_tools[component_name] = cast(type[BaseTool], target_class)
|
||||
self._registry._llm_available_tools[component_name] = target_class # type: ignore
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
self._registry._enabled_event_handlers[component_name] = cast(type[BaseEventHandler], target_class)
|
||||
self._registry._enabled_event_handlers[component_name] = target_class # type: ignore
|
||||
# 重新注册事件处理器
|
||||
from .event_manager import event_manager
|
||||
event_manager.register_event_handler(
|
||||
cast(type[BaseEventHandler], target_class),
|
||||
target_class, # type: ignore
|
||||
self._registry.get_plugin_config(target_info.plugin_name) or {}
|
||||
)
|
||||
case ComponentType.CHATTER:
|
||||
self._registry._enabled_chatter_registry[component_name] = cast(type[BaseChatter], target_class)
|
||||
self._registry._enabled_chatter_registry[component_name] = target_class # type: ignore
|
||||
case ComponentType.INTEREST_CALCULATOR:
|
||||
self._registry._enabled_interest_calculator_registry[component_name] = cast(
|
||||
type[BaseInterestCalculator], target_class
|
||||
)
|
||||
self._registry._enabled_interest_calculator_registry[component_name] = target_class # type: ignore
|
||||
case ComponentType.PROMPT:
|
||||
self._registry._enabled_prompt_registry[component_name] = cast(type[BasePrompt], target_class)
|
||||
self._registry._enabled_prompt_registry[component_name] = target_class # type: ignore
|
||||
case ComponentType.ADAPTER:
|
||||
self._registry._enabled_adapter_registry[component_name] = cast(Any, target_class)
|
||||
self._registry._enabled_adapter_registry[component_name] = target_class # type: ignore
|
||||
|
||||
logger.info(f"组件 {component_name} ({component_type.value}) 已全局启用")
|
||||
return True
|
||||
@@ -261,8 +253,9 @@ class ComponentStateManager:
|
||||
|
||||
检查顺序:
|
||||
1. 组件是否存在
|
||||
2. (如果提供了 stream_id 且组件类型支持局部状态) 是否有局部状态覆盖
|
||||
3. 全局启用状态
|
||||
2. 组件所属插件是否已启用
|
||||
3. (如果提供了 stream_id 且组件类型支持局部状态) 是否有局部状态覆盖
|
||||
4. 全局启用状态
|
||||
|
||||
Args:
|
||||
component_name: 组件名称
|
||||
@@ -278,17 +271,29 @@ class ComponentStateManager:
|
||||
if not component_info:
|
||||
return False
|
||||
|
||||
# 2. 不支持局部状态的类型,直接返回全局状态
|
||||
# 2. 检查组件所属插件是否已启用
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
plugin_instance = plugin_manager.get_plugin_instance(component_info.plugin_name)
|
||||
if not plugin_instance:
|
||||
return False
|
||||
if not plugin_instance.enable_plugin:
|
||||
logger.debug(
|
||||
f"组件 {component_name} ({component_type.value}) 不可用: "
|
||||
f"所属插件 {component_info.plugin_name} 已被禁用"
|
||||
)
|
||||
return False
|
||||
|
||||
# 3. 不支持局部状态的类型,直接返回全局状态
|
||||
if component_type in self._no_local_state_types:
|
||||
return component_info.enabled
|
||||
|
||||
# 3. 如果提供了 stream_id,检查是否存在局部状态覆盖
|
||||
# 4. 如果提供了 stream_id,检查是否存在局部状态覆盖
|
||||
if stream_id:
|
||||
local_state = self.get_local_state(stream_id, component_name, component_type)
|
||||
if local_state is not None:
|
||||
return local_state # 局部状态存在,直接返回
|
||||
|
||||
# 4. 如果没有局部状态覆盖,返回全局状态
|
||||
# 5. 如果没有局部状态覆盖,返回全局状态
|
||||
return component_info.enabled
|
||||
|
||||
def get_enabled_components_by_type(
|
||||
|
||||
@@ -58,7 +58,7 @@ class ChatterPlanFilter:
|
||||
prompt, used_message_id_list = await self._build_prompt(plan)
|
||||
plan.llm_prompt = prompt
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"规划器原始提示词:{prompt}"
|
||||
) # 叫你不要改你耳朵聋吗😡😡😡😡😡
|
||||
|
||||
|
||||
@@ -16,46 +16,6 @@ from ..services.manager import get_service
|
||||
logger = get_logger("tts_voice_plugin.action")
|
||||
|
||||
|
||||
def _get_available_styles() -> list[str]:
|
||||
"""动态读取配置文件,获取所有可用的TTS风格名称"""
|
||||
try:
|
||||
# 这个路径构建逻辑是为了确保无论从哪里启动,都能准确定位到配置文件
|
||||
plugin_file = Path(__file__).resolve()
|
||||
# Bot/src/plugins/built_in/tts_voice_plugin/actions -> Bot
|
||||
bot_root = plugin_file.parent.parent.parent.parent.parent.parent
|
||||
config_file = bot_root / "config" / "plugins" / "tts_voice_plugin" / "config.toml"
|
||||
|
||||
if not config_file.is_file():
|
||||
logger.warning("在 tts_action 中未找到 tts_voice_plugin 的配置文件,无法动态加载风格列表。")
|
||||
return ["default"]
|
||||
|
||||
config = toml.loads(config_file.read_text(encoding="utf-8"))
|
||||
|
||||
styles_config = config.get("tts_styles", [])
|
||||
if not isinstance(styles_config, list):
|
||||
|
||||
return ["default"]
|
||||
|
||||
# 使用显式循环和类型检查来提取 style_name,以确保 Pylance 类型检查通过
|
||||
style_names: list[str] = []
|
||||
for style in styles_config:
|
||||
if isinstance(style, dict):
|
||||
name = style.get("style_name")
|
||||
# 确保 name 是一个非空字符串
|
||||
if isinstance(name, str) and name:
|
||||
style_names.append(name)
|
||||
|
||||
return style_names if style_names else ["default"]
|
||||
except Exception as e:
|
||||
logger.error(f"动态加载TTS风格列表时出错: {e}")
|
||||
return ["default"] # 出现任何错误都回退
|
||||
|
||||
|
||||
# 在类定义之前执行函数,获取风格列表
|
||||
AVAILABLE_STYLES = _get_available_styles()
|
||||
STYLE_OPTIONS_DESC = ", ".join(f"'{s}'" for s in AVAILABLE_STYLES)
|
||||
|
||||
|
||||
class TTSVoiceAction(BaseAction):
|
||||
"""
|
||||
通过关键词或规划器自动触发 TTS 语音合成
|
||||
@@ -75,7 +35,7 @@ class TTSVoiceAction(BaseAction):
|
||||
},
|
||||
"voice_style": {
|
||||
"type": "string",
|
||||
"description": f"语音的风格。可用选项: [{STYLE_OPTIONS_DESC}]。请根据对话的情感和上下文选择一个最合适的风格。如果未提供,将使用默认风格。",
|
||||
"description": "语音的风格。请根据对话的情感和上下文选择一个最合适的风格。如果未提供,将使用默认风格。",
|
||||
"required": False
|
||||
},
|
||||
"text_language": {
|
||||
@@ -115,6 +75,109 @@ class TTSVoiceAction(BaseAction):
|
||||
super().__init__(*args, **kwargs)
|
||||
# 关键配置项现在由 TTSService 管理
|
||||
self.tts_service = get_service("tts")
|
||||
|
||||
# 动态更新 voice_style 参数描述(包含可用风格)
|
||||
self._update_voice_style_parameter()
|
||||
|
||||
def _update_voice_style_parameter(self):
|
||||
"""动态更新 voice_style 参数描述,包含实际可用的风格选项"""
|
||||
try:
|
||||
available_styles = self._get_available_styles_safe()
|
||||
if available_styles:
|
||||
styles_list = "、".join(available_styles)
|
||||
updated_description = (
|
||||
f"语音的风格。请根据对话的情感和上下文选择一个最合适的风格。"
|
||||
f"当前可用风格:{styles_list}。如果未提供,将使用默认风格。"
|
||||
)
|
||||
# 更新实例的参数描述
|
||||
self.action_parameters["voice_style"]["description"] = updated_description
|
||||
logger.debug(f"{self.log_prefix} 已更新语音风格参数描述,包含 {len(available_styles)} 个可用风格")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 无法获取可用语音风格,使用默认参数描述")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 更新语音风格参数时出错: {e}")
|
||||
|
||||
def _get_available_styles_safe(self) -> list[str]:
|
||||
"""安全地获取可用语音风格列表"""
|
||||
try:
|
||||
# 首先尝试从TTS服务获取
|
||||
if hasattr(self.tts_service, 'get_available_styles'):
|
||||
styles = self.tts_service.get_available_styles()
|
||||
if styles:
|
||||
return styles
|
||||
|
||||
# 回退到直接读取配置文件
|
||||
plugin_file = Path(__file__).resolve()
|
||||
bot_root = plugin_file.parent.parent.parent.parent.parent.parent
|
||||
config_file = bot_root / "config" / "plugins" / "tts_voice_plugin" / "config.toml"
|
||||
|
||||
if config_file.exists():
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
config = toml.load(f)
|
||||
styles_config = config.get('tts_styles', [])
|
||||
|
||||
if isinstance(styles_config, list):
|
||||
style_names = []
|
||||
for style in styles_config:
|
||||
if isinstance(style, dict):
|
||||
name = style.get('style_name')
|
||||
if isinstance(name, str) and name:
|
||||
style_names.append(name)
|
||||
return style_names if style_names else ['default']
|
||||
except Exception as e:
|
||||
logger.debug(f"{self.log_prefix} 获取可用语音风格时出错: {e}")
|
||||
|
||||
return ['default'] # 安全回退
|
||||
|
||||
@classmethod
|
||||
def get_action_info(cls) -> "ActionInfo":
|
||||
"""重写获取Action信息的方法,动态更新参数描述"""
|
||||
# 先调用父类方法获取基础信息
|
||||
info = super().get_action_info()
|
||||
|
||||
# 尝试动态更新 voice_style 参数描述
|
||||
try:
|
||||
# 尝试获取可用风格(不创建完整实例)
|
||||
available_styles = cls._get_available_styles_for_info()
|
||||
if available_styles:
|
||||
styles_list = "、".join(available_styles)
|
||||
updated_description = (
|
||||
f"语音的风格。请根据对话的情感和上下文选择一个最合适的风格。"
|
||||
f"当前可用风格:{styles_list}。如果未提供,将使用默认风格。"
|
||||
)
|
||||
# 更新参数描述
|
||||
info.action_parameters["voice_style"]["description"] = updated_description
|
||||
except Exception as e:
|
||||
logger.debug(f"[TTSVoiceAction] 在获取Action信息时更新参数描述失败: {e}")
|
||||
|
||||
return info
|
||||
|
||||
@classmethod
|
||||
def _get_available_styles_for_info(cls) -> list[str]:
|
||||
"""为 get_action_info 方法获取可用风格(类方法版本)"""
|
||||
try:
|
||||
# 构建配置文件路径
|
||||
plugin_file = Path(__file__).resolve()
|
||||
bot_root = plugin_file.parent.parent.parent.parent.parent.parent
|
||||
config_file = bot_root / "config" / "plugins" / "tts_voice_plugin" / "config.toml"
|
||||
|
||||
if config_file.exists():
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
config = toml.load(f)
|
||||
styles_config = config.get('tts_styles', [])
|
||||
|
||||
if isinstance(styles_config, list):
|
||||
style_names = []
|
||||
for style in styles_config:
|
||||
if isinstance(style, dict):
|
||||
name = style.get('style_name')
|
||||
if isinstance(name, str) and name:
|
||||
style_names.append(name)
|
||||
return style_names if style_names else ['default']
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ['default'] # 安全回退
|
||||
|
||||
async def go_activate(self, llm_judge_model=None) -> bool:
|
||||
"""
|
||||
|
||||
@@ -28,7 +28,7 @@ class TTSVoicePlugin(BasePlugin):
|
||||
plugin_description = "基于GPT-SoVITS的文本转语音插件(重构版)"
|
||||
plugin_version = "3.1.2"
|
||||
plugin_author = "Kilo Code & 靚仔"
|
||||
enable_plugin = True
|
||||
# enable_plugin 应该由配置文件控制,不在此处硬编码
|
||||
config_file_name = "config.toml"
|
||||
dependencies: ClassVar[list[str]] = []
|
||||
|
||||
@@ -61,7 +61,7 @@ class TTSVoicePlugin(BasePlugin):
|
||||
|
||||
default_config_content = """# 插件基础配置
|
||||
[plugin]
|
||||
enable = true
|
||||
enable = false
|
||||
keywords = [
|
||||
"发语音", "语音", "说句话", "用语音说", "听你", "听声音", "想听你", "想听声音",
|
||||
"讲个话", "说段话", "念一下", "读一下", "用嘴说", "说", "能发语音吗","亲口"
|
||||
|
||||
@@ -30,6 +30,30 @@ max_retry = 2
|
||||
timeout = 30
|
||||
retry_interval = 10
|
||||
|
||||
#[[api_providers]] # AWS Bedrock配置示例 - 方式1:IAM凭证模式(取消注释以启用)
|
||||
#name = "AWS_Bedrock"
|
||||
#base_url = "" # Bedrock不需要base_url,留空即可
|
||||
#api_key = "YOUR_AWS_ACCESS_KEY_ID" # 你的AWS Access Key ID
|
||||
#client_type = "bedrock" # 使用bedrock客户端
|
||||
#max_retry = 2
|
||||
#timeout = 60 # Bedrock推荐较长超时时间
|
||||
#retry_interval = 10
|
||||
#[api_providers.extra_params] # Bedrock需要的额外配置
|
||||
#aws_secret_key = "YOUR_AWS_SECRET_ACCESS_KEY" # 你的AWS Secret Access Key
|
||||
#region = "us-east-1" # AWS区域,可选:us-east-1, us-west-2, eu-central-1等
|
||||
|
||||
#[[api_providers]] # AWS Bedrock配置示例 - 方式2:IAM角色模式(推荐EC2/ECS部署)
|
||||
#name = "AWS_Bedrock_Role"
|
||||
#base_url = "" # Bedrock不需要base_url
|
||||
#api_key = "dummy" # IAM角色模式不使用api_key,但字段必填,可填任意值
|
||||
#client_type = "bedrock"
|
||||
#max_retry = 2
|
||||
#timeout = 60
|
||||
#retry_interval = 10
|
||||
#[api_providers.extra_params]
|
||||
## 不配置aws_secret_key,将自动使用IAM角色/环境变量认证
|
||||
#region = "us-east-1" # 只需配置区域
|
||||
|
||||
|
||||
[[models]] # 模型(可以配置多个)
|
||||
model_identifier = "deepseek-chat" # 模型标识符(API服务商提供的模型标识符)
|
||||
@@ -123,6 +147,28 @@ price_out = 0.0
|
||||
#thinking_level = "medium" # Gemini3新版参数,可选值: "low", "medium", "high"
|
||||
thinking_budget = 256 # Gemini2.5系列旧版参数,不同模型范围不同(如 gemini-2.5-flash: 1-24576, gemini-2.5-pro: 128-32768)
|
||||
|
||||
#[[models]] # AWS Bedrock - Claude 3.5 Sonnet配置示例(取消注释以启用)
|
||||
#model_identifier = "us.anthropic.claude-3-5-sonnet-20240620-v1:0" # 跨区推理配置文件
|
||||
#name = "claude-3.5-sonnet-bedrock"
|
||||
#api_provider = "AWS_Bedrock"
|
||||
#price_in = 3.0 # 每百万输入token价格(USD)
|
||||
#price_out = 15.0 # 每百万输出token价格(USD)
|
||||
#force_stream_mode = false
|
||||
|
||||
#[[models]] # AWS Bedrock - Amazon Nova Pro配置示例
|
||||
#model_identifier = "us.amazon.nova-pro-v1:0"
|
||||
#name = "nova-pro"
|
||||
#api_provider = "AWS_Bedrock"
|
||||
#price_in = 0.8
|
||||
#price_out = 3.2
|
||||
|
||||
#[[models]] # AWS Bedrock - Titan Embeddings嵌入模型示例
|
||||
#model_identifier = "amazon.titan-embed-text-v2:0"
|
||||
#name = "titan-embed-v2"
|
||||
#api_provider = "AWS_Bedrock"
|
||||
#price_in = 0.00002 # 每千token
|
||||
#price_out = 0.0
|
||||
|
||||
[model_task_config.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型
|
||||
model_list = ["siliconflow-deepseek-ai/DeepSeek-V3.2-Exp"] # 使用的模型列表,每个子项对应上面的模型名称(name)
|
||||
temperature = 0.2 # 模型温度,新V3建议0.1-0.3
|
||||
|
||||
Reference in New Issue
Block a user