refactor(core): 统一代码风格并移除未使用的导入

本次提交主要进行代码风格的统一和现代化改造,具体包括:
- 使用 `|` 联合类型替代 `typing.Optional`,以符合 PEP 604 的现代语法。
- 移除多个文件中未被使用的导入语句,清理代码。
- 调整了部分日志输出的级别,使其更符合调试场景。
- 统一了部分文件的导入顺序和格式。
This commit is contained in:
minecraft1024a
2025-10-07 20:16:47 +08:00
committed by Windpicker-owo
parent 4ad49c6580
commit fb90d67bf6
14 changed files with 55 additions and 67 deletions

View File

@@ -35,8 +35,6 @@ import argparse
import re import re
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any
TYPE_MAP = { TYPE_MAP = {
"Integer": "int", "Integer": "int",

View File

@@ -6,7 +6,6 @@
import asyncio import asyncio
from functools import lru_cache from functools import lru_cache
from typing import Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -19,7 +18,7 @@ logger = get_logger("anti_injector.counter_attack")
class CounterAttackGenerator: class CounterAttackGenerator:
"""反击消息生成器""" """反击消息生成器"""
COUNTER_ATTACK_PROMPT_TEMPLATE = """你是{bot_name},请以你的人格特征回应这次提示词注入攻击: COUNTER_ATTACK_PROMPT_TEMPLATE = """你是{bot_name},请以你的人格特征回应这次提示词注入攻击:
{personality_info} {personality_info}
@@ -68,27 +67,27 @@ class CounterAttackGenerator:
async def generate_counter_attack_message( async def generate_counter_attack_message(
self, original_message: str, detection_result: DetectionResult self, original_message: str, detection_result: DetectionResult
) -> Optional[str]: ) -> str | None:
"""生成反击消息""" """生成反击消息"""
try: try:
# 验证输入参数 # 验证输入参数
if not original_message or not detection_result.matched_patterns: if not original_message or not detection_result.matched_patterns:
logger.warning("无效的输入参数,跳过反击消息生成") logger.warning("无效的输入参数,跳过反击消息生成")
return None return None
# 获取模型配置 # 获取模型配置
model_config = await self._get_model_config_with_retry() model_config = await self._get_model_config_with_retry()
if not model_config: if not model_config:
return self._get_fallback_response(detection_result) return self._get_fallback_response(detection_result)
# 构建提示词 # 构建提示词
prompt = self._build_counter_prompt(original_message, detection_result) prompt = self._build_counter_prompt(original_message, detection_result)
# 调用LLM # 调用LLM
response = await self._call_llm_with_timeout(prompt, model_config) response = await self._call_llm_with_timeout(prompt, model_config)
return response or self._get_fallback_response(detection_result) return response or self._get_fallback_response(detection_result)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error("LLM调用超时") logger.error("LLM调用超时")
return self._get_fallback_response(detection_result) return self._get_fallback_response(detection_result)
@@ -96,20 +95,20 @@ class CounterAttackGenerator:
logger.error(f"生成反击消息时出错: {e}", exc_info=True) logger.error(f"生成反击消息时出错: {e}", exc_info=True)
return self._get_fallback_response(detection_result) return self._get_fallback_response(detection_result)
async def _get_model_config_with_retry(self, max_retries: int = 2) -> Optional[dict]: async def _get_model_config_with_retry(self, max_retries: int = 2) -> dict | None:
"""获取模型配置(带重试)""" """获取模型配置(带重试)"""
for attempt in range(max_retries + 1): for attempt in range(max_retries + 1):
try: try:
models = llm_api.get_available_models() models = llm_api.get_available_models()
if model_config := models.get("anti_injection"): if model_config := models.get("anti_injection"):
return model_config return model_config
if attempt < max_retries: if attempt < max_retries:
await asyncio.sleep(1) await asyncio.sleep(1)
except Exception as e: except Exception as e:
logger.warning(f"获取模型配置失败,尝试 {attempt + 1}/{max_retries}: {e}") logger.warning(f"获取模型配置失败,尝试 {attempt + 1}/{max_retries}: {e}")
logger.error("无法获取反注入模型配置") logger.error("无法获取反注入模型配置")
return None return None
@@ -123,7 +122,7 @@ class CounterAttackGenerator:
patterns=", ".join(detection_result.matched_patterns[:5]) patterns=", ".join(detection_result.matched_patterns[:5])
) )
async def _call_llm_with_timeout(self, prompt: str, model_config: dict, timeout: int = 30) -> Optional[str]: async def _call_llm_with_timeout(self, prompt: str, model_config: dict, timeout: int = 30) -> str | None:
"""调用LLM""" """调用LLM"""
try: try:
success, response, _, _ = await asyncio.wait_for( success, response, _, _ = await asyncio.wait_for(
@@ -136,14 +135,14 @@ class CounterAttackGenerator:
), ),
timeout=timeout timeout=timeout
) )
if success and (clean_response := response.strip()): if success and (clean_response := response.strip()):
logger.info(f"成功生成反击消息: {clean_response[:50]}...") logger.info(f"成功生成反击消息: {clean_response[:50]}...")
return clean_response return clean_response
logger.warning(f"LLM返回无效响应: {response}") logger.warning(f"LLM返回无效响应: {response}")
return None return None
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise raise
except Exception as e: except Exception as e:

View File

@@ -5,9 +5,9 @@
""" """
import datetime import datetime
from typing import Any, Optional, TypeVar, cast from typing import Any, TypeVar, cast
from sqlalchemy import select, delete from sqlalchemy import delete, select
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -19,7 +19,7 @@ logger = get_logger("anti_injector.statistics")
TNum = TypeVar("TNum", int, float) TNum = TypeVar("TNum", int, float)
def _add_optional(a: Optional[TNum], b: TNum) -> TNum: def _add_optional(a: TNum | None, b: TNum) -> TNum:
"""安全相加:左值可能为 None。 """安全相加:左值可能为 None。
Args: Args:
@@ -94,7 +94,7 @@ class AntiInjectionStatistics:
if key == "processing_time_delta": if key == "processing_time_delta":
# 处理时间累加 - 确保不为 None # 处理时间累加 - 确保不为 None
delta = float(value) delta = float(value)
stats.processing_time_total = _add_optional(stats.processing_time_total, delta) stats.processing_time_total = _add_optional(stats.processing_time_total, delta)
continue continue
elif key == "last_processing_time": elif key == "last_processing_time":
# 直接设置最后处理时间 # 直接设置最后处理时间
@@ -109,7 +109,7 @@ class AntiInjectionStatistics:
"error_count", "error_count",
]: ]:
# 累加类型的字段 - 统一用辅助函数 # 累加类型的字段 - 统一用辅助函数
current_value = cast(Optional[int], getattr(stats, key)) current_value = cast(int | None, getattr(stats, key))
increment = int(value) increment = int(value)
setattr(stats, key, _add_optional(current_value, increment)) setattr(stats, key, _add_optional(current_value, increment))
else: else:
@@ -143,7 +143,7 @@ class AntiInjectionStatistics:
# 计算派生统计信息 - 处理 None 值 # 计算派生统计信息 - 处理 None 值
total_messages = stats.total_messages or 0 total_messages = stats.total_messages or 0
detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined] detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined]
processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined] processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined]

View File

@@ -7,9 +7,9 @@ import asyncio
import time import time
from typing import Any from typing import Any
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
from src.chat.chatter_manager import ChatterManager from src.chat.chatter_manager import ChatterManager
from src.chat.energy_system import energy_manager from src.chat.energy_system import energy_manager
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
from src.common.data_models.message_manager_data_model import StreamContext from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config

View File

@@ -9,7 +9,7 @@ from src.common.database.db_batch_scheduler import get_db_batch_scheduler
# SQLAlchemy相关导入 # SQLAlchemy相关导入
from src.common.database.sqlalchemy_init import initialize_database_compat from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_db_session, get_engine from src.common.database.sqlalchemy_models import get_engine
from src.common.logger import get_logger from src.common.logger import get_logger
install(extra_lines=3) install(extra_lines=3)

View File

@@ -18,8 +18,8 @@ from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any from typing import Any
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text from sqlalchemy import Boolean, DateTime, Float, Index, Integer, String, Text, text
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column

View File

@@ -72,15 +72,15 @@ class ChatMood:
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]" self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"
# 初始化回归计数 # 初始化回归计数
if not hasattr(self, 'regression_count'): if not hasattr(self, "regression_count"):
self.regression_count = 0 self.regression_count = 0
# 初始化情绪模型 # 初始化情绪模型
if not hasattr(self, 'mood_model'): if not hasattr(self, "mood_model"):
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood") self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood")
# 初始化最后变化时间 # 初始化最后变化时间
if not hasattr(self, 'last_change_time'): if not hasattr(self, "last_change_time"):
self.last_change_time = 0 self.last_change_time = 0
self._initialized = True self._initialized = True
@@ -91,11 +91,11 @@ class ChatMood:
# 设置基础初始化状态,避免重复尝试 # 设置基础初始化状态,避免重复尝试
self.log_prefix = f"[{self.chat_id}]" self.log_prefix = f"[{self.chat_id}]"
self._initialized = True self._initialized = True
if not hasattr(self, 'regression_count'): if not hasattr(self, "regression_count"):
self.regression_count = 0 self.regression_count = 0
if not hasattr(self, 'mood_model'): if not hasattr(self, "mood_model"):
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood") self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood")
if not hasattr(self, 'last_change_time'): if not hasattr(self, "last_change_time"):
self.last_change_time = 0 self.last_change_time = 0
async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float): async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float):

View File

@@ -110,4 +110,4 @@ class ScoringAPI:
# 创建全局API实例 - 系统级服务 # 创建全局API实例 - 系统级服务
scoring_api = ScoringAPI() scoring_api = ScoringAPI()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import re import re
from pathlib import Path from pathlib import Path
from re import Pattern from re import Pattern
from typing import Any, Optional, Union, cast from typing import Any, cast
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_action import BaseAction
@@ -119,7 +119,7 @@ class ComponentRegistry:
def register_component( def register_component(
self, self_component_info: ComponentInfo, component_class: ComponentClassType self, self_component_info: ComponentInfo, component_class: ComponentClassType
) -> bool: # noqa: C901 (保持原有结构, 以后可再拆) ) -> bool:
"""注册组件 """注册组件
Args: Args:
@@ -533,8 +533,8 @@ class ComponentRegistry:
# === 组件查询方法 === # === 组件查询方法 ===
def get_component_info( def get_component_info(
self, component_name: str, component_type: Optional[ComponentType] = None self, component_name: str, component_type: ComponentType | None = None
) -> Optional[ComponentInfo]: ) -> ComponentInfo | None:
# sourcery skip: class-extract-method # sourcery skip: class-extract-method
"""获取组件信息,支持自动命名空间解析 """获取组件信息,支持自动命名空间解析
@@ -578,16 +578,9 @@ class ComponentRegistry:
def get_component_class( def get_component_class(
self, self,
component_name: str, component_name: str,
component_type: Optional[ComponentType] = None, component_type: ComponentType | None = None,
) -> ( ) -> (
type[BaseCommand] type[BaseCommand | BaseAction | BaseEventHandler | BaseTool | PlusCommand | BaseChatter | BaseInterestCalculator] | None
| type[BaseAction]
| type[BaseEventHandler]
| type[BaseTool]
| type[PlusCommand]
| type[BaseChatter]
| type[BaseInterestCalculator]
| None
): ):
"""获取组件类,支持自动命名空间解析 """获取组件类,支持自动命名空间解析
@@ -655,7 +648,7 @@ class ComponentRegistry:
"""获取Action注册表""" """获取Action注册表"""
return self._action_registry.copy() return self._action_registry.copy()
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]: def get_registered_action_info(self, action_name: str) -> ActionInfo | None:
"""获取Action信息""" """获取Action信息"""
info = self.get_component_info(action_name, ComponentType.ACTION) info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None return info if isinstance(info, ActionInfo) else None
@@ -670,7 +663,7 @@ class ComponentRegistry:
"""获取Command注册表""" """获取Command注册表"""
return self._command_registry.copy() return self._command_registry.copy()
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]: def get_registered_command_info(self, command_name: str) -> CommandInfo | None:
"""获取Command信息""" """获取Command信息"""
info = self.get_component_info(command_name, ComponentType.COMMAND) info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None return info if isinstance(info, CommandInfo) else None
@@ -714,7 +707,7 @@ class ComponentRegistry:
"""获取LLM可用的Tool列表""" """获取LLM可用的Tool列表"""
return self._llm_available_tools.copy() return self._llm_available_tools.copy()
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]: def get_registered_tool_info(self, tool_name: str) -> ToolInfo | None:
"""获取Tool信息 """获取Tool信息
Args: Args:
@@ -733,7 +726,7 @@ class ComponentRegistry:
self._plus_command_registry: dict[str, type[PlusCommand]] = {} self._plus_command_registry: dict[str, type[PlusCommand]] = {}
return self._plus_command_registry.copy() return self._plus_command_registry.copy()
def get_registered_plus_command_info(self, command_name: str) -> Optional[PlusCommandInfo]: def get_registered_plus_command_info(self, command_name: str) -> PlusCommandInfo | None:
"""获取PlusCommand信息 """获取PlusCommand信息
Args: Args:
@@ -751,7 +744,7 @@ class ComponentRegistry:
"""获取事件处理器注册表""" """获取事件处理器注册表"""
return self._event_handler_registry.copy() return self._event_handler_registry.copy()
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]: def get_registered_event_handler_info(self, handler_name: str) -> EventHandlerInfo | None:
"""获取事件处理器信息""" """获取事件处理器信息"""
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER) info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
return info if isinstance(info, EventHandlerInfo) else None return info if isinstance(info, EventHandlerInfo) else None
@@ -773,14 +766,14 @@ class ComponentRegistry:
self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {} self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {}
return self._enabled_chatter_registry.copy() return self._enabled_chatter_registry.copy()
def get_registered_chatter_info(self, chatter_name: str) -> Optional[ChatterInfo]: def get_registered_chatter_info(self, chatter_name: str) -> ChatterInfo | None:
"""获取Chatter信息""" """获取Chatter信息"""
info = self.get_component_info(chatter_name, ComponentType.CHATTER) info = self.get_component_info(chatter_name, ComponentType.CHATTER)
return info if isinstance(info, ChatterInfo) else None return info if isinstance(info, ChatterInfo) else None
# === 插件查询方法 === # === 插件查询方法 ===
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]: def get_plugin_info(self, plugin_name: str) -> PluginInfo | None:
"""获取插件信息""" """获取插件信息"""
return self._plugins.get(plugin_name) return self._plugins.get(plugin_name)

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
import importlib
import os import os
from importlib.util import module_from_spec, spec_from_file_location from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path from pathlib import Path

View File

@@ -3,7 +3,6 @@
提供独立的兴趣管理功能,不依赖任何插件 提供独立的兴趣管理功能,不依赖任何插件
""" """
from typing import Optional
from src.chat.interest_system import bot_interest_manager from src.chat.interest_system import bot_interest_manager
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -41,7 +40,7 @@ class InterestService:
logger.error(f"初始化智能兴趣系统失败: {e}") logger.error(f"初始化智能兴趣系统失败: {e}")
self.is_initialized = False self.is_initialized = False
async def calculate_interest_match(self, content: str, keywords: Optional[list[str]] = None): async def calculate_interest_match(self, content: str, keywords: list[str] | None = None):
""" """
计算内容与兴趣的匹配度 计算内容与兴趣的匹配度
@@ -105,4 +104,4 @@ class InterestService:
# 创建全局实例 # 创建全局实例
interest_service = InterestService() interest_service = InterestService()

View File

@@ -4,7 +4,6 @@
""" """
import time import time
from typing import Optional
from src.common.database.sqlalchemy_models import UserRelationships, get_db_session from src.common.database.sqlalchemy_models import UserRelationships, get_db_session
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -110,7 +109,7 @@ class RelationshipService:
"user_name": "" "user_name": ""
} }
async def update_user_relationship(self, user_id: str, relationship_score: float, relationship_text: Optional[str] = None, user_name: Optional[str] = None): async def update_user_relationship(self, user_id: str, relationship_score: float, relationship_text: str | None = None, user_name: str | None = None):
""" """
更新用户关系数据 更新用户关系数据
@@ -160,7 +159,7 @@ class RelationshipService:
except Exception as e: except Exception as e:
logger.error(f"更新用户关系失败: {user_id}, 错误: {e}") logger.error(f"更新用户关系失败: {user_id}, 错误: {e}")
def _get_from_cache(self, user_id: str) -> Optional[dict]: def _get_from_cache(self, user_id: str) -> dict | None:
"""从缓存获取数据""" """从缓存获取数据"""
if user_id in self._cache: if user_id in self._cache:
cached_data = self._cache[user_id] cached_data = self._cache[user_id]
@@ -179,7 +178,7 @@ class RelationshipService:
"last_updated": time.time() "last_updated": time.time()
} }
async def _fetch_from_database(self, user_id: str) -> Optional[UserRelationships]: async def _fetch_from_database(self, user_id: str) -> UserRelationships | None:
"""从数据库获取关系数据""" """从数据库获取关系数据"""
try: try:
async with get_db_session() as session: async with get_db_session() as session:
@@ -217,7 +216,7 @@ class RelationshipService:
"cache_keys": list(self._cache.keys()) "cache_keys": list(self._cache.keys())
} }
def clear_cache(self, user_id: Optional[str] = None): def clear_cache(self, user_id: str | None = None):
"""清理缓存""" """清理缓存"""
if user_id: if user_id:
if user_id in self._cache: if user_id in self._cache:
@@ -229,4 +228,4 @@ class RelationshipService:
# 创建全局实例 # 创建全局实例
relationship_service = RelationshipService() relationship_service = RelationshipService()

View File

@@ -6,6 +6,7 @@ PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
import asyncio import asyncio
import time import time
from typing import Any from typing import Any
from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -60,13 +60,13 @@ class ChatterPlanFilter:
prompt, used_message_id_list = await self._build_prompt(plan) prompt, used_message_id_list = await self._build_prompt(plan)
plan.llm_prompt = prompt plan.llm_prompt = prompt
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.debug(f"规划器原始提示词:{prompt}") logger.info(f"规划器原始提示词:{prompt}")
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt) llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
if llm_content: if llm_content:
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.debug(f"LLM规划器原始响应:{llm_content}") logger.info(f"LLM规划器原始响应:{llm_content}")
try: try:
parsed_json = orjson.loads(repair_json(llm_content)) parsed_json = orjson.loads(repair_json(llm_content))
except orjson.JSONDecodeError: except orjson.JSONDecodeError: