refactor(core): 优化类型提示与代码风格
本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:
1. **类型提示现代化**:
- 将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
- 这提高了代码的可读性,并与较新 Python 版本的风格保持一致。
2. **代码风格统一**:
- 移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
- 统一了部分日志输出的格式,增强了日志的可读性。
3. **导入语句优化**:
- 调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。
这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
This commit is contained in:
@@ -5,7 +5,6 @@
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner")
|
||||
class StyleLearner:
|
||||
"""单个聊天室的表达风格学习器"""
|
||||
|
||||
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
|
||||
def __init__(self, chat_id: str, model_config: dict | None = None):
|
||||
"""
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
@@ -37,9 +36,9 @@ class StyleLearner:
|
||||
|
||||
# 动态风格管理
|
||||
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
||||
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
|
||||
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
|
||||
self.next_style_id = 0
|
||||
|
||||
# 学习统计
|
||||
@@ -51,7 +50,7 @@ class StyleLearner:
|
||||
|
||||
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
|
||||
|
||||
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
|
||||
def add_style(self, style: str, situation: str | None = None) -> bool:
|
||||
"""
|
||||
动态添加一个新的风格
|
||||
|
||||
@@ -130,7 +129,7 @@ class StyleLearner:
|
||||
logger.error(f"学习映射失败: {e}")
|
||||
return False
|
||||
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
|
||||
"""
|
||||
根据up_content预测最合适的style
|
||||
|
||||
@@ -146,7 +145,7 @@ class StyleLearner:
|
||||
if not self.style_to_id:
|
||||
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
|
||||
return None, {}
|
||||
|
||||
|
||||
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
||||
|
||||
if best_style_id is None:
|
||||
@@ -155,7 +154,7 @@ class StyleLearner:
|
||||
|
||||
# 将style_id转换为style文本
|
||||
best_style = self.id_to_style.get(best_style_id)
|
||||
|
||||
|
||||
if best_style is None:
|
||||
logger.warning(
|
||||
f"style_id无法转换为style文本: style_id={best_style_id}, "
|
||||
@@ -171,7 +170,7 @@ class StyleLearner:
|
||||
style_scores[style_text] = score
|
||||
else:
|
||||
logger.warning(f"跳过无法转换的style_id: {sid}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"预测成功: up_content={up_content[:30]}..., "
|
||||
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
|
||||
@@ -183,7 +182,7 @@ class StyleLearner:
|
||||
logger.error(f"预测style失败: {e}", exc_info=True)
|
||||
return None, {}
|
||||
|
||||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
def get_style_info(self, style: str) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
获取style的完整信息
|
||||
|
||||
@@ -200,7 +199,7 @@ class StyleLearner:
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
return style_id, situation
|
||||
|
||||
def get_all_styles(self) -> List[str]:
|
||||
def get_all_styles(self) -> list[str]:
|
||||
"""
|
||||
获取所有风格列表
|
||||
|
||||
@@ -209,7 +208,7 @@ class StyleLearner:
|
||||
"""
|
||||
return list(self.style_to_id.keys())
|
||||
|
||||
def apply_decay(self, factor: Optional[float] = None):
|
||||
def apply_decay(self, factor: float | None = None):
|
||||
"""
|
||||
应用知识衰减
|
||||
|
||||
@@ -304,7 +303,7 @@ class StyleLearner:
|
||||
logger.error(f"加载StyleLearner失败: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
model_stats = self.expressor.get_stats()
|
||||
return {
|
||||
@@ -324,7 +323,7 @@ class StyleLearnerManager:
|
||||
Args:
|
||||
model_save_path: 模型保存路径
|
||||
"""
|
||||
self.learners: Dict[str, StyleLearner] = {}
|
||||
self.learners: dict[str, StyleLearner] = {}
|
||||
self.model_save_path = model_save_path
|
||||
|
||||
# 确保保存目录存在
|
||||
@@ -332,7 +331,7 @@ class StyleLearnerManager:
|
||||
|
||||
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
|
||||
|
||||
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
|
||||
def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner:
|
||||
"""
|
||||
获取或创建指定chat_id的学习器
|
||||
|
||||
@@ -369,7 +368,7 @@ class StyleLearnerManager:
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.learn_mapping(up_content, style)
|
||||
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
|
||||
"""
|
||||
预测最合适的风格
|
||||
|
||||
@@ -399,7 +398,7 @@ class StyleLearnerManager:
|
||||
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
|
||||
return success
|
||||
|
||||
def apply_decay_all(self, factor: Optional[float] = None):
|
||||
def apply_decay_all(self, factor: float | None = None):
|
||||
"""
|
||||
对所有学习器应用知识衰减
|
||||
|
||||
@@ -409,9 +408,9 @@ class StyleLearnerManager:
|
||||
for learner in self.learners.values():
|
||||
learner.apply_decay(factor)
|
||||
|
||||
logger.info(f"对所有StyleLearner应用知识衰减")
|
||||
logger.info("对所有StyleLearner应用知识衰减")
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict]:
|
||||
def get_all_stats(self) -> dict[str, dict]:
|
||||
"""
|
||||
获取所有学习器的统计信息
|
||||
|
||||
|
||||
Reference in New Issue
Block a user