refactor(core): 优化类型提示与代码风格
本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:
1. **类型提示现代化**:
- 将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
- 这提高了代码的可读性,并与较新 Python 版本的风格保持一致。
2. **代码风格统一**:
- 移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
- 统一了部分日志输出的格式,增强了日志的可读性。
3. **导入语句优化**:
- 调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。
这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
This commit is contained in:
@@ -9,24 +9,25 @@ from pathlib import Path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
|
||||
|
||||
async def check_database():
|
||||
"""检查表达方式数据库状态"""
|
||||
|
||||
|
||||
print("=" * 60)
|
||||
print("表达方式数据库诊断报告")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 1. 统计总数
|
||||
total_count = await session.execute(select(func.count()).select_from(Expression))
|
||||
total = total_count.scalar()
|
||||
print(f"\n📊 总表达方式数量: {total}")
|
||||
|
||||
|
||||
if total == 0:
|
||||
print("\n⚠️ 数据库为空!")
|
||||
print("\n可能的原因:")
|
||||
@@ -38,7 +39,7 @@ async def check_database():
|
||||
print("- 查看日志中是否有表达学习相关的错误")
|
||||
print("- 确认聊天流的 learn_expression 配置为 true")
|
||||
return
|
||||
|
||||
|
||||
# 2. 按 chat_id 统计
|
||||
print("\n📝 按聊天流统计:")
|
||||
chat_counts = await session.execute(
|
||||
@@ -47,7 +48,7 @@ async def check_database():
|
||||
)
|
||||
for chat_id, count in chat_counts:
|
||||
print(f" - {chat_id}: {count} 个表达方式")
|
||||
|
||||
|
||||
# 3. 按 type 统计
|
||||
print("\n📝 按类型统计:")
|
||||
type_counts = await session.execute(
|
||||
@@ -56,7 +57,7 @@ async def check_database():
|
||||
)
|
||||
for expr_type, count in type_counts:
|
||||
print(f" - {expr_type}: {count} 个")
|
||||
|
||||
|
||||
# 4. 检查 situation 和 style 字段是否有空值
|
||||
print("\n🔍 字段完整性检查:")
|
||||
null_situation = await session.execute(
|
||||
@@ -69,30 +70,30 @@ async def check_database():
|
||||
.select_from(Expression)
|
||||
.where(Expression.style == None)
|
||||
)
|
||||
|
||||
|
||||
null_sit_count = null_situation.scalar()
|
||||
null_sty_count = null_style.scalar()
|
||||
|
||||
|
||||
print(f" - situation 为空: {null_sit_count} 个")
|
||||
print(f" - style 为空: {null_sty_count} 个")
|
||||
|
||||
|
||||
if null_sit_count > 0 or null_sty_count > 0:
|
||||
print(" ⚠️ 发现空值!这会导致匹配失败")
|
||||
|
||||
|
||||
# 5. 显示一些样例数据
|
||||
print("\n📋 样例数据 (前10条):")
|
||||
samples = await session.execute(
|
||||
select(Expression)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
|
||||
for i, expr in enumerate(samples.scalars(), 1):
|
||||
print(f"\n [{i}] Chat: {expr.chat_id}")
|
||||
print(f" Type: {expr.type}")
|
||||
print(f" Situation: {expr.situation}")
|
||||
print(f" Style: {expr.style}")
|
||||
print(f" Count: {expr.count}")
|
||||
|
||||
|
||||
# 6. 检查 style 字段的唯一值
|
||||
print("\n📋 Style 字段样例 (前20个):")
|
||||
unique_styles = await session.execute(
|
||||
@@ -100,13 +101,13 @@ async def check_database():
|
||||
.distinct()
|
||||
.limit(20)
|
||||
)
|
||||
|
||||
|
||||
styles = [s for s in unique_styles.scalars()]
|
||||
for style in styles:
|
||||
print(f" - {style}")
|
||||
|
||||
|
||||
print(f"\n (共 {len(styles)} 个不同的 style)")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("诊断完成")
|
||||
print("=" * 60)
|
||||
|
||||
@@ -9,27 +9,28 @@ project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
|
||||
|
||||
async def analyze_style_fields():
|
||||
"""分析 style 字段的内容"""
|
||||
|
||||
|
||||
print("=" * 60)
|
||||
print("Style 字段内容分析")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 获取所有表达方式
|
||||
result = await session.execute(select(Expression).limit(30))
|
||||
expressions = result.scalars().all()
|
||||
|
||||
|
||||
print(f"\n总共检查 {len(expressions)} 条记录\n")
|
||||
|
||||
|
||||
# 按类型分类
|
||||
style_examples = []
|
||||
|
||||
|
||||
for expr in expressions:
|
||||
if expr.type == "style":
|
||||
style_examples.append({
|
||||
@@ -37,7 +38,7 @@ async def analyze_style_fields():
|
||||
"style": expr.style,
|
||||
"length": len(expr.style) if expr.style else 0
|
||||
})
|
||||
|
||||
|
||||
print("📋 Style 类型样例 (前15条):")
|
||||
print("="*60)
|
||||
for i, ex in enumerate(style_examples[:15], 1):
|
||||
@@ -45,17 +46,17 @@ async def analyze_style_fields():
|
||||
print(f" Situation: {ex['situation']}")
|
||||
print(f" Style: {ex['style']}")
|
||||
print(f" 长度: {ex['length']} 字符")
|
||||
|
||||
|
||||
# 判断是具体表达还是风格描述
|
||||
if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']):
|
||||
if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]):
|
||||
style_type = "✓ 风格描述"
|
||||
elif ex['length'] <= 10:
|
||||
elif ex["length"] <= 10:
|
||||
style_type = "? 可能是具体表达(较短)"
|
||||
else:
|
||||
style_type = "✗ 具体表达内容"
|
||||
|
||||
|
||||
print(f" 类型判断: {style_type}")
|
||||
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("分析完成")
|
||||
print("="*60)
|
||||
|
||||
@@ -16,28 +16,28 @@ logger = get_logger("debug_style_learner")
|
||||
|
||||
def check_style_learner_status(chat_id: str):
|
||||
"""检查指定 chat_id 的 StyleLearner 状态"""
|
||||
|
||||
|
||||
print("=" * 60)
|
||||
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 获取 learner
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
|
||||
# 1. 基本信息
|
||||
print(f"\n📊 基本信息:")
|
||||
print("\n📊 基本信息:")
|
||||
print(f" Chat ID: {learner.chat_id}")
|
||||
print(f" 风格数量: {len(learner.style_to_id)}")
|
||||
print(f" 下一个ID: {learner.next_style_id}")
|
||||
print(f" 最大风格数: {learner.max_styles}")
|
||||
|
||||
|
||||
# 2. 学习统计
|
||||
print(f"\n📈 学习统计:")
|
||||
print("\n📈 学习统计:")
|
||||
print(f" 总样本数: {learner.learning_stats['total_samples']}")
|
||||
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
|
||||
|
||||
|
||||
# 3. 风格列表(前20个)
|
||||
print(f"\n📋 已学习的风格 (前20个):")
|
||||
print("\n📋 已学习的风格 (前20个):")
|
||||
all_styles = learner.get_all_styles()
|
||||
if not all_styles:
|
||||
print(" ⚠️ 没有任何风格!模型尚未训练")
|
||||
@@ -47,9 +47,9 @@ def check_style_learner_status(chat_id: str):
|
||||
situation = learner.id_to_situation.get(style_id, "N/A")
|
||||
print(f" [{i}] {style}")
|
||||
print(f" (ID: {style_id}, Situation: {situation})")
|
||||
|
||||
|
||||
# 4. 测试预测
|
||||
print(f"\n🔮 测试预测功能:")
|
||||
print("\n🔮 测试预测功能:")
|
||||
if not all_styles:
|
||||
print(" ⚠️ 无法测试,模型没有训练数据")
|
||||
else:
|
||||
@@ -58,19 +58,19 @@ def check_style_learner_status(chat_id: str):
|
||||
"讨论游戏",
|
||||
"表达赞同"
|
||||
]
|
||||
|
||||
|
||||
for test_sit in test_situations:
|
||||
print(f"\n 测试输入: '{test_sit}'")
|
||||
best_style, scores = learner.predict_style(test_sit, top_k=3)
|
||||
|
||||
|
||||
if best_style:
|
||||
print(f" ✓ 最佳匹配: {best_style}")
|
||||
print(f" Top 3:")
|
||||
print(" Top 3:")
|
||||
for style, score in list(scores.items())[:3]:
|
||||
print(f" - {style}: {score:.4f}")
|
||||
else:
|
||||
print(f" ✗ 预测失败")
|
||||
|
||||
print(" ✗ 预测失败")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("诊断完成")
|
||||
print("=" * 60)
|
||||
@@ -82,7 +82,7 @@ if __name__ == "__main__":
|
||||
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
|
||||
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
|
||||
]
|
||||
|
||||
|
||||
for chat_id in test_chat_ids:
|
||||
check_style_learner_status(chat_id)
|
||||
print("\n")
|
||||
|
||||
Reference in New Issue
Block a user