feat(expression_selector): 添加温度采样功能以优化表达选择

feat(official_configs): 新增模型温度配置项以支持表达模型采样
chore(bot_config_template): 更新版本号并添加模型温度说明
This commit is contained in:
Windpicker-owo
2025-12-11 13:57:17 +08:00
parent cc531d1b97
commit c75cc88fb5
4 changed files with 61 additions and 10 deletions

View File

@@ -1,5 +1,6 @@
import asyncio
import hashlib
import math
import random
import time
from typing import Any
@@ -76,6 +77,45 @@ def weighted_sample(population: list[dict], weights: list[float], k: int) -> lis
class ExpressionSelector:
@staticmethod
def _sample_with_temperature(
candidates: list[tuple[Any, float, float, str]],
max_num: int,
temperature: float,
) -> list[tuple[Any, float, float, str]]:
"""
对候选表达按温度采样,温度越高越均匀。
Args:
candidates: (expr, similarity, count, best_predicted) 列表
max_num: 需要返回的数量
temperature: 温度参数0 表示贪婪选择
"""
if max_num <= 0 or not candidates:
return []
if temperature <= 0:
return candidates[:max_num]
adjusted_temp = max(temperature, 1e-6)
# 使用与排序相同的打分,但通过 softmax/temperature 放大尾部概率
scores = [max(c[1] * (c[2] ** 0.5), 1e-8) for c in candidates]
max_score = max(scores)
weights = [math.exp((s - max_score) / adjusted_temp) for s in scores]
# 始终保留最高分一个,剩余的按温度采样,避免过度集中
best_idx = scores.index(max_score)
selected = [candidates[best_idx]]
remaining_indices = [i for i in range(len(candidates)) if i != best_idx]
while remaining_indices and len(selected) < max_num:
current_weights = [weights[i] for i in remaining_indices]
picked_pos = random.choices(range(len(remaining_indices)), weights=current_weights, k=1)[0]
picked_idx = remaining_indices.pop(picked_pos)
selected.append(candidates[picked_idx])
return selected
def __init__(self, chat_id: str = ""):
self.chat_id = chat_id
if model_config is None:
@@ -517,12 +557,21 @@ class ExpressionSelector:
)
return []
# 按照相似度*count排序选择最佳匹配
# 按照相似度*count排序并根据温度采样,避免过度集中
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
temperature = getattr(global_config.expression, "model_temperature", 0.0)
sampled_matches = self._sample_with_temperature(
candidates=matched_expressions,
max_num=max_num,
temperature=temperature,
)
expressions_objs = [e[0] for e in sampled_matches]
# 显示最佳匹配的详细信息
logger.debug(f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式")
logger.debug(
f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式 "
f"(候选 {len(matched_expressions)}temperature={temperature})"
)
# 转换为字典格式
expressions = [