refactor(core): 优化类型提示与代码风格
本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:
1. **类型提示现代化**:
- 将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
- 这提高了代码的可读性,并与较新 Python 版本的风格保持一致。
2. **代码风格统一**:
- 移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
- 统一了部分日志输出的格式,增强了日志的可读性。
3. **导入语句优化**:
- 调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。
这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
This commit is contained in:
@@ -267,11 +267,11 @@ class ExpressionSelector:
|
||||
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
|
||||
else:
|
||||
chat_info = chat_history
|
||||
|
||||
|
||||
# 根据配置选择模式
|
||||
mode = global_config.expression.mode
|
||||
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
|
||||
|
||||
|
||||
if mode == "exp_model":
|
||||
return await self._select_expressions_model_only(
|
||||
chat_id=chat_id,
|
||||
@@ -288,7 +288,7 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -298,7 +298,7 @@ class ExpressionSelector:
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""经典模式:随机抽样 + LLM评估"""
|
||||
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
|
||||
logger.debug("[Classic模式] 使用LLM评估表达方式")
|
||||
return await self.select_suitable_expressions_llm(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -306,7 +306,7 @@ class ExpressionSelector:
|
||||
min_num=min_num,
|
||||
target_message=target_message
|
||||
)
|
||||
|
||||
|
||||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -316,22 +316,22 @@ class ExpressionSelector:
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""模型预测模式:先提取情境,再使用StyleLearner预测表达风格"""
|
||||
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
|
||||
|
||||
logger.debug("[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return []
|
||||
|
||||
|
||||
# 步骤1: 提取聊天情境
|
||||
situations = await situation_extractor.extract_situations(
|
||||
chat_history=chat_info,
|
||||
target_message=target_message,
|
||||
max_situations=3
|
||||
)
|
||||
|
||||
|
||||
if not situations:
|
||||
logger.warning(f"无法提取聊天情境,回退到经典模式")
|
||||
logger.warning("无法提取聊天情境,回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -339,17 +339,17 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}")
|
||||
|
||||
|
||||
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
|
||||
all_predicted_styles = {}
|
||||
for i, situation in enumerate(situations, 1):
|
||||
logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}")
|
||||
best_style, scores = learner.predict_style(situation, top_k=max_num)
|
||||
|
||||
|
||||
if best_style and scores:
|
||||
logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}")
|
||||
# 合并分数(取最高分)
|
||||
@@ -357,10 +357,10 @@ class ExpressionSelector:
|
||||
if style not in all_predicted_styles or score > all_predicted_styles[style]:
|
||||
all_predicted_styles[style] = score
|
||||
else:
|
||||
logger.debug(f" 该情境未返回预测结果")
|
||||
|
||||
logger.debug(" 该情境未返回预测结果")
|
||||
|
||||
if not all_predicted_styles:
|
||||
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
||||
logger.warning("[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -368,22 +368,22 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
# 将分数字典转换为列表格式 [(style, score), ...]
|
||||
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
|
||||
|
||||
|
||||
# 步骤3: 根据预测的风格从数据库获取表达方式
|
||||
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
|
||||
logger.debug("[Exp_model模式] 步骤3 - 从数据库查询表达方式")
|
||||
expressions = await self.get_model_predicted_expressions(
|
||||
chat_id=chat_id,
|
||||
predicted_styles=predicted_styles,
|
||||
max_num=max_num
|
||||
)
|
||||
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
|
||||
logger.warning("[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -391,10 +391,10 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
|
||||
async def get_model_predicted_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -414,15 +414,15 @@ class ExpressionSelector:
|
||||
"""
|
||||
if not predicted_styles:
|
||||
return []
|
||||
|
||||
|
||||
# 提取风格名称(前3个最佳匹配)
|
||||
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
|
||||
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
|
||||
|
||||
|
||||
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id(支持共享表达方式)
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}")
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
|
||||
db_chat_ids_result = await session.execute(
|
||||
@@ -432,7 +432,7 @@ class ExpressionSelector:
|
||||
)
|
||||
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
|
||||
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
|
||||
|
||||
|
||||
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
|
||||
all_expressions_result = await session.execute(
|
||||
select(Expression)
|
||||
@@ -440,51 +440,51 @@ class ExpressionSelector:
|
||||
.where(Expression.type == "style")
|
||||
)
|
||||
all_expressions = list(all_expressions_result.scalars())
|
||||
|
||||
|
||||
logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
|
||||
|
||||
|
||||
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
|
||||
if not all_expressions:
|
||||
logger.info(f"相关chat_id没有数据,尝试从所有chat_id查询")
|
||||
logger.info("相关chat_id没有数据,尝试从所有chat_id查询")
|
||||
all_expressions_result = await session.execute(
|
||||
select(Expression)
|
||||
.where(Expression.type == "style")
|
||||
)
|
||||
all_expressions = list(all_expressions_result.scalars())
|
||||
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
|
||||
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
|
||||
logger.warning("数据库中完全没有任何表达方式,需要先学习")
|
||||
return []
|
||||
|
||||
|
||||
# 🔥 使用模糊匹配而不是精确匹配
|
||||
# 计算每个预测style与数据库style的相似度
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
matched_expressions = []
|
||||
for expr in all_expressions:
|
||||
db_style = expr.style or ""
|
||||
max_similarity = 0.0
|
||||
best_predicted = ""
|
||||
|
||||
|
||||
# 与每个预测的style计算相似度
|
||||
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
|
||||
# 计算字符串相似度
|
||||
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
|
||||
|
||||
|
||||
# 也检查包含关系(如果一个是另一个的子串,给更高分)
|
||||
if len(predicted_style) >= 2 and len(db_style) >= 2:
|
||||
if predicted_style in db_style or db_style in predicted_style:
|
||||
similarity = max(similarity, 0.7)
|
||||
|
||||
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_predicted = predicted_style
|
||||
|
||||
|
||||
# 🔥 降低阈值到30%,因为StyleLearner预测质量较差
|
||||
if max_similarity >= 0.3: # 30%相似度阈值
|
||||
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
|
||||
|
||||
|
||||
if not matched_expressions:
|
||||
# 收集数据库中的style样例用于调试
|
||||
all_styles = [e.style for e in all_expressions[:10]]
|
||||
@@ -495,11 +495,11 @@ class ExpressionSelector:
|
||||
f" 提示: StyleLearner预测质量差,建议重新训练或使用classic模式"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
# 按照相似度*count排序,选择最佳匹配
|
||||
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
|
||||
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
|
||||
|
||||
|
||||
# 显示最佳匹配的详细信息
|
||||
top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]]
|
||||
logger.info(
|
||||
@@ -507,7 +507,7 @@ class ExpressionSelector:
|
||||
f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n"
|
||||
f" Top3匹配: {top_matches}"
|
||||
)
|
||||
|
||||
|
||||
# 转换为字典格式
|
||||
expressions = []
|
||||
for expr in expressions_objs:
|
||||
@@ -518,7 +518,7 @@ class ExpressionSelector:
|
||||
"count": float(expr.count) if expr.count else 0.0,
|
||||
"last_active_time": expr.last_active_time or 0.0
|
||||
})
|
||||
|
||||
|
||||
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
|
||||
Reference in New Issue
Block a user