🤖 自动格式化代码 [skip ci]
This commit is contained in:
@@ -12,6 +12,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
expression_evaluation_prompt = """
|
||||
你的名字是{bot_name}
|
||||
@@ -42,30 +43,32 @@ def init_prompt():
|
||||
"""
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
|
||||
"""按权重随机抽样"""
|
||||
if not population or not weights or k <= 0:
|
||||
return []
|
||||
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
|
||||
# 使用累积权重的方法进行加权抽样
|
||||
selected = []
|
||||
population_copy = population.copy()
|
||||
weights_copy = weights.copy()
|
||||
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
|
||||
|
||||
# 选择一个元素
|
||||
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
|
||||
selected.append(population_copy.pop(chosen_idx))
|
||||
weights_copy.pop(chosen_idx)
|
||||
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.expression_learner = get_expression_learner()
|
||||
@@ -75,7 +78,9 @@ class ExpressionSelector:
|
||||
request_type="expression.selector",
|
||||
)
|
||||
|
||||
def get_random_expressions(self, chat_id: str, style_num: int, grammar_num: int, personality_num: int) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, style_num: int, grammar_num: int, personality_num: int
|
||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||
(
|
||||
learnt_style_expressions,
|
||||
learnt_grammar_expressions,
|
||||
@@ -88,13 +93,13 @@ class ExpressionSelector:
|
||||
selected_style = weighted_sample(learnt_style_expressions, style_weights, style_num)
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
|
||||
if learnt_grammar_expressions:
|
||||
grammar_weights = [expr.get("count", 1) for expr in learnt_grammar_expressions]
|
||||
selected_grammar = weighted_sample(learnt_grammar_expressions, grammar_weights, grammar_num)
|
||||
else:
|
||||
selected_grammar = []
|
||||
|
||||
|
||||
if personality_expressions:
|
||||
personality_weights = [expr.get("count", 1) for expr in personality_expressions]
|
||||
selected_personality = weighted_sample(personality_expressions, personality_weights, personality_num)
|
||||
@@ -102,7 +107,7 @@ class ExpressionSelector:
|
||||
selected_personality = []
|
||||
|
||||
return selected_style, selected_grammar, selected_personality
|
||||
|
||||
|
||||
def update_expression_count(self, chat_id: str, expression: Dict[str, str], multiplier: float = 1.5):
|
||||
"""更新表达方式的count值"""
|
||||
if expression.get("type") == "style_personality":
|
||||
@@ -117,29 +122,30 @@ class ExpressionSelector:
|
||||
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
|
||||
|
||||
# 找到匹配的表达方式并更新count
|
||||
for expr in expressions:
|
||||
if (expr.get("situation") == expression.get("situation") and
|
||||
expr.get("style") == expression.get("style")):
|
||||
if expr.get("situation") == expression.get("situation") and expr.get("style") == expression.get(
|
||||
"style"
|
||||
):
|
||||
expr["count"] = expr.get("count", 1) * multiplier
|
||||
expr["last_active_time"] = time.time()
|
||||
break
|
||||
|
||||
|
||||
# 保存更新后的文件
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(expressions, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新表达方式count失败: {e}")
|
||||
|
||||
|
||||
async def select_suitable_expressions_llm(self, chat_id: str, chat_info: str) -> List[Dict[str, str]]:
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
@@ -188,7 +194,7 @@ class ExpressionSelector:
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
|
||||
|
||||
# 4. 调用LLM
|
||||
try:
|
||||
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
@@ -216,7 +222,7 @@ class ExpressionSelector:
|
||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
valid_expressions.append(expression)
|
||||
|
||||
|
||||
# 对选中的表达方式count数*1.5
|
||||
self.update_expression_count(chat_id, expression, 1.5)
|
||||
|
||||
@@ -226,7 +232,7 @@ class ExpressionSelector:
|
||||
except Exception as e:
|
||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -234,10 +240,3 @@ try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
print(f"ExpressionSelector初始化失败: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user