三次修改
This commit is contained in:
@@ -136,18 +136,18 @@ class ExpressionSelector:
|
||||
|
||||
return related_chat_ids if related_chat_ids else [chat_id]
|
||||
|
||||
def get_random_expressions(
|
||||
async def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = session.execute(
|
||||
style_query = await session.execute(
|
||||
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style"))
|
||||
)
|
||||
grammar_query = session.execute(
|
||||
grammar_query = await session.execute(
|
||||
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar"))
|
||||
)
|
||||
|
||||
@@ -193,7 +193,7 @@ class ExpressionSelector:
|
||||
|
||||
return selected_style, selected_grammar
|
||||
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
async def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
@@ -210,26 +210,27 @@ class ExpressionSelector:
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
with get_db_session() as session:
|
||||
query = session.execute(
|
||||
async with get_db_session() as session:
|
||||
query = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
session.commit()
|
||||
query = query.scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
@@ -248,7 +249,7 @@ class ExpressionSelector:
|
||||
return []
|
||||
|
||||
# 1. 获取35个随机表达方式(现在按权重抽取)
|
||||
style_exprs, grammar_exprs = self.get_random_expressions(chat_id, 30, 0.5, 0.5)
|
||||
style_exprs, grammar_exprs = await self.get_random_expressions(chat_id, 30, 0.5, 0.5)
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions = []
|
||||
@@ -334,7 +335,7 @@ class ExpressionSelector:
|
||||
|
||||
# 对选中的所有表达方式,一次性更新count数
|
||||
if valid_expressions:
|
||||
self.update_expressions_count_batch(valid_expressions, 0.006)
|
||||
await self.update_expressions_count_batch(valid_expressions, 0.006)
|
||||
|
||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions
|
||||
|
||||
Reference in New Issue
Block a user