fix:修复表达提取无法提高count的问题
This commit is contained in:
@@ -108,55 +108,64 @@ class ExpressionSelector:
|
||||
|
||||
return selected_style, selected_grammar, selected_personality
|
||||
|
||||
def update_expression_count(self, chat_id: str, expression: Dict[str, str], increment: float = 0.1):
|
||||
"""更新表达方式的count值
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
expression: 表达方式字典
|
||||
increment: 增量值,默认0.1
|
||||
"""
|
||||
if expression.get("type") == "style_personality":
|
||||
# personality表达方式存储在全局文件中
|
||||
file_path = os.path.join("data", "expression", "personality", "expressions.json")
|
||||
else:
|
||||
# style和grammar表达方式存储在对应chat_id目录中
|
||||
expr_type = expression.get("type", "style")
|
||||
if expr_type == "style":
|
||||
file_path = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||
elif expr_type == "grammar":
|
||||
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||
else:
|
||||
return
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按文件分组后一次性写入"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
updates_by_file = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id = expr.get("source_id")
|
||||
if not source_id:
|
||||
logger.warning(f"表达方式缺少source_id,无法更新: {expr}")
|
||||
continue
|
||||
|
||||
# 找到匹配的表达方式并更新count
|
||||
for expr in expressions:
|
||||
if expr.get("situation") == expression.get("situation") and expr.get("style") == expression.get(
|
||||
"style"
|
||||
):
|
||||
current_count = expr.get("count", 1)
|
||||
file_path = ""
|
||||
if source_id == "personality":
|
||||
file_path = os.path.join("data", "expression", "personality", "expressions.json")
|
||||
else:
|
||||
chat_id = source_id
|
||||
expr_type = expr.get("type", "style")
|
||||
if expr_type == "style":
|
||||
file_path = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||
elif expr_type == "grammar":
|
||||
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||
|
||||
# 简单加0.1,但限制最高为5
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr["count"] = new_count
|
||||
expr["last_active_time"] = time.time()
|
||||
if file_path:
|
||||
if file_path not in updates_by_file:
|
||||
updates_by_file[file_path] = []
|
||||
updates_by_file[file_path].append(expr)
|
||||
|
||||
logger.info(f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f}")
|
||||
break
|
||||
for file_path, updates in updates_by_file.items():
|
||||
if not os.path.exists(file_path):
|
||||
continue
|
||||
|
||||
# 保存更新后的文件
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(expressions, f, ensure_ascii=False, indent=2)
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
all_expressions = json.load(f)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新表达方式count失败: {e}")
|
||||
# Create a dictionary for quick lookup
|
||||
expr_map = {(e.get("situation"), e.get("style")): e for e in all_expressions}
|
||||
|
||||
# Update counts in memory
|
||||
for expr_to_update in updates:
|
||||
key = (expr_to_update.get("situation"), expr_to_update.get("style"))
|
||||
if key in expr_map:
|
||||
expr_in_map = expr_map[key]
|
||||
current_count = expr_in_map.get("count", 1)
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_in_map["count"] = new_count
|
||||
expr_in_map["last_active_time"] = time.time()
|
||||
logger.info(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in {file_path}"
|
||||
)
|
||||
|
||||
# Save the updated list once for this file
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(all_expressions, f, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5
|
||||
@@ -237,8 +246,9 @@ class ExpressionSelector:
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的表达方式count数+0.1
|
||||
self.update_expression_count(chat_id, expression, 0.001)
|
||||
# 对选中的所有表达方式,一次性更新count数
|
||||
if valid_expressions:
|
||||
self.update_expressions_count_batch(valid_expressions, 0.003)
|
||||
|
||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions
|
||||
|
||||
Reference in New Issue
Block a user