refactor(core): 优化类型提示与代码风格

本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:

1.  **类型提示现代化**:
    -   将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
    -   这提高了代码的可读性,并与较新 Python 版本的风格保持一致。

2.  **代码风格统一**:
    -   移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
    -   统一了部分日志输出的格式,增强了日志的可读性。

3.  **导入语句优化**:
    -   调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。

这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
This commit is contained in:
minecraft1024a
2025-10-31 20:56:17 +08:00
parent 926adf16dd
commit a29be48091
47 changed files with 923 additions and 933 deletions

View File

@@ -5,7 +5,6 @@
"""
import os
import time
from typing import Dict, List, Optional, Tuple
from src.common.logger import get_logger
@@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner")
class StyleLearner:
"""单个聊天室的表达风格学习器"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
def __init__(self, chat_id: str, model_config: dict | None = None):
"""
Args:
chat_id: 聊天室ID
@@ -37,9 +36,9 @@ class StyleLearner:
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
self.id_to_style: dict[str, str] = {} # style_id -> style文本
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0
# 学习统计
@@ -51,7 +50,7 @@ class StyleLearner:
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
def add_style(self, style: str, situation: str | None = None) -> bool:
"""
动态添加一个新的风格
@@ -130,7 +129,7 @@ class StyleLearner:
logger.error(f"学习映射失败: {e}")
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
根据up_content预测最合适的style
@@ -146,7 +145,7 @@ class StyleLearner:
if not self.style_to_id:
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
return None, {}
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
@@ -155,7 +154,7 @@ class StyleLearner:
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
if best_style is None:
logger.warning(
f"style_id无法转换为style文本: style_id={best_style_id}, "
@@ -171,7 +170,7 @@ class StyleLearner:
style_scores[style_text] = score
else:
logger.warning(f"跳过无法转换的style_id: {sid}")
logger.debug(
f"预测成功: up_content={up_content[:30]}..., "
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
@@ -183,7 +182,7 @@ class StyleLearner:
logger.error(f"预测style失败: {e}", exc_info=True)
return None, {}
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
def get_style_info(self, style: str) -> tuple[str | None, str | None]:
"""
获取style的完整信息
@@ -200,7 +199,7 @@ class StyleLearner:
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_styles(self) -> List[str]:
def get_all_styles(self) -> list[str]:
"""
获取所有风格列表
@@ -209,7 +208,7 @@ class StyleLearner:
"""
return list(self.style_to_id.keys())
def apply_decay(self, factor: Optional[float] = None):
def apply_decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -304,7 +303,7 @@ class StyleLearner:
logger.error(f"加载StyleLearner失败: {e}")
return False
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取统计信息"""
model_stats = self.expressor.get_stats()
return {
@@ -324,7 +323,7 @@ class StyleLearnerManager:
Args:
model_save_path: 模型保存路径
"""
self.learners: Dict[str, StyleLearner] = {}
self.learners: dict[str, StyleLearner] = {}
self.model_save_path = model_save_path
# 确保保存目录存在
@@ -332,7 +331,7 @@ class StyleLearnerManager:
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
@@ -369,7 +368,7 @@ class StyleLearnerManager:
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
预测最合适的风格
@@ -399,7 +398,7 @@ class StyleLearnerManager:
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def apply_decay_all(self, factor: Optional[float] = None):
def apply_decay_all(self, factor: float | None = None):
"""
对所有学习器应用知识衰减
@@ -409,9 +408,9 @@ class StyleLearnerManager:
for learner in self.learners.values():
learner.apply_decay(factor)
logger.info(f"对所有StyleLearner应用知识衰减")
logger.info("对所有StyleLearner应用知识衰减")
def get_all_stats(self) -> Dict[str, Dict]:
def get_all_stats(self) -> dict[str, dict]:
"""
获取所有学习器的统计信息