refactor(core): 优化类型提示与代码风格

本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:

1.  **类型提示现代化**:
    -   将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
    -   这提高了代码的可读性,并与较新 Python 版本的风格保持一致。

2.  **代码风格统一**:
    -   移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
    -   统一了部分日志输出的格式,增强了日志的可读性。

3.  **导入语句优化**:
    -   调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。

这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
This commit is contained in:
minecraft1024a
2025-10-31 20:56:17 +08:00
parent 926adf16dd
commit a29be48091
47 changed files with 923 additions and 933 deletions

View File

@@ -5,14 +5,14 @@
import difflib
import random
import re
from typing import Any, Dict, List, Optional
from typing import Any
from src.common.logger import get_logger
logger = get_logger("express_utils")
def filter_message_content(content: Optional[str]) -> str:
def filter_message_content(content: str | None) -> str:
"""
过滤消息内容,移除回复、@、图片等格式
@@ -51,7 +51,7 @@ def calculate_similarity(text1: str, text2: str) -> float:
return difflib.SequenceMatcher(None, text1, text2).ratio()
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]:
"""
加权随机抽样函数
@@ -108,7 +108,7 @@ def normalize_text(text: str) -> str:
return text.strip()
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
"""
简单的关键词提取(基于词频)
@@ -135,7 +135,7 @@ def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
return words[:max_keywords]
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
def format_expression_pair(situation: str, style: str, index: int | None = None) -> str:
"""
格式化表达方式对
@@ -153,7 +153,7 @@ def format_expression_pair(situation: str, style: str, index: Optional[int] = No
return f'"{situation}"时,使用"{style}"'
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
def parse_expression_pair(text: str) -> tuple[str, str] | None:
"""
解析表达方式对文本
@@ -170,7 +170,7 @@ def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
return None
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]:
"""
批量去重表达方式
@@ -219,8 +219,8 @@ def calculate_time_weight(last_active_time: float, current_time: float, half_lif
def merge_expressions_from_multiple_chats(
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
) -> List[Dict[str, Any]]:
expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100
) -> list[dict[str, Any]]:
"""
合并多个聊天室的表达方式

View File

@@ -438,9 +438,9 @@ class ExpressionLearner:
try:
# 获取 StyleLearner 实例
learner = style_learner_manager.get_learner(chat_id)
logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}")
# 为每个学习到的表达方式训练模型
# 使用 situation 作为输入style 作为目标
# 这是最符合语义的方式:场景 -> 表达方式
@@ -448,25 +448,25 @@ class ExpressionLearner:
for expr in expr_list:
situation = expr["situation"]
style = expr["style"]
# 训练映射关系: situation -> style
if learner.learn_mapping(situation, style):
success_count += 1
else:
logger.warning(f"训练失败: {situation} -> {style}")
logger.info(
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, "
f"当前风格总数={len(learner.get_all_styles())}, "
f"总样本数={learner.learning_stats['total_samples']}"
)
# 保存模型
if learner.save(style_learner_manager.model_save_path):
logger.info(f"StyleLearner 模型保存成功: {chat_id}")
else:
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
except Exception as e:
logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True)
@@ -527,7 +527,7 @@ class ExpressionLearner:
logger.debug(f"学习{type_str}的response: {response}")
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
if not expressions:
logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。")
logger.info(f"LLM完整响应:\n{response}")
@@ -542,26 +542,26 @@ class ExpressionLearner:
"""
expressions: list[tuple[str, str, str]] = []
failed_lines = []
for line_num, line in enumerate(response.splitlines(), 1):
line = line.strip()
if not line:
continue
# 替换中文引号为英文引号,便于统一处理
line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"')
# 查找"当"和下一个引号
idx_when = line_normalized.find('"')
if idx_when == -1:
# 尝试不带引号的格式: 当xxx时
idx_when = line_normalized.find('')
idx_when = line_normalized.find("")
if idx_when == -1:
failed_lines.append((line_num, line, "找不到''关键字"))
continue
# 提取"当"和"时"之间的内容
idx_shi = line_normalized.find('', idx_when)
idx_shi = line_normalized.find("", idx_when)
if idx_shi == -1:
failed_lines.append((line_num, line, "找不到''关键字"))
continue
@@ -575,20 +575,20 @@ class ExpressionLearner:
continue
situation = line_normalized[idx_quote1 + 1 : idx_quote2]
search_start = idx_quote2
# 查找"使用"或"可以"
idx_use = line_normalized.find('使用"', search_start)
if idx_use == -1:
idx_use = line_normalized.find('可以"', search_start)
if idx_use == -1:
# 尝试不带引号的格式
idx_use = line_normalized.find('使用', search_start)
idx_use = line_normalized.find("使用", search_start)
if idx_use == -1:
idx_use = line_normalized.find('可以', search_start)
idx_use = line_normalized.find("可以", search_start)
if idx_use == -1:
failed_lines.append((line_num, line, "找不到'使用''可以'关键字"))
continue
# 提取剩余部分作为style
style = line_normalized[idx_use + 2:].strip('"\'"",。')
if not style:
@@ -610,24 +610,24 @@ class ExpressionLearner:
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
else:
style = line_normalized[idx_quote3 + 1 : idx_quote4]
# 清理并验证
situation = situation.strip()
style = style.strip()
if not situation or not style:
failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'"))
continue
expressions.append((chat_id, situation, style))
# 记录解析失败的行
if failed_lines:
logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:")
for line_num, line, reason in failed_lines[:5]: # 只显示前5个
logger.warning(f"{line_num}: {reason}")
logger.debug(f" 原文: {line}")
if not expressions:
logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}")
else:

View File

@@ -267,11 +267,11 @@ class ExpressionSelector:
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
else:
chat_info = chat_history
# 根据配置选择模式
mode = global_config.expression.mode
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
if mode == "exp_model":
return await self._select_expressions_model_only(
chat_id=chat_id,
@@ -288,7 +288,7 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
async def _select_expressions_classic(
self,
chat_id: str,
@@ -298,7 +298,7 @@ class ExpressionSelector:
min_num: int = 5,
) -> list[dict[str, Any]]:
"""经典模式:随机抽样 + LLM评估"""
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
logger.debug("[Classic模式] 使用LLM评估表达方式")
return await self.select_suitable_expressions_llm(
chat_id=chat_id,
chat_info=chat_info,
@@ -306,7 +306,7 @@ class ExpressionSelector:
min_num=min_num,
target_message=target_message
)
async def _select_expressions_model_only(
self,
chat_id: str,
@@ -316,22 +316,22 @@ class ExpressionSelector:
min_num: int = 5,
) -> list[dict[str, Any]]:
"""模型预测模式先提取情境再使用StyleLearner预测表达风格"""
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
logger.debug("[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 步骤1: 提取聊天情境
situations = await situation_extractor.extract_situations(
chat_history=chat_info,
target_message=target_message,
max_situations=3
)
if not situations:
logger.warning(f"无法提取聊天情境,回退到经典模式")
logger.warning("无法提取聊天情境,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -339,17 +339,17 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}")
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
learner = style_learner_manager.get_learner(chat_id)
all_predicted_styles = {}
for i, situation in enumerate(situations, 1):
logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}")
best_style, scores = learner.predict_style(situation, top_k=max_num)
if best_style and scores:
logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}")
# 合并分数(取最高分)
@@ -357,10 +357,10 @@ class ExpressionSelector:
if style not in all_predicted_styles or score > all_predicted_styles[style]:
all_predicted_styles[style] = score
else:
logger.debug(f" 该情境未返回预测结果")
logger.debug(" 该情境未返回预测结果")
if not all_predicted_styles:
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果可能模型未训练回退到经典模式")
logger.warning("[Exp_model模式] StyleLearner未返回预测结果可能模型未训练回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -368,22 +368,22 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
# 将分数字典转换为列表格式 [(style, score), ...]
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
# 步骤3: 根据预测的风格从数据库获取表达方式
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
logger.debug("[Exp_model模式] 步骤3 - 从数据库查询表达方式")
expressions = await self.get_model_predicted_expressions(
chat_id=chat_id,
predicted_styles=predicted_styles,
max_num=max_num
)
if not expressions:
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
logger.warning("[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
@@ -391,10 +391,10 @@ class ExpressionSelector:
max_num=max_num,
min_num=min_num
)
logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式")
return expressions
async def get_model_predicted_expressions(
self,
chat_id: str,
@@ -414,15 +414,15 @@ class ExpressionSelector:
"""
if not predicted_styles:
return []
# 提取风格名称前3个最佳匹配
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id支持共享表达方式
related_chat_ids = self.get_related_chat_ids(chat_id)
logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}")
async with get_db_session() as session:
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
db_chat_ids_result = await session.execute(
@@ -432,7 +432,7 @@ class ExpressionSelector:
)
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
all_expressions_result = await session.execute(
select(Expression)
@@ -440,51 +440,51 @@ class ExpressionSelector:
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
if not all_expressions:
logger.info(f"相关chat_id没有数据尝试从所有chat_id查询")
logger.info("相关chat_id没有数据尝试从所有chat_id查询")
all_expressions_result = await session.execute(
select(Expression)
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
if not all_expressions:
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
logger.warning("数据库中完全没有任何表达方式,需要先学习")
return []
# 🔥 使用模糊匹配而不是精确匹配
# 计算每个预测style与数据库style的相似度
from difflib import SequenceMatcher
matched_expressions = []
for expr in all_expressions:
db_style = expr.style or ""
max_similarity = 0.0
best_predicted = ""
# 与每个预测的style计算相似度
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
# 计算字符串相似度
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
# 也检查包含关系(如果一个是另一个的子串,给更高分)
if len(predicted_style) >= 2 and len(db_style) >= 2:
if predicted_style in db_style or db_style in predicted_style:
similarity = max(similarity, 0.7)
if similarity > max_similarity:
max_similarity = similarity
best_predicted = predicted_style
# 🔥 降低阈值到30%因为StyleLearner预测质量较差
if max_similarity >= 0.3: # 30%相似度阈值
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
if not matched_expressions:
# 收集数据库中的style样例用于调试
all_styles = [e.style for e in all_expressions[:10]]
@@ -495,11 +495,11 @@ class ExpressionSelector:
f" 提示: StyleLearner预测质量差建议重新训练或使用classic模式"
)
return []
# 按照相似度*count排序选择最佳匹配
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
# 显示最佳匹配的详细信息
top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]]
logger.info(
@@ -507,7 +507,7 @@ class ExpressionSelector:
f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n"
f" Top3匹配: {top_matches}"
)
# 转换为字典格式
expressions = []
for expr in expressions_objs:
@@ -518,7 +518,7 @@ class ExpressionSelector:
"count": float(expr.count) if expr.count else 0.0,
"last_active_time": expr.last_active_time or 0.0
})
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
return expressions

View File

@@ -5,7 +5,6 @@
import os
import pickle
from collections import Counter, defaultdict
from typing import Dict, Optional, Tuple
from src.common.logger import get_logger
@@ -36,14 +35,14 @@ class ExpressorModel:
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
# 候选表达管理
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
self._candidates: dict[str, str] = {} # cid -> text (style)
self._situations: dict[str, str] = {} # cid -> situation (不参与计算)
logger.info(
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
)
def add_candidate(self, cid: str, text: str, situation: Optional[str] = None):
def add_candidate(self, cid: str, text: str, situation: str | None = None):
"""
添加候选文本和对应的situation
@@ -62,7 +61,7 @@ class ExpressorModel:
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]:
"""
直接对所有候选进行朴素贝叶斯评分
@@ -113,7 +112,7 @@ class ExpressorModel:
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: Optional[float] = None):
def decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -122,7 +121,7 @@ class ExpressorModel:
"""
self.nb.decay(factor)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
def get_candidate_info(self, cid: str) -> tuple[str | None, str | None]:
"""
获取候选信息
@@ -136,7 +135,7 @@ class ExpressorModel:
situation = self._situations.get(cid)
return style, situation
def get_all_candidates(self) -> Dict[str, Tuple[str, str]]:
def get_all_candidates(self) -> dict[str, tuple[str, str]]:
"""
获取所有候选
@@ -205,7 +204,7 @@ class ExpressorModel:
logger.info(f"模型已从 {path} 加载")
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取模型统计信息"""
nb_stats = self.nb.get_stats()
return {

View File

@@ -4,7 +4,6 @@
"""
import math
from collections import Counter, defaultdict
from typing import Dict, List, Optional
from src.common.logger import get_logger
@@ -28,15 +27,15 @@ class OnlineNaiveBayes:
self.V = vocab_size
# 类别统计
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(
self.cls_counts: dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: dict[str, dict[str, float]] = defaultdict(
lambda: defaultdict(float)
) # cid -> term -> count
# 缓存
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
self._logZ: dict[str, float] = {} # cache log(∑counts + Vα)
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
def score_batch(self, tf: Counter, cids: list[str]) -> dict[str, float]:
"""
批量计算候选的贝叶斯分数
@@ -51,7 +50,7 @@ class OnlineNaiveBayes:
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
out: dict[str, float] = {}
for cid in cids:
# 计算先验概率 log P(c)
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
@@ -88,7 +87,7 @@ class OnlineNaiveBayes:
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: Optional[float] = None):
def decay(self, factor: float | None = None):
"""
知识衰减(遗忘机制)
@@ -133,7 +132,7 @@ class OnlineNaiveBayes:
if cid in self._logZ:
del self._logZ[cid]
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取统计信息"""
return {
"n_classes": len(self.cls_counts),

View File

@@ -1,7 +1,6 @@
"""
文本分词器支持中文Jieba分词
"""
from typing import List
from src.common.logger import get_logger
@@ -30,7 +29,7 @@ class Tokenizer:
logger.warning("Jieba未安装将使用字符级分词")
self.use_jieba = False
def tokenize(self, text: str) -> List[str]:
def tokenize(self, text: str) -> list[str]:
"""
分词并返回token列表

View File

@@ -2,7 +2,6 @@
情境提取器
从聊天历史中提取当前的情境situation用于 StyleLearner 预测
"""
from typing import Optional
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
@@ -41,17 +40,17 @@ def init_prompt():
class SituationExtractor:
"""情境提取器,从聊天历史中提取当前情境"""
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="expression.situation_extractor"
)
async def extract_situations(
self,
chat_history: list | str,
target_message: Optional[str] = None,
target_message: str | None = None,
max_situations: int = 3
) -> list[str]:
"""
@@ -68,18 +67,18 @@ class SituationExtractor:
# 转换chat_history为字符串
if isinstance(chat_history, list):
chat_info = "\n".join([
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
for msg in chat_history
])
else:
chat_info = chat_history
# 构建目标消息信息
if target_message:
target_message_info = f",现在你想要回复消息:{target_message}"
else:
target_message_info = ""
# 构建 prompt
try:
prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format(
@@ -87,31 +86,31 @@ class SituationExtractor:
chat_history=chat_info,
target_message_info=target_message_info
)
# 调用 LLM
response, _ = await self.llm_model.generate_response_async(
prompt=prompt,
temperature=0.3
)
if not response or not response.strip():
logger.warning("LLM返回空响应无法提取情境")
return []
# 解析响应
situations = self._parse_situations(response, max_situations)
if situations:
logger.debug(f"提取到 {len(situations)} 个情境: {situations}")
else:
logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}")
return situations
except Exception as e:
logger.error(f"提取情境失败: {e}")
return []
@staticmethod
def _parse_situations(response: str, max_situations: int) -> list[str]:
"""
@@ -125,33 +124,33 @@ class SituationExtractor:
情境描述列表
"""
situations = []
for line in response.splitlines():
line = line.strip()
if not line:
continue
# 移除可能的序号、引号等
line = line.lstrip('0123456789.、-*>)】] \t"\'""''')
line = line.rstrip('"\'""''')
line = line.strip()
if not line:
continue
# 过滤掉明显不是情境描述的内容
if len(line) > 30: # 太长
continue
if len(line) < 2: # 太短
continue
if any(keyword in line.lower() for keyword in ['例如', '注意', '', '分析', '总结']):
if any(keyword in line.lower() for keyword in ["例如", "注意", "", "分析", "总结"]):
continue
situations.append(line)
if len(situations) >= max_situations:
break
return situations

View File

@@ -5,7 +5,6 @@
"""
import os
import time
from typing import Dict, List, Optional, Tuple
from src.common.logger import get_logger
@@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner")
class StyleLearner:
"""单个聊天室的表达风格学习器"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
def __init__(self, chat_id: str, model_config: dict | None = None):
"""
Args:
chat_id: 聊天室ID
@@ -37,9 +36,9 @@ class StyleLearner:
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
self.id_to_style: dict[str, str] = {} # style_id -> style文本
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0
# 学习统计
@@ -51,7 +50,7 @@ class StyleLearner:
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
def add_style(self, style: str, situation: str | None = None) -> bool:
"""
动态添加一个新的风格
@@ -130,7 +129,7 @@ class StyleLearner:
logger.error(f"学习映射失败: {e}")
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
根据up_content预测最合适的style
@@ -146,7 +145,7 @@ class StyleLearner:
if not self.style_to_id:
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
return None, {}
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
@@ -155,7 +154,7 @@ class StyleLearner:
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
if best_style is None:
logger.warning(
f"style_id无法转换为style文本: style_id={best_style_id}, "
@@ -171,7 +170,7 @@ class StyleLearner:
style_scores[style_text] = score
else:
logger.warning(f"跳过无法转换的style_id: {sid}")
logger.debug(
f"预测成功: up_content={up_content[:30]}..., "
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
@@ -183,7 +182,7 @@ class StyleLearner:
logger.error(f"预测style失败: {e}", exc_info=True)
return None, {}
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
def get_style_info(self, style: str) -> tuple[str | None, str | None]:
"""
获取style的完整信息
@@ -200,7 +199,7 @@ class StyleLearner:
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_styles(self) -> List[str]:
def get_all_styles(self) -> list[str]:
"""
获取所有风格列表
@@ -209,7 +208,7 @@ class StyleLearner:
"""
return list(self.style_to_id.keys())
def apply_decay(self, factor: Optional[float] = None):
def apply_decay(self, factor: float | None = None):
"""
应用知识衰减
@@ -304,7 +303,7 @@ class StyleLearner:
logger.error(f"加载StyleLearner失败: {e}")
return False
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取统计信息"""
model_stats = self.expressor.get_stats()
return {
@@ -324,7 +323,7 @@ class StyleLearnerManager:
Args:
model_save_path: 模型保存路径
"""
self.learners: Dict[str, StyleLearner] = {}
self.learners: dict[str, StyleLearner] = {}
self.model_save_path = model_save_path
# 确保保存目录存在
@@ -332,7 +331,7 @@ class StyleLearnerManager:
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
@@ -369,7 +368,7 @@ class StyleLearnerManager:
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
"""
预测最合适的风格
@@ -399,7 +398,7 @@ class StyleLearnerManager:
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
return success
def apply_decay_all(self, factor: Optional[float] = None):
def apply_decay_all(self, factor: float | None = None):
"""
对所有学习器应用知识衰减
@@ -409,9 +408,9 @@ class StyleLearnerManager:
for learner in self.learners.values():
learner.apply_decay(factor)
logger.info(f"对所有StyleLearner应用知识衰减")
logger.info("对所有StyleLearner应用知识衰减")
def get_all_stats(self) -> Dict[str, Dict]:
def get_all_stats(self) -> dict[str, dict]:
"""
获取所有学习器的统计信息