优化表达方式学习
This commit is contained in:
@@ -207,31 +207,20 @@ class ExpressionSelector:
|
||||
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar"))
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
# 🔥 优化:提前定义转换函数,避免重复代码
|
||||
def expr_to_dict(expr, expr_type: str) -> dict[str, Any]:
|
||||
return {
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"type": expr_type,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query.scalars()
|
||||
]
|
||||
|
||||
grammar_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in grammar_query.scalars()
|
||||
]
|
||||
|
||||
style_exprs = [expr_to_dict(expr, "style") for expr in style_query.scalars()]
|
||||
grammar_exprs = [expr_to_dict(expr, "grammar") for expr in grammar_query.scalars()]
|
||||
|
||||
style_num = int(total_num * style_percentage)
|
||||
grammar_num = int(total_num * grammar_percentage)
|
||||
@@ -251,9 +240,14 @@ class ExpressionSelector:
|
||||
|
||||
@staticmethod
|
||||
async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库
|
||||
|
||||
🔥 优化:合并所有更新到一个事务中,减少数据库连接开销
|
||||
"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
|
||||
# 去重处理
|
||||
updates_by_key = {}
|
||||
affected_chat_ids = set()
|
||||
for expr in expressions_to_update:
|
||||
@@ -269,9 +263,15 @@ class ExpressionSelector:
|
||||
updates_by_key[key] = expr
|
||||
affected_chat_ids.add(source_id)
|
||||
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
async with get_db_session() as session:
|
||||
query = await session.execute(
|
||||
if not updates_by_key:
|
||||
return
|
||||
|
||||
# 🔥 优化:使用单个 session 批量处理所有更新
|
||||
current_time = time.time()
|
||||
async with get_db_session() as session:
|
||||
updated_count = 0
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
query_result = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
@@ -279,25 +279,26 @@ class ExpressionSelector:
|
||||
& (Expression.style == style)
|
||||
)
|
||||
)
|
||||
query = query.scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
expr_obj = query_result.scalar()
|
||||
if expr_obj:
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.last_active_time = current_time
|
||||
updated_count += 1
|
||||
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
# 批量提交所有更改
|
||||
if updated_count > 0:
|
||||
await session.commit()
|
||||
logger.debug(f"批量更新了 {updated_count} 个表达方式的count值")
|
||||
|
||||
# 清除所有受影响的chat_id的缓存
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
cache = await get_cache()
|
||||
for chat_id in affected_chat_ids:
|
||||
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
||||
if affected_chat_ids:
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
cache = await get_cache()
|
||||
for chat_id in affected_chat_ids:
|
||||
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
@@ -518,29 +519,41 @@ class ExpressionSelector:
|
||||
logger.warning("数据库中完全没有任何表达方式,需要先学习")
|
||||
return []
|
||||
|
||||
# 🔥 使用模糊匹配而不是精确匹配
|
||||
# 计算每个预测style与数据库style的相似度
|
||||
# 🔥 优化:使用更高效的模糊匹配算法
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
# 预处理:提前计算所有预测 style 的小写版本,避免重复计算
|
||||
predicted_styles_lower = [(s.lower(), score) for s, score in predicted_styles[:20]]
|
||||
|
||||
matched_expressions = []
|
||||
for expr in all_expressions:
|
||||
db_style = expr.style or ""
|
||||
db_style_lower = db_style.lower()
|
||||
max_similarity = 0.0
|
||||
best_predicted = ""
|
||||
|
||||
# 与每个预测的style计算相似度
|
||||
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
|
||||
# 计算字符串相似度
|
||||
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
|
||||
|
||||
# 也检查包含关系(如果一个是另一个的子串,给更高分)
|
||||
if len(predicted_style) >= 2 and len(db_style) >= 2:
|
||||
if predicted_style in db_style or db_style in predicted_style:
|
||||
similarity = max(similarity, 0.7)
|
||||
|
||||
for predicted_style_lower, pred_score in predicted_styles_lower:
|
||||
# 快速检查:完全匹配
|
||||
if predicted_style_lower == db_style_lower:
|
||||
max_similarity = 1.0
|
||||
best_predicted = predicted_style_lower
|
||||
break
|
||||
|
||||
# 快速检查:子串匹配
|
||||
if len(predicted_style_lower) >= 2 and len(db_style_lower) >= 2:
|
||||
if predicted_style_lower in db_style_lower or db_style_lower in predicted_style_lower:
|
||||
similarity = 0.7
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_predicted = predicted_style_lower
|
||||
continue
|
||||
|
||||
# 计算字符串相似度(较慢,只在必要时使用)
|
||||
similarity = SequenceMatcher(None, predicted_style_lower, db_style_lower).ratio()
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_predicted = predicted_style
|
||||
best_predicted = predicted_style_lower
|
||||
|
||||
# 🔥 降低阈值到30%,因为StyleLearner预测质量较差
|
||||
if max_similarity >= 0.3: # 30%相似度阈值
|
||||
@@ -573,14 +586,15 @@ class ExpressionSelector:
|
||||
f"(候选 {len(matched_expressions)},temperature={temperature})"
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
# 🔥 优化:使用列表推导式和预定义函数减少开销
|
||||
expressions = [
|
||||
{
|
||||
"situation": expr.situation or "",
|
||||
"style": expr.style or "",
|
||||
"type": expr.type or "style",
|
||||
"count": float(expr.count) if expr.count else 0.0,
|
||||
"last_active_time": expr.last_active_time or 0.0
|
||||
"last_active_time": expr.last_active_time or 0.0,
|
||||
"source_id": expr.chat_id # 添加 source_id 以便后续更新
|
||||
}
|
||||
for expr in expressions_objs
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user