chore: 代码清理和格式化
对项目代码进行了一系列小幅度的清理和改进,包括: - 移除未使用的导入语句 - 统一代码格式,如调整空行和导入顺序 - 优化日志输出的可读性 - 更新类型注解以符合现代 Python 语法 - 修复代码风格检查器(linter)报告的问题
This commit is contained in:
@@ -3,7 +3,6 @@ import datetime
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ class StreamContext(BaseDataModel):
|
||||
# 同步打断计数到ChatStream
|
||||
await self._sync_interruption_count_to_stream()
|
||||
|
||||
|
||||
|
||||
async def _sync_interruption_count_to_stream(self):
|
||||
"""同步打断计数到ChatStream"""
|
||||
try:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import socket
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware # 新增导入
|
||||
from rich.traceback import install
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
from uvicorn import Config
|
||||
from uvicorn import Server as UvicornServer
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ class ChatConfig(ValidatedConfigBase):
|
||||
interruption_min_probability: float = Field(
|
||||
default=0.1, ge=0.0, le=1.0, description="最低打断概率(即使达到较高打断次数,也保证有此概率的打断机会)"
|
||||
)
|
||||
|
||||
|
||||
# DEPRECATED: interruption_probability_factor (已废弃的配置项)
|
||||
# 新的线性概率模型不再需要复杂的概率因子
|
||||
# 保留此字段是为了向后兼容,现有配置文件不会报错
|
||||
|
||||
@@ -265,7 +265,7 @@ async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20,
|
||||
else: # private
|
||||
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
|
||||
found_stream = stream
|
||||
break
|
||||
break
|
||||
if not found_stream:
|
||||
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
|
||||
continue
|
||||
|
||||
@@ -236,10 +236,10 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
if is_mentioned:
|
||||
if is_at:
|
||||
logger.debug(f"[提及分计算] 直接@机器人,返回1.0")
|
||||
logger.debug("[提及分计算] 直接@机器人,返回1.0")
|
||||
return 1.0 # 直接@机器人,最高分
|
||||
else:
|
||||
logger.debug(f"[提及分计算] 提及机器人,返回0.8")
|
||||
logger.debug("[提及分计算] 提及机器人,返回0.8")
|
||||
return 0.8 # 提及机器人名字,高分
|
||||
else:
|
||||
# 检查是否被提及(文本匹配)
|
||||
@@ -248,13 +248,13 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
# 如果被提及或是私聊,都视为提及了bot
|
||||
if is_text_mentioned:
|
||||
logger.debug(f"[提及分计算] 文本提及机器人,返回提及分")
|
||||
logger.debug("[提及分计算] 文本提及机器人,返回提及分")
|
||||
return global_config.affinity_flow.mention_bot_interest_score
|
||||
elif is_private_chat:
|
||||
logger.debug(f"[提及分计算] 私聊消息,返回提及分")
|
||||
logger.debug("[提及分计算] 私聊消息,返回提及分")
|
||||
return global_config.affinity_flow.mention_bot_interest_score
|
||||
else:
|
||||
logger.debug(f"[提及分计算] 未提及机器人,返回0.0")
|
||||
logger.debug("[提及分计算] 未提及机器人,返回0.0")
|
||||
return 0.0 # 未提及机器人
|
||||
|
||||
def _apply_no_reply_boost(self, base_score: float) -> float:
|
||||
|
||||
@@ -500,7 +500,7 @@ class ChatterPlanExecutor:
|
||||
)
|
||||
|
||||
# 添加到chat_stream的已读消息中
|
||||
if hasattr(chat_stream, 'stream_context') and chat_stream.stream_context:
|
||||
if hasattr(chat_stream, "stream_context") and chat_stream.stream_context:
|
||||
chat_stream.stream_context.history_messages.append(bot_message)
|
||||
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
|
||||
else:
|
||||
|
||||
@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
|
||||
"is_built_in": False,
|
||||
"plugin_type": "tools",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
"""
|
||||
TTS 语音合成 Action
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api, generator_api
|
||||
from src.plugin_system.apis import generator_api
|
||||
from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode
|
||||
|
||||
from ..services.manager import get_service
|
||||
@@ -44,7 +42,7 @@ class TTSVoiceAction(BaseAction):
|
||||
# 关键配置项现在由 TTSService 管理
|
||||
self.tts_service = get_service("tts")
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""
|
||||
执行 Action 的核心逻辑
|
||||
"""
|
||||
@@ -79,7 +77,7 @@ class TTSVoiceAction(BaseAction):
|
||||
|
||||
except Exception as e:
|
||||
await self._handle_error_and_reply("generic_error", e)
|
||||
return False, f"语音合成出错: {str(e)}"
|
||||
return False, f"语音合成出错: {e!s}"
|
||||
|
||||
async def _generate_final_text(self, initial_text: str) -> str:
|
||||
"""请求主回复模型生成或优化文本"""
|
||||
@@ -89,7 +87,7 @@ class TTSVoiceAction(BaseAction):
|
||||
"请基于规划器提供的初步文本,结合对话历史和自己的人设,将它优化成一句自然、富有感情、适合用语音说出的话。"
|
||||
"最终指令:请务-必确保文本听起来像真实的、自然的口语对话,而不是书面语。"
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 请求主回复模型(replyer)全新生成TTS文本...")
|
||||
success, response_set, _ = await generator_api.rewrite_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
@@ -101,11 +99,11 @@ class TTSVoiceAction(BaseAction):
|
||||
text = "".join(str(seg[1]) if isinstance(seg, tuple) else str(seg) for seg in response_set).strip()
|
||||
logger.info(f"{self.log_prefix} 成功生成高质量TTS文本: {text}")
|
||||
return text
|
||||
|
||||
|
||||
if initial_text:
|
||||
logger.warning(f"{self.log_prefix} 主模型生成失败,使用规划器原始文本作为兜底。")
|
||||
return initial_text
|
||||
|
||||
|
||||
raise Exception("主模型未能生成回复,且规划器也未提供兜底文本。")
|
||||
|
||||
except Exception as e:
|
||||
@@ -119,11 +117,11 @@ class TTSVoiceAction(BaseAction):
|
||||
error_prompts = {
|
||||
"generic_error": {
|
||||
"raw_reply": "糟糕,我的思路好像缠成一团毛线球了,需要一点时间来解开...你能耐心等我一下吗?",
|
||||
"reason": f"客观原因:插件在执行时发生了未知异常。详细信息: {str(exception)}"
|
||||
"reason": f"客观原因:插件在执行时发生了未知异常。详细信息: {exception!s}"
|
||||
},
|
||||
"tts_api_error": {
|
||||
"raw_reply": "我的麦克风好像有点小情绪,突然不想工作了...我正在哄它呢,请稍等片刻哦!",
|
||||
"reason": f"客观原因:语音合成服务返回了一个错误。详细信息: {str(exception)}"
|
||||
"reason": f"客观原因:语音合成服务返回了一个错误。详细信息: {exception!s}"
|
||||
}
|
||||
}
|
||||
prompt_data = error_prompts.get(error_context, error_prompts["generic_error"])
|
||||
@@ -144,6 +142,6 @@ class TTSVoiceAction(BaseAction):
|
||||
await self.send_text("唔...我的思路好像卡壳了,请稍等一下哦!")
|
||||
|
||||
await self.store_action_info(
|
||||
action_prompt_display=f"语音合成失败: {str(exception)}",
|
||||
action_prompt_display=f"语音合成失败: {exception!s}",
|
||||
action_done=False
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
TTS 语音合成命令
|
||||
"""
|
||||
import re
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""
|
||||
TTS Voice 插件 - 重构版
|
||||
"""
|
||||
import toml
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, Type, Dict
|
||||
from typing import Any
|
||||
|
||||
import toml
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
||||
from src.plugin_system.base.component_types import PermissionNodeField
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
|
||||
from .actions.tts_action import TTSVoiceAction
|
||||
from .commands.tts_command import TTSVoiceCommand
|
||||
@@ -57,34 +57,34 @@ class TTSVoicePlugin(BasePlugin):
|
||||
"""
|
||||
# 需要手动加载的顶级配置节
|
||||
manual_load_keys = ["tts_styles", "spatial_effects", "tts_advanced", "tts"]
|
||||
top_key = key.split('.')[0]
|
||||
top_key = key.split(".")[0]
|
||||
|
||||
if top_key in manual_load_keys:
|
||||
try:
|
||||
plugin_file = Path(__file__).resolve()
|
||||
bot_root = plugin_file.parent.parent.parent.parent.parent
|
||||
config_file = bot_root / "config" / "plugins" / self.plugin_name / self.config_file_name
|
||||
|
||||
|
||||
if not config_file.is_file():
|
||||
logger.error(f"TTS config file not found at robustly constructed path: {config_file}")
|
||||
return default
|
||||
|
||||
|
||||
full_config = toml.loads(config_file.read_text(encoding="utf-8"))
|
||||
|
||||
# 支持点状路径访问
|
||||
value = full_config
|
||||
for k in key.split('.'):
|
||||
for k in key.split("."):
|
||||
if isinstance(value, dict):
|
||||
value = value.get(k)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
return value if value is not None else default
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to manually load '{key}' from config: {e}", exc_info=True)
|
||||
return default
|
||||
|
||||
|
||||
return self.get_config(key, default)
|
||||
|
||||
async def on_plugin_loaded(self):
|
||||
@@ -100,7 +100,7 @@ class TTSVoicePlugin(BasePlugin):
|
||||
register_service("tts", self.tts_service)
|
||||
logger.info("TTSService 已成功初始化并注册。")
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""
|
||||
返回插件包含的组件列表。
|
||||
"""
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
用于注册和获取插件内部使用的服务实例。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
_services: Dict[str, Any] = {}
|
||||
_services: dict[str, Any] = {}
|
||||
|
||||
|
||||
def register_service(name: str, instance: Any) -> None:
|
||||
|
||||
@@ -6,14 +6,14 @@ import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import soundfile as sf
|
||||
from pedalboard import Convolution, Pedalboard, Reverb
|
||||
from pedalboard.io import AudioFile
|
||||
|
||||
import aiohttp
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tts_voice_plugin.service")
|
||||
@@ -24,7 +24,7 @@ class TTSService:
|
||||
|
||||
def __init__(self, get_config_func: Callable[[str, Any], Any]):
|
||||
self.get_config = get_config_func
|
||||
self.tts_styles: Dict[str, Any] = {}
|
||||
self.tts_styles: dict[str, Any] = {}
|
||||
self.timeout: int = 60
|
||||
self.max_text_length: int = 500
|
||||
self._load_config()
|
||||
@@ -43,12 +43,12 @@ class TTSService:
|
||||
except Exception as e:
|
||||
logger.error(f"TTS服务配置加载失败: {e}", exc_info=True)
|
||||
|
||||
def _load_tts_styles(self) -> Dict[str, Dict[str, Any]]:
|
||||
def _load_tts_styles(self) -> dict[str, dict[str, Any]]:
|
||||
"""加载 TTS 风格配置"""
|
||||
styles = {}
|
||||
global_server = self.get_config("tts.server", "http://127.0.0.1:9880")
|
||||
tts_styles_config = self.get_config("tts_styles", [])
|
||||
|
||||
|
||||
if not isinstance(tts_styles_config, list):
|
||||
logger.error(f"tts_styles 配置不是一个列表, 而是 {type(tts_styles_config)}")
|
||||
return styles
|
||||
@@ -57,7 +57,7 @@ class TTSService:
|
||||
if not default_cfg:
|
||||
logger.error("在 tts_styles 配置中未找到 'default' 风格,这是必需的。")
|
||||
return styles
|
||||
|
||||
|
||||
default_refer_wav = default_cfg.get("refer_wav_path", "")
|
||||
default_prompt_text = default_cfg.get("prompt_text", "")
|
||||
default_gpt_weights = default_cfg.get("gpt_weights", "")
|
||||
@@ -68,7 +68,7 @@ class TTSService:
|
||||
|
||||
for style_cfg in tts_styles_config:
|
||||
if not isinstance(style_cfg, dict): continue
|
||||
|
||||
|
||||
style_name = style_cfg.get("style_name")
|
||||
if not style_name: continue
|
||||
|
||||
@@ -86,9 +86,9 @@ class TTSService:
|
||||
|
||||
# ... [其他方法保持不变] ...
|
||||
def _detect_language(self, text: str) -> str:
|
||||
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
||||
english_chars = len(re.findall(r'[a-zA-Z]', text))
|
||||
japanese_chars = len(re.findall(r'[\u3040-\u309f\u30a0-\u30ff]', text))
|
||||
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
|
||||
english_chars = len(re.findall(r"[a-zA-Z]", text))
|
||||
japanese_chars = len(re.findall(r"[\u3040-\u309f\u30a0-\u30ff]", text))
|
||||
total_chars = chinese_chars + english_chars + japanese_chars
|
||||
if total_chars == 0: return "zh"
|
||||
if chinese_chars / total_chars > 0.3: return "zh"
|
||||
@@ -98,41 +98,41 @@ class TTSService:
|
||||
|
||||
def _clean_text_for_tts(self, text: str) -> str:
|
||||
# 1. 基本清理
|
||||
text = re.sub(r'[\((\[【].*?[\))\]】]', '', text)
|
||||
text = re.sub(r'([,。!?、;:,.!?;:~\-`])\1+', r'\1', text)
|
||||
text = re.sub(r'~{2,}|~{2,}', ',', text)
|
||||
text = re.sub(r'\.{3,}|…{1,}', '。', text)
|
||||
text = re.sub(r"[\((\[【].*?[\))\]】]", "", text)
|
||||
text = re.sub(r"([,。!?、;:,.!?;:~\-`])\1+", r"\1", text)
|
||||
text = re.sub(r"~{2,}|~{2,}", ",", text)
|
||||
text = re.sub(r"\.{3,}|…{1,}", "。", text)
|
||||
|
||||
# 2. 词语替换
|
||||
replacements = {'www': '哈哈哈', 'hhh': '哈哈', '233': '哈哈', '666': '厉害', '88': '拜拜'}
|
||||
replacements = {"www": "哈哈哈", "hhh": "哈哈", "233": "哈哈", "666": "厉害", "88": "拜拜"}
|
||||
for old, new in replacements.items():
|
||||
text = text.replace(old, new)
|
||||
|
||||
# 3. 移除不必要的字符 (恢复使用更安全的原版正则,避免误删)
|
||||
text = re.sub(r'[^\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ffa-zA-Z0-9\s,。!?、;:,.!?;:~~]', '', text)
|
||||
|
||||
text = re.sub(r"[^\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ffa-zA-Z0-9\s,。!?、;:,.!?;:~~]", "", text)
|
||||
|
||||
# 4. 确保结尾有标点
|
||||
if text and not text.endswith(tuple(',。!?、;:,.!?;:')):
|
||||
text += '。'
|
||||
if text and not text.endswith(tuple(",。!?、;:,.!?;:")):
|
||||
text += "。"
|
||||
|
||||
# 5. 智能截断 (保留改进的截断逻辑)
|
||||
if len(text) > self.max_text_length:
|
||||
cut_text = text[:self.max_text_length]
|
||||
punctuation = '。!?.…'
|
||||
punctuation = "。!?.…"
|
||||
last_punc_pos = max(cut_text.rfind(p) for p in punctuation)
|
||||
|
||||
if last_punc_pos != -1:
|
||||
text = cut_text[:last_punc_pos + 1]
|
||||
else:
|
||||
last_comma_pos = max(cut_text.rfind(p) for p in ',、;,;')
|
||||
last_comma_pos = max(cut_text.rfind(p) for p in ",、;,;")
|
||||
if last_comma_pos != -1:
|
||||
text = cut_text[:last_comma_pos + 1]
|
||||
else:
|
||||
text = cut_text
|
||||
|
||||
|
||||
return text.strip()
|
||||
|
||||
async def _call_tts_api(self, server_config: Dict, text: str, text_language: str, **kwargs) -> Optional[bytes]:
|
||||
async def _call_tts_api(self, server_config: dict, text: str, text_language: str, **kwargs) -> bytes | None:
|
||||
"""
|
||||
最终修复版:先切换模型,然后仅通过路径发送合成请求。
|
||||
"""
|
||||
@@ -144,7 +144,7 @@ class TTSService:
|
||||
base_url = server_config["url"].rstrip("/")
|
||||
|
||||
# --- 步骤一:像稳定版一样,先切换模型 ---
|
||||
async def switch_model_weights(weights_path: Optional[str], weight_type: str):
|
||||
async def switch_model_weights(weights_path: str | None, weight_type: str):
|
||||
if not weights_path: return
|
||||
api_endpoint = f"/set_{weight_type}_weights"
|
||||
switch_url = f"{base_url}{api_endpoint}"
|
||||
@@ -173,12 +173,12 @@ class TTSService:
|
||||
# "gpt_model_path": kwargs.get("gpt_weights"),
|
||||
# "sovits_model_path": kwargs.get("sovits_weights"),
|
||||
}
|
||||
|
||||
|
||||
# 合并高级配置
|
||||
advanced_config = self.get_config("tts_advanced", {})
|
||||
if isinstance(advanced_config, dict):
|
||||
data.update({k: v for k, v in advanced_config.items() if v is not None})
|
||||
|
||||
|
||||
# 优先使用风格特定的语速
|
||||
if server_config.get("speed_factor") is not None:
|
||||
data["speed_factor"] = server_config["speed_factor"]
|
||||
@@ -202,7 +202,7 @@ class TTSService:
|
||||
logger.error(f"TTS API调用异常: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _apply_spatial_audio_effect(self, audio_data: bytes) -> Optional[bytes]:
|
||||
async def _apply_spatial_audio_effect(self, audio_data: bytes) -> bytes | None:
|
||||
"""根据配置应用空间效果(混响和卷积)"""
|
||||
try:
|
||||
effects_config = self.get_config("spatial_effects", {})
|
||||
@@ -249,9 +249,9 @@ class TTSService:
|
||||
# 将处理后的音频数据写回内存中的字节流
|
||||
with io.BytesIO() as output_stream:
|
||||
# 使用 soundfile 写入,因为它更稳定
|
||||
sf.write(output_stream, effected.T, f.samplerate, format='WAV')
|
||||
sf.write(output_stream, effected.T, f.samplerate, format="WAV")
|
||||
processed_audio_data = output_stream.getvalue()
|
||||
|
||||
|
||||
logger.info("成功应用空间效果。")
|
||||
return processed_audio_data
|
||||
|
||||
@@ -259,9 +259,9 @@ class TTSService:
|
||||
logger.error(f"应用空间效果时出错: {e}", exc_info=True)
|
||||
return audio_data # 如果出错,返回原始音频
|
||||
|
||||
async def generate_voice(self, text: str, style_hint: str = "default") -> Optional[str]:
|
||||
async def generate_voice(self, text: str, style_hint: str = "default") -> str | None:
|
||||
self._load_config()
|
||||
|
||||
|
||||
if not self.tts_styles:
|
||||
logger.error("TTS风格配置为空,无法生成语音。")
|
||||
return None
|
||||
@@ -306,5 +306,5 @@ class TTSService:
|
||||
else:
|
||||
logger.warning("空间音频效果应用失败,将使用原始音频。")
|
||||
|
||||
return base64.b64encode(audio_data).decode('utf-8')
|
||||
return base64.b64encode(audio_data).decode("utf-8")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user