refactor(core): 优化类型提示与代码风格

本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:

1.  **类型提示现代化**:
    -   将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
    -   这提高了代码的可读性,并与较新 Python 版本的风格保持一致。

2.  **代码风格统一**:
    -   移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
    -   统一了部分日志输出的格式,增强了日志的可读性。

3.  **导入语句优化**:
    -   调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。

这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
This commit is contained in:
minecraft1024a
2025-10-31 20:56:17 +08:00
parent 926adf16dd
commit a29be48091
47 changed files with 923 additions and 933 deletions

View File

@@ -9,24 +9,25 @@ from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select, func
from sqlalchemy import func, select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
async def check_database():
"""检查表达方式数据库状态"""
print("=" * 60)
print("表达方式数据库诊断报告")
print("=" * 60)
async with get_db_session() as session:
# 1. 统计总数
total_count = await session.execute(select(func.count()).select_from(Expression))
total = total_count.scalar()
print(f"\n📊 总表达方式数量: {total}")
if total == 0:
print("\n⚠️ 数据库为空!")
print("\n可能的原因:")
@@ -38,7 +39,7 @@ async def check_database():
print("- 查看日志中是否有表达学习相关的错误")
print("- 确认聊天流的 learn_expression 配置为 true")
return
# 2. 按 chat_id 统计
print("\n📝 按聊天流统计:")
chat_counts = await session.execute(
@@ -47,7 +48,7 @@ async def check_database():
)
for chat_id, count in chat_counts:
print(f" - {chat_id}: {count} 个表达方式")
# 3. 按 type 统计
print("\n📝 按类型统计:")
type_counts = await session.execute(
@@ -56,7 +57,7 @@ async def check_database():
)
for expr_type, count in type_counts:
print(f" - {expr_type}: {count}")
# 4. 检查 situation 和 style 字段是否有空值
print("\n🔍 字段完整性检查:")
null_situation = await session.execute(
@@ -69,30 +70,30 @@ async def check_database():
.select_from(Expression)
.where(Expression.style == None)
)
null_sit_count = null_situation.scalar()
null_sty_count = null_style.scalar()
print(f" - situation 为空: {null_sit_count}")
print(f" - style 为空: {null_sty_count}")
if null_sit_count > 0 or null_sty_count > 0:
print(" ⚠️ 发现空值!这会导致匹配失败")
# 5. 显示一些样例数据
print("\n📋 样例数据 (前10条):")
samples = await session.execute(
select(Expression)
.limit(10)
)
for i, expr in enumerate(samples.scalars(), 1):
print(f"\n [{i}] Chat: {expr.chat_id}")
print(f" Type: {expr.type}")
print(f" Situation: {expr.situation}")
print(f" Style: {expr.style}")
print(f" Count: {expr.count}")
# 6. 检查 style 字段的唯一值
print("\n📋 Style 字段样例 (前20个):")
unique_styles = await session.execute(
@@ -100,13 +101,13 @@ async def check_database():
.distinct()
.limit(20)
)
styles = [s for s in unique_styles.scalars()]
for style in styles:
print(f" - {style}")
print(f"\n (共 {len(styles)} 个不同的 style)")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)

View File

@@ -9,27 +9,28 @@ project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
async def analyze_style_fields():
"""分析 style 字段的内容"""
print("=" * 60)
print("Style 字段内容分析")
print("=" * 60)
async with get_db_session() as session:
# 获取所有表达方式
result = await session.execute(select(Expression).limit(30))
expressions = result.scalars().all()
print(f"\n总共检查 {len(expressions)} 条记录\n")
# 按类型分类
style_examples = []
for expr in expressions:
if expr.type == "style":
style_examples.append({
@@ -37,7 +38,7 @@ async def analyze_style_fields():
"style": expr.style,
"length": len(expr.style) if expr.style else 0
})
print("📋 Style 类型样例 (前15条):")
print("="*60)
for i, ex in enumerate(style_examples[:15], 1):
@@ -45,17 +46,17 @@ async def analyze_style_fields():
print(f" Situation: {ex['situation']}")
print(f" Style: {ex['style']}")
print(f" 长度: {ex['length']} 字符")
# 判断是具体表达还是风格描述
if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']):
if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]):
style_type = "✓ 风格描述"
elif ex['length'] <= 10:
elif ex["length"] <= 10:
style_type = "? 可能是具体表达(较短)"
else:
style_type = "✗ 具体表达内容"
print(f" 类型判断: {style_type}")
print("\n" + "="*60)
print("分析完成")
print("="*60)

View File

@@ -16,28 +16,28 @@ logger = get_logger("debug_style_learner")
def check_style_learner_status(chat_id: str):
"""检查指定 chat_id 的 StyleLearner 状态"""
print("=" * 60)
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
print("=" * 60)
# 获取 learner
learner = style_learner_manager.get_learner(chat_id)
# 1. 基本信息
print(f"\n📊 基本信息:")
print("\n📊 基本信息:")
print(f" Chat ID: {learner.chat_id}")
print(f" 风格数量: {len(learner.style_to_id)}")
print(f" 下一个ID: {learner.next_style_id}")
print(f" 最大风格数: {learner.max_styles}")
# 2. 学习统计
print(f"\n📈 学习统计:")
print("\n📈 学习统计:")
print(f" 总样本数: {learner.learning_stats['total_samples']}")
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
# 3. 风格列表前20个
print(f"\n📋 已学习的风格 (前20个):")
print("\n📋 已学习的风格 (前20个):")
all_styles = learner.get_all_styles()
if not all_styles:
print(" ⚠️ 没有任何风格!模型尚未训练")
@@ -47,9 +47,9 @@ def check_style_learner_status(chat_id: str):
situation = learner.id_to_situation.get(style_id, "N/A")
print(f" [{i}] {style}")
print(f" (ID: {style_id}, Situation: {situation})")
# 4. 测试预测
print(f"\n🔮 测试预测功能:")
print("\n🔮 测试预测功能:")
if not all_styles:
print(" ⚠️ 无法测试,模型没有训练数据")
else:
@@ -58,19 +58,19 @@ def check_style_learner_status(chat_id: str):
"讨论游戏",
"表达赞同"
]
for test_sit in test_situations:
print(f"\n 测试输入: '{test_sit}'")
best_style, scores = learner.predict_style(test_sit, top_k=3)
if best_style:
print(f" ✓ 最佳匹配: {best_style}")
print(f" Top 3:")
print(" Top 3:")
for style, score in list(scores.items())[:3]:
print(f" - {style}: {score:.4f}")
else:
print(f" ✗ 预测失败")
print(" ✗ 预测失败")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)
@@ -82,7 +82,7 @@ if __name__ == "__main__":
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
]
for chat_id in test_chat_ids:
check_style_learner_status(chat_id)
print("\n")