Files
Mofox-Core/src/chat/express/expression_selector.py
雅诺狐 2d4745cd58 初始化
2025-08-11 19:34:18 +08:00

340 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import time
import random
import hashlib
from typing import List, Dict, Tuple, Optional, Any
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_session
session = get_session()
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容:
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境:
{all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的{min_num}-{max_num}个情境。
考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等)
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如:
{{
"selected_situations": [2, 3, 5, 7, 19, 22, 25, 38, 39, 45, 48, 64]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
"""按权重随机抽样"""
if not population or not weights or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用累积权重的方法进行加权抽样
selected = []
population_copy = population.copy()
weights_copy = weights.copy()
for _ in range(k):
if not population_copy:
break
# 选择一个元素
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
selected.append(population_copy.pop(chosen_idx))
weights_copy.pop(chosen_idx)
return selected
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
))
grammar_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
))
style_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query.scalars()
]
grammar_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in grammar_query.scalars()
]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
selected_style = weighted_sample(style_exprs, style_weights, style_num)
else:
selected_style = []
if grammar_exprs:
grammar_weights = [expr.get("count", 1) for expr in grammar_exprs]
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
else:
selected_grammar = []
return selected_style, selected_grammar
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)).scalar()
if query:
expr_obj = query
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
session.commit()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
min_num: int = 5,
target_message: Optional[str] = None,
) -> List[Dict[str, Any]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 1. 获取35个随机表达方式现在按权重抽取
style_exprs, grammar_exprs = self.get_random_expressions(chat_id, 30, 0.5, 0.5)
# 2. 构建所有表达方式的索引和情境列表
all_expressions = []
all_situations = []
# 添加style表达方式
for expr in style_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_with_type = expr.copy()
expr_with_type["type"] = "style"
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
# 添加grammar表达方式
for expr in grammar_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_with_type = expr.copy()
expr_with_type["type"] = "grammar"
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
min_num=min_num,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# print(prompt)
# 4. 调用LLM
try:
# start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
# logger.info(f"模型名称: {model_name}")
# logger.info(f"LLM返回结果: {content}")
# if reasoning_content:
# logger.info(f"LLM推理: {reasoning_content}")
# else:
# logger.info(f"LLM推理: 无")
if not content:
logger.warning("LLM返回空结果")
return []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
valid_expressions.append(expression)
# 对选中的所有表达方式一次性更新count数
if valid_expressions:
self.update_expressions_count_batch(valid_expressions, 0.006)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions
except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}")
return []
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
print(f"ExpressionSelector初始化失败: {e}")