chore: 代码清理和格式化
对项目代码进行了一系列小幅度的清理和改进,包括: - 移除未使用的导入语句 - 统一代码格式,如调整空行和导入顺序 - 优化日志输出的可读性 - 更新类型注解以符合现代 Python 语法 - 修复代码风格检查器(linter)报告的问题
This commit is contained in:
@@ -3,7 +3,6 @@ import datetime
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ class StreamContext(BaseDataModel):
|
|||||||
# 同步打断计数到ChatStream
|
# 同步打断计数到ChatStream
|
||||||
await self._sync_interruption_count_to_stream()
|
await self._sync_interruption_count_to_stream()
|
||||||
|
|
||||||
|
|
||||||
async def _sync_interruption_count_to_stream(self):
|
async def _sync_interruption_count_to_stream(self):
|
||||||
"""同步打断计数到ChatStream"""
|
"""同步打断计数到ChatStream"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware # 新增导入
|
from fastapi.middleware.cors import CORSMiddleware # 新增导入
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from uvicorn import Config, Server as UvicornServer
|
from uvicorn import Config
|
||||||
|
from uvicorn import Server as UvicornServer
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class ChatConfig(ValidatedConfigBase):
|
|||||||
interruption_min_probability: float = Field(
|
interruption_min_probability: float = Field(
|
||||||
default=0.1, ge=0.0, le=1.0, description="最低打断概率(即使达到较高打断次数,也保证有此概率的打断机会)"
|
default=0.1, ge=0.0, le=1.0, description="最低打断概率(即使达到较高打断次数,也保证有此概率的打断机会)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# DEPRECATED: interruption_probability_factor (已废弃的配置项)
|
# DEPRECATED: interruption_probability_factor (已废弃的配置项)
|
||||||
# 新的线性概率模型不再需要复杂的概率因子
|
# 新的线性概率模型不再需要复杂的概率因子
|
||||||
# 保留此字段是为了向后兼容,现有配置文件不会报错
|
# 保留此字段是为了向后兼容,现有配置文件不会报错
|
||||||
|
|||||||
@@ -265,7 +265,7 @@ async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20,
|
|||||||
else: # private
|
else: # private
|
||||||
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
|
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
|
||||||
found_stream = stream
|
found_stream = stream
|
||||||
break
|
break
|
||||||
if not found_stream:
|
if not found_stream:
|
||||||
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
|
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -236,10 +236,10 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
|
|
||||||
if is_mentioned:
|
if is_mentioned:
|
||||||
if is_at:
|
if is_at:
|
||||||
logger.debug(f"[提及分计算] 直接@机器人,返回1.0")
|
logger.debug("[提及分计算] 直接@机器人,返回1.0")
|
||||||
return 1.0 # 直接@机器人,最高分
|
return 1.0 # 直接@机器人,最高分
|
||||||
else:
|
else:
|
||||||
logger.debug(f"[提及分计算] 提及机器人,返回0.8")
|
logger.debug("[提及分计算] 提及机器人,返回0.8")
|
||||||
return 0.8 # 提及机器人名字,高分
|
return 0.8 # 提及机器人名字,高分
|
||||||
else:
|
else:
|
||||||
# 检查是否被提及(文本匹配)
|
# 检查是否被提及(文本匹配)
|
||||||
@@ -248,13 +248,13 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
|
|
||||||
# 如果被提及或是私聊,都视为提及了bot
|
# 如果被提及或是私聊,都视为提及了bot
|
||||||
if is_text_mentioned:
|
if is_text_mentioned:
|
||||||
logger.debug(f"[提及分计算] 文本提及机器人,返回提及分")
|
logger.debug("[提及分计算] 文本提及机器人,返回提及分")
|
||||||
return global_config.affinity_flow.mention_bot_interest_score
|
return global_config.affinity_flow.mention_bot_interest_score
|
||||||
elif is_private_chat:
|
elif is_private_chat:
|
||||||
logger.debug(f"[提及分计算] 私聊消息,返回提及分")
|
logger.debug("[提及分计算] 私聊消息,返回提及分")
|
||||||
return global_config.affinity_flow.mention_bot_interest_score
|
return global_config.affinity_flow.mention_bot_interest_score
|
||||||
else:
|
else:
|
||||||
logger.debug(f"[提及分计算] 未提及机器人,返回0.0")
|
logger.debug("[提及分计算] 未提及机器人,返回0.0")
|
||||||
return 0.0 # 未提及机器人
|
return 0.0 # 未提及机器人
|
||||||
|
|
||||||
def _apply_no_reply_boost(self, base_score: float) -> float:
|
def _apply_no_reply_boost(self, base_score: float) -> float:
|
||||||
|
|||||||
@@ -500,7 +500,7 @@ class ChatterPlanExecutor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 添加到chat_stream的已读消息中
|
# 添加到chat_stream的已读消息中
|
||||||
if hasattr(chat_stream, 'stream_context') and chat_stream.stream_context:
|
if hasattr(chat_stream, "stream_context") and chat_stream.stream_context:
|
||||||
chat_stream.stream_context.history_messages.append(bot_message)
|
chat_stream.stream_context.history_messages.append(bot_message)
|
||||||
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
|
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
"is_built_in": False,
|
"is_built_in": False,
|
||||||
"plugin_type": "tools",
|
"plugin_type": "tools",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
TTS 语音合成 Action
|
TTS 语音合成 Action
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.apis import config_api, generator_api
|
from src.plugin_system.apis import generator_api
|
||||||
from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode
|
from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode
|
||||||
|
|
||||||
from ..services.manager import get_service
|
from ..services.manager import get_service
|
||||||
@@ -44,7 +42,7 @@ class TTSVoiceAction(BaseAction):
|
|||||||
# 关键配置项现在由 TTSService 管理
|
# 关键配置项现在由 TTSService 管理
|
||||||
self.tts_service = get_service("tts")
|
self.tts_service = get_service("tts")
|
||||||
|
|
||||||
async def execute(self) -> Tuple[bool, str]:
|
async def execute(self) -> tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
执行 Action 的核心逻辑
|
执行 Action 的核心逻辑
|
||||||
"""
|
"""
|
||||||
@@ -79,7 +77,7 @@ class TTSVoiceAction(BaseAction):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self._handle_error_and_reply("generic_error", e)
|
await self._handle_error_and_reply("generic_error", e)
|
||||||
return False, f"语音合成出错: {str(e)}"
|
return False, f"语音合成出错: {e!s}"
|
||||||
|
|
||||||
async def _generate_final_text(self, initial_text: str) -> str:
|
async def _generate_final_text(self, initial_text: str) -> str:
|
||||||
"""请求主回复模型生成或优化文本"""
|
"""请求主回复模型生成或优化文本"""
|
||||||
@@ -89,7 +87,7 @@ class TTSVoiceAction(BaseAction):
|
|||||||
"请基于规划器提供的初步文本,结合对话历史和自己的人设,将它优化成一句自然、富有感情、适合用语音说出的话。"
|
"请基于规划器提供的初步文本,结合对话历史和自己的人设,将它优化成一句自然、富有感情、适合用语音说出的话。"
|
||||||
"最终指令:请务-必确保文本听起来像真实的、自然的口语对话,而不是书面语。"
|
"最终指令:请务-必确保文本听起来像真实的、自然的口语对话,而不是书面语。"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 请求主回复模型(replyer)全新生成TTS文本...")
|
logger.info(f"{self.log_prefix} 请求主回复模型(replyer)全新生成TTS文本...")
|
||||||
success, response_set, _ = await generator_api.rewrite_reply(
|
success, response_set, _ = await generator_api.rewrite_reply(
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream,
|
||||||
@@ -101,11 +99,11 @@ class TTSVoiceAction(BaseAction):
|
|||||||
text = "".join(str(seg[1]) if isinstance(seg, tuple) else str(seg) for seg in response_set).strip()
|
text = "".join(str(seg[1]) if isinstance(seg, tuple) else str(seg) for seg in response_set).strip()
|
||||||
logger.info(f"{self.log_prefix} 成功生成高质量TTS文本: {text}")
|
logger.info(f"{self.log_prefix} 成功生成高质量TTS文本: {text}")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
if initial_text:
|
if initial_text:
|
||||||
logger.warning(f"{self.log_prefix} 主模型生成失败,使用规划器原始文本作为兜底。")
|
logger.warning(f"{self.log_prefix} 主模型生成失败,使用规划器原始文本作为兜底。")
|
||||||
return initial_text
|
return initial_text
|
||||||
|
|
||||||
raise Exception("主模型未能生成回复,且规划器也未提供兜底文本。")
|
raise Exception("主模型未能生成回复,且规划器也未提供兜底文本。")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -119,11 +117,11 @@ class TTSVoiceAction(BaseAction):
|
|||||||
error_prompts = {
|
error_prompts = {
|
||||||
"generic_error": {
|
"generic_error": {
|
||||||
"raw_reply": "糟糕,我的思路好像缠成一团毛线球了,需要一点时间来解开...你能耐心等我一下吗?",
|
"raw_reply": "糟糕,我的思路好像缠成一团毛线球了,需要一点时间来解开...你能耐心等我一下吗?",
|
||||||
"reason": f"客观原因:插件在执行时发生了未知异常。详细信息: {str(exception)}"
|
"reason": f"客观原因:插件在执行时发生了未知异常。详细信息: {exception!s}"
|
||||||
},
|
},
|
||||||
"tts_api_error": {
|
"tts_api_error": {
|
||||||
"raw_reply": "我的麦克风好像有点小情绪,突然不想工作了...我正在哄它呢,请稍等片刻哦!",
|
"raw_reply": "我的麦克风好像有点小情绪,突然不想工作了...我正在哄它呢,请稍等片刻哦!",
|
||||||
"reason": f"客观原因:语音合成服务返回了一个错误。详细信息: {str(exception)}"
|
"reason": f"客观原因:语音合成服务返回了一个错误。详细信息: {exception!s}"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prompt_data = error_prompts.get(error_context, error_prompts["generic_error"])
|
prompt_data = error_prompts.get(error_context, error_prompts["generic_error"])
|
||||||
@@ -144,6 +142,6 @@ class TTSVoiceAction(BaseAction):
|
|||||||
await self.send_text("唔...我的思路好像卡壳了,请稍等一下哦!")
|
await self.send_text("唔...我的思路好像卡壳了,请稍等一下哦!")
|
||||||
|
|
||||||
await self.store_action_info(
|
await self.store_action_info(
|
||||||
action_prompt_display=f"语音合成失败: {str(exception)}",
|
action_prompt_display=f"语音合成失败: {exception!s}",
|
||||||
action_done=False
|
action_done=False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
TTS 语音合成命令
|
TTS 语音合成命令
|
||||||
"""
|
"""
|
||||||
import re
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.command_args import CommandArgs
|
from src.plugin_system.base.command_args import CommandArgs
|
||||||
from src.plugin_system.base.plus_command import PlusCommand
|
from src.plugin_system.base.plus_command import PlusCommand
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
TTS Voice 插件 - 重构版
|
TTS Voice 插件 - 重构版
|
||||||
"""
|
"""
|
||||||
import toml
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Tuple, Type, Dict
|
from typing import Any
|
||||||
|
|
||||||
|
import toml
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
||||||
from src.plugin_system.base.component_types import PermissionNodeField
|
from src.plugin_system.base.component_types import PermissionNodeField
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
|
||||||
|
|
||||||
from .actions.tts_action import TTSVoiceAction
|
from .actions.tts_action import TTSVoiceAction
|
||||||
from .commands.tts_command import TTSVoiceCommand
|
from .commands.tts_command import TTSVoiceCommand
|
||||||
@@ -57,34 +57,34 @@ class TTSVoicePlugin(BasePlugin):
|
|||||||
"""
|
"""
|
||||||
# 需要手动加载的顶级配置节
|
# 需要手动加载的顶级配置节
|
||||||
manual_load_keys = ["tts_styles", "spatial_effects", "tts_advanced", "tts"]
|
manual_load_keys = ["tts_styles", "spatial_effects", "tts_advanced", "tts"]
|
||||||
top_key = key.split('.')[0]
|
top_key = key.split(".")[0]
|
||||||
|
|
||||||
if top_key in manual_load_keys:
|
if top_key in manual_load_keys:
|
||||||
try:
|
try:
|
||||||
plugin_file = Path(__file__).resolve()
|
plugin_file = Path(__file__).resolve()
|
||||||
bot_root = plugin_file.parent.parent.parent.parent.parent
|
bot_root = plugin_file.parent.parent.parent.parent.parent
|
||||||
config_file = bot_root / "config" / "plugins" / self.plugin_name / self.config_file_name
|
config_file = bot_root / "config" / "plugins" / self.plugin_name / self.config_file_name
|
||||||
|
|
||||||
if not config_file.is_file():
|
if not config_file.is_file():
|
||||||
logger.error(f"TTS config file not found at robustly constructed path: {config_file}")
|
logger.error(f"TTS config file not found at robustly constructed path: {config_file}")
|
||||||
return default
|
return default
|
||||||
|
|
||||||
full_config = toml.loads(config_file.read_text(encoding="utf-8"))
|
full_config = toml.loads(config_file.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
# 支持点状路径访问
|
# 支持点状路径访问
|
||||||
value = full_config
|
value = full_config
|
||||||
for k in key.split('.'):
|
for k in key.split("."):
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
value = value.get(k)
|
value = value.get(k)
|
||||||
else:
|
else:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
return value if value is not None else default
|
return value if value is not None else default
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to manually load '{key}' from config: {e}", exc_info=True)
|
logger.error(f"Failed to manually load '{key}' from config: {e}", exc_info=True)
|
||||||
return default
|
return default
|
||||||
|
|
||||||
return self.get_config(key, default)
|
return self.get_config(key, default)
|
||||||
|
|
||||||
async def on_plugin_loaded(self):
|
async def on_plugin_loaded(self):
|
||||||
@@ -100,7 +100,7 @@ class TTSVoicePlugin(BasePlugin):
|
|||||||
register_service("tts", self.tts_service)
|
register_service("tts", self.tts_service)
|
||||||
logger.info("TTSService 已成功初始化并注册。")
|
logger.info("TTSService 已成功初始化并注册。")
|
||||||
|
|
||||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||||
"""
|
"""
|
||||||
返回插件包含的组件列表。
|
返回插件包含的组件列表。
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,9 +4,9 @@
|
|||||||
用于注册和获取插件内部使用的服务实例。
|
用于注册和获取插件内部使用的服务实例。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
_services: Dict[str, Any] = {}
|
_services: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_service(name: str, instance: Any) -> None:
|
def register_service(name: str, instance: Any) -> None:
|
||||||
|
|||||||
@@ -6,14 +6,14 @@ import base64
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
from collections.abc import Callable
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from pedalboard import Convolution, Pedalboard, Reverb
|
from pedalboard import Convolution, Pedalboard, Reverb
|
||||||
from pedalboard.io import AudioFile
|
from pedalboard.io import AudioFile
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("tts_voice_plugin.service")
|
logger = get_logger("tts_voice_plugin.service")
|
||||||
@@ -24,7 +24,7 @@ class TTSService:
|
|||||||
|
|
||||||
def __init__(self, get_config_func: Callable[[str, Any], Any]):
|
def __init__(self, get_config_func: Callable[[str, Any], Any]):
|
||||||
self.get_config = get_config_func
|
self.get_config = get_config_func
|
||||||
self.tts_styles: Dict[str, Any] = {}
|
self.tts_styles: dict[str, Any] = {}
|
||||||
self.timeout: int = 60
|
self.timeout: int = 60
|
||||||
self.max_text_length: int = 500
|
self.max_text_length: int = 500
|
||||||
self._load_config()
|
self._load_config()
|
||||||
@@ -43,12 +43,12 @@ class TTSService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"TTS服务配置加载失败: {e}", exc_info=True)
|
logger.error(f"TTS服务配置加载失败: {e}", exc_info=True)
|
||||||
|
|
||||||
def _load_tts_styles(self) -> Dict[str, Dict[str, Any]]:
|
def _load_tts_styles(self) -> dict[str, dict[str, Any]]:
|
||||||
"""加载 TTS 风格配置"""
|
"""加载 TTS 风格配置"""
|
||||||
styles = {}
|
styles = {}
|
||||||
global_server = self.get_config("tts.server", "http://127.0.0.1:9880")
|
global_server = self.get_config("tts.server", "http://127.0.0.1:9880")
|
||||||
tts_styles_config = self.get_config("tts_styles", [])
|
tts_styles_config = self.get_config("tts_styles", [])
|
||||||
|
|
||||||
if not isinstance(tts_styles_config, list):
|
if not isinstance(tts_styles_config, list):
|
||||||
logger.error(f"tts_styles 配置不是一个列表, 而是 {type(tts_styles_config)}")
|
logger.error(f"tts_styles 配置不是一个列表, 而是 {type(tts_styles_config)}")
|
||||||
return styles
|
return styles
|
||||||
@@ -57,7 +57,7 @@ class TTSService:
|
|||||||
if not default_cfg:
|
if not default_cfg:
|
||||||
logger.error("在 tts_styles 配置中未找到 'default' 风格,这是必需的。")
|
logger.error("在 tts_styles 配置中未找到 'default' 风格,这是必需的。")
|
||||||
return styles
|
return styles
|
||||||
|
|
||||||
default_refer_wav = default_cfg.get("refer_wav_path", "")
|
default_refer_wav = default_cfg.get("refer_wav_path", "")
|
||||||
default_prompt_text = default_cfg.get("prompt_text", "")
|
default_prompt_text = default_cfg.get("prompt_text", "")
|
||||||
default_gpt_weights = default_cfg.get("gpt_weights", "")
|
default_gpt_weights = default_cfg.get("gpt_weights", "")
|
||||||
@@ -68,7 +68,7 @@ class TTSService:
|
|||||||
|
|
||||||
for style_cfg in tts_styles_config:
|
for style_cfg in tts_styles_config:
|
||||||
if not isinstance(style_cfg, dict): continue
|
if not isinstance(style_cfg, dict): continue
|
||||||
|
|
||||||
style_name = style_cfg.get("style_name")
|
style_name = style_cfg.get("style_name")
|
||||||
if not style_name: continue
|
if not style_name: continue
|
||||||
|
|
||||||
@@ -86,9 +86,9 @@ class TTSService:
|
|||||||
|
|
||||||
# ... [其他方法保持不变] ...
|
# ... [其他方法保持不变] ...
|
||||||
def _detect_language(self, text: str) -> str:
|
def _detect_language(self, text: str) -> str:
|
||||||
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
|
||||||
english_chars = len(re.findall(r'[a-zA-Z]', text))
|
english_chars = len(re.findall(r"[a-zA-Z]", text))
|
||||||
japanese_chars = len(re.findall(r'[\u3040-\u309f\u30a0-\u30ff]', text))
|
japanese_chars = len(re.findall(r"[\u3040-\u309f\u30a0-\u30ff]", text))
|
||||||
total_chars = chinese_chars + english_chars + japanese_chars
|
total_chars = chinese_chars + english_chars + japanese_chars
|
||||||
if total_chars == 0: return "zh"
|
if total_chars == 0: return "zh"
|
||||||
if chinese_chars / total_chars > 0.3: return "zh"
|
if chinese_chars / total_chars > 0.3: return "zh"
|
||||||
@@ -98,41 +98,41 @@ class TTSService:
|
|||||||
|
|
||||||
def _clean_text_for_tts(self, text: str) -> str:
|
def _clean_text_for_tts(self, text: str) -> str:
|
||||||
# 1. 基本清理
|
# 1. 基本清理
|
||||||
text = re.sub(r'[\((\[【].*?[\))\]】]', '', text)
|
text = re.sub(r"[\((\[【].*?[\))\]】]", "", text)
|
||||||
text = re.sub(r'([,。!?、;:,.!?;:~\-`])\1+', r'\1', text)
|
text = re.sub(r"([,。!?、;:,.!?;:~\-`])\1+", r"\1", text)
|
||||||
text = re.sub(r'~{2,}|~{2,}', ',', text)
|
text = re.sub(r"~{2,}|~{2,}", ",", text)
|
||||||
text = re.sub(r'\.{3,}|…{1,}', '。', text)
|
text = re.sub(r"\.{3,}|…{1,}", "。", text)
|
||||||
|
|
||||||
# 2. 词语替换
|
# 2. 词语替换
|
||||||
replacements = {'www': '哈哈哈', 'hhh': '哈哈', '233': '哈哈', '666': '厉害', '88': '拜拜'}
|
replacements = {"www": "哈哈哈", "hhh": "哈哈", "233": "哈哈", "666": "厉害", "88": "拜拜"}
|
||||||
for old, new in replacements.items():
|
for old, new in replacements.items():
|
||||||
text = text.replace(old, new)
|
text = text.replace(old, new)
|
||||||
|
|
||||||
# 3. 移除不必要的字符 (恢复使用更安全的原版正则,避免误删)
|
# 3. 移除不必要的字符 (恢复使用更安全的原版正则,避免误删)
|
||||||
text = re.sub(r'[^\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ffa-zA-Z0-9\s,。!?、;:,.!?;:~~]', '', text)
|
text = re.sub(r"[^\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ffa-zA-Z0-9\s,。!?、;:,.!?;:~~]", "", text)
|
||||||
|
|
||||||
# 4. 确保结尾有标点
|
# 4. 确保结尾有标点
|
||||||
if text and not text.endswith(tuple(',。!?、;:,.!?;:')):
|
if text and not text.endswith(tuple(",。!?、;:,.!?;:")):
|
||||||
text += '。'
|
text += "。"
|
||||||
|
|
||||||
# 5. 智能截断 (保留改进的截断逻辑)
|
# 5. 智能截断 (保留改进的截断逻辑)
|
||||||
if len(text) > self.max_text_length:
|
if len(text) > self.max_text_length:
|
||||||
cut_text = text[:self.max_text_length]
|
cut_text = text[:self.max_text_length]
|
||||||
punctuation = '。!?.…'
|
punctuation = "。!?.…"
|
||||||
last_punc_pos = max(cut_text.rfind(p) for p in punctuation)
|
last_punc_pos = max(cut_text.rfind(p) for p in punctuation)
|
||||||
|
|
||||||
if last_punc_pos != -1:
|
if last_punc_pos != -1:
|
||||||
text = cut_text[:last_punc_pos + 1]
|
text = cut_text[:last_punc_pos + 1]
|
||||||
else:
|
else:
|
||||||
last_comma_pos = max(cut_text.rfind(p) for p in ',、;,;')
|
last_comma_pos = max(cut_text.rfind(p) for p in ",、;,;")
|
||||||
if last_comma_pos != -1:
|
if last_comma_pos != -1:
|
||||||
text = cut_text[:last_comma_pos + 1]
|
text = cut_text[:last_comma_pos + 1]
|
||||||
else:
|
else:
|
||||||
text = cut_text
|
text = cut_text
|
||||||
|
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
async def _call_tts_api(self, server_config: Dict, text: str, text_language: str, **kwargs) -> Optional[bytes]:
|
async def _call_tts_api(self, server_config: dict, text: str, text_language: str, **kwargs) -> bytes | None:
|
||||||
"""
|
"""
|
||||||
最终修复版:先切换模型,然后仅通过路径发送合成请求。
|
最终修复版:先切换模型,然后仅通过路径发送合成请求。
|
||||||
"""
|
"""
|
||||||
@@ -144,7 +144,7 @@ class TTSService:
|
|||||||
base_url = server_config["url"].rstrip("/")
|
base_url = server_config["url"].rstrip("/")
|
||||||
|
|
||||||
# --- 步骤一:像稳定版一样,先切换模型 ---
|
# --- 步骤一:像稳定版一样,先切换模型 ---
|
||||||
async def switch_model_weights(weights_path: Optional[str], weight_type: str):
|
async def switch_model_weights(weights_path: str | None, weight_type: str):
|
||||||
if not weights_path: return
|
if not weights_path: return
|
||||||
api_endpoint = f"/set_{weight_type}_weights"
|
api_endpoint = f"/set_{weight_type}_weights"
|
||||||
switch_url = f"{base_url}{api_endpoint}"
|
switch_url = f"{base_url}{api_endpoint}"
|
||||||
@@ -173,12 +173,12 @@ class TTSService:
|
|||||||
# "gpt_model_path": kwargs.get("gpt_weights"),
|
# "gpt_model_path": kwargs.get("gpt_weights"),
|
||||||
# "sovits_model_path": kwargs.get("sovits_weights"),
|
# "sovits_model_path": kwargs.get("sovits_weights"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 合并高级配置
|
# 合并高级配置
|
||||||
advanced_config = self.get_config("tts_advanced", {})
|
advanced_config = self.get_config("tts_advanced", {})
|
||||||
if isinstance(advanced_config, dict):
|
if isinstance(advanced_config, dict):
|
||||||
data.update({k: v for k, v in advanced_config.items() if v is not None})
|
data.update({k: v for k, v in advanced_config.items() if v is not None})
|
||||||
|
|
||||||
# 优先使用风格特定的语速
|
# 优先使用风格特定的语速
|
||||||
if server_config.get("speed_factor") is not None:
|
if server_config.get("speed_factor") is not None:
|
||||||
data["speed_factor"] = server_config["speed_factor"]
|
data["speed_factor"] = server_config["speed_factor"]
|
||||||
@@ -202,7 +202,7 @@ class TTSService:
|
|||||||
logger.error(f"TTS API调用异常: {e}", exc_info=True)
|
logger.error(f"TTS API调用异常: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _apply_spatial_audio_effect(self, audio_data: bytes) -> Optional[bytes]:
|
async def _apply_spatial_audio_effect(self, audio_data: bytes) -> bytes | None:
|
||||||
"""根据配置应用空间效果(混响和卷积)"""
|
"""根据配置应用空间效果(混响和卷积)"""
|
||||||
try:
|
try:
|
||||||
effects_config = self.get_config("spatial_effects", {})
|
effects_config = self.get_config("spatial_effects", {})
|
||||||
@@ -249,9 +249,9 @@ class TTSService:
|
|||||||
# 将处理后的音频数据写回内存中的字节流
|
# 将处理后的音频数据写回内存中的字节流
|
||||||
with io.BytesIO() as output_stream:
|
with io.BytesIO() as output_stream:
|
||||||
# 使用 soundfile 写入,因为它更稳定
|
# 使用 soundfile 写入,因为它更稳定
|
||||||
sf.write(output_stream, effected.T, f.samplerate, format='WAV')
|
sf.write(output_stream, effected.T, f.samplerate, format="WAV")
|
||||||
processed_audio_data = output_stream.getvalue()
|
processed_audio_data = output_stream.getvalue()
|
||||||
|
|
||||||
logger.info("成功应用空间效果。")
|
logger.info("成功应用空间效果。")
|
||||||
return processed_audio_data
|
return processed_audio_data
|
||||||
|
|
||||||
@@ -259,9 +259,9 @@ class TTSService:
|
|||||||
logger.error(f"应用空间效果时出错: {e}", exc_info=True)
|
logger.error(f"应用空间效果时出错: {e}", exc_info=True)
|
||||||
return audio_data # 如果出错,返回原始音频
|
return audio_data # 如果出错,返回原始音频
|
||||||
|
|
||||||
async def generate_voice(self, text: str, style_hint: str = "default") -> Optional[str]:
|
async def generate_voice(self, text: str, style_hint: str = "default") -> str | None:
|
||||||
self._load_config()
|
self._load_config()
|
||||||
|
|
||||||
if not self.tts_styles:
|
if not self.tts_styles:
|
||||||
logger.error("TTS风格配置为空,无法生成语音。")
|
logger.error("TTS风格配置为空,无法生成语音。")
|
||||||
return None
|
return None
|
||||||
@@ -306,5 +306,5 @@ class TTSService:
|
|||||||
else:
|
else:
|
||||||
logger.warning("空间音频效果应用失败,将使用原始音频。")
|
logger.warning("空间音频效果应用失败,将使用原始音频。")
|
||||||
|
|
||||||
return base64.b64encode(audio_data).decode('utf-8')
|
return base64.b64encode(audio_data).decode("utf-8")
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user