Files
Mofox-Core/src/chat/express/expression_selector.py
LuiKlee 0feb878830 ruff
2025-12-16 16:18:59 +08:00

719 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import hashlib
import math
import random
import time
from typing import Any
import orjson
from json_repair import repair_json
from sqlalchemy import select
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Expression
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
# 导入StyleLearner管理器和情境提取器
from .situation_extractor import situation_extractor
from .style_learner import style_learner_manager
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容:
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境:
{all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的{min_num}-{max_num}个情境。
考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等)
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如:
{{
"selected_situations": [2, 3, 5, 7, 19, 22, 25, 38, 39, 45, 48, 64]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: list[dict], weights: list[float], k: int) -> list[dict]:
"""按权重随机抽样"""
if not population or not weights or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用累积权重的方法进行加权抽样
selected = []
population_copy = population.copy()
weights_copy = weights.copy()
for _ in range(k):
if not population_copy:
break
# 选择一个元素
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
selected.append(population_copy.pop(chosen_idx))
weights_copy.pop(chosen_idx)
return selected
class ExpressionSelector:
@staticmethod
def _sample_with_temperature(
candidates: list[tuple[Any, float, float, str]],
max_num: int,
temperature: float,
) -> list[tuple[Any, float, float, str]]:
"""
对候选表达按温度采样,温度越高越均匀。
Args:
candidates: (expr, similarity, count, best_predicted) 列表
max_num: 需要返回的数量
temperature: 温度参数0 表示贪婪选择
"""
if max_num <= 0 or not candidates:
return []
if temperature <= 0:
return candidates[:max_num]
adjusted_temp = max(temperature, 1e-6)
# 使用与排序相同的打分,但通过 softmax/temperature 放大尾部概率
scores = [max(c[1] * (c[2] ** 0.5), 1e-8) for c in candidates]
max_score = max(scores)
weights = [math.exp((s - max_score) / adjusted_temp) for s in scores]
# 始终保留最高分一个,剩余的按温度采样,避免过度集中
best_idx = scores.index(max_score)
selected = [candidates[best_idx]]
remaining_indices = [i for i in range(len(candidates)) if i != best_idx]
while remaining_indices and len(selected) < max_num:
current_weights = [weights[i] for i in remaining_indices]
picked_pos = random.choices(range(len(remaining_indices)), weights=current_weights, k=1)[0]
picked_idx = remaining_indices.pop(picked_pos)
selected.append(candidates[picked_idx])
return selected
def __init__(self, chat_id: str = ""):
self.chat_id = chat_id
if model_config is None:
raise RuntimeError("Model config is not initialized")
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
@staticmethod
def can_use_expression_for_chat(chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
if global_config is None:
return False
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> list[str]:
"""根据expression.rules配置获取与当前chat_id相关的所有chat_id包括自身"""
if global_config is None:
return [chat_id]
rules = global_config.expression.rules
current_group = None
# 找到当前chat_id所在的组
for rule in rules:
if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == chat_id:
current_group = rule.group
break
# 🔥 始终包含当前 chat_id确保至少能查到自己的数据
related_chat_ids = [chat_id]
if current_group:
# 找出同一组的所有chat_id
for rule in rules:
if rule.group == current_group and rule.chat_stream_id:
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
if chat_id_candidate not in related_chat_ids:
related_chat_ids.append(chat_id_candidate)
return related_chat_ids
async def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 使用CRUD查询由于需要IN条件使用session
async with get_db_session() as session:
# 优化一次性查询所有相关chat_id的表达方式
style_query = await session.execute(
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style"))
)
grammar_query = await session.execute(
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar"))
)
# 🔥 优化:提前定义转换函数,避免重复代码
def expr_to_dict(expr, expr_type: str) -> dict[str, Any]:
return {
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": expr_type,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
style_exprs = [expr_to_dict(expr, "style") for expr in style_query.scalars()]
grammar_exprs = [expr_to_dict(expr, "grammar") for expr in grammar_query.scalars()]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
selected_style = weighted_sample(style_exprs, style_weights, style_num)
else:
selected_style = []
if grammar_exprs:
grammar_weights = [expr.get("count", 1) for expr in grammar_exprs]
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
else:
selected_grammar = []
return selected_style, selected_grammar
@staticmethod
async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库
🔥 优化:合并所有更新到一个事务中,减少数据库连接开销
"""
if not expressions_to_update:
return
# 去重处理
updates_by_key = {}
affected_chat_ids = set()
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
affected_chat_ids.add(source_id)
if not updates_by_key:
return
# 🔥 优化:使用单个 session 批量处理所有更新
current_time = time.time()
async with get_db_session() as session:
updated_count = 0
for chat_id, expr_type, situation, style in updates_by_key:
query_result = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
)
expr_obj = query_result.scalar()
if expr_obj:
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = current_time
updated_count += 1
# 批量提交所有更改
if updated_count > 0:
await session.commit()
logger.debug(f"批量更新了 {updated_count} 个表达方式的count值")
# 清除所有受影响的chat_id的缓存
if affected_chat_ids:
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
for chat_id in affected_chat_ids:
await cache.delete(generate_cache_key("chat_expressions", chat_id))
async def select_suitable_expressions(
self,
chat_id: str,
chat_history: list | str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""
统一的表达方式选择入口,根据配置自动选择模式
Args:
chat_id: 聊天ID
chat_history: 聊天历史(列表或字符串)
target_message: 目标消息
max_num: 最多返回数量
min_num: 最少返回数量
Returns:
选中的表达方式列表
"""
# 转换chat_history为字符串
if isinstance(chat_history, list):
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
else:
chat_info = chat_history
if global_config is None:
raise RuntimeError("Global config is not initialized")
# 根据配置选择模式
mode = global_config.expression.mode
logger.debug(f"使用表达选择模式: {mode}")
if mode == "exp_model":
return await self._select_expressions_model_only(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
else: # classic mode
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
async def _select_expressions_classic(
self,
chat_id: str,
chat_info: str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""经典模式:随机抽样 + LLM评估"""
logger.debug("使用LLM评估表达方式")
return await self.select_suitable_expressions_llm(
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
min_num=min_num,
target_message=target_message
)
async def _select_expressions_model_only(
self,
chat_id: str,
chat_info: str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""模型预测模式先提取情境再使用StyleLearner预测表达风格"""
logger.debug("使用情境提取 + StyleLearner预测表达方式")
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 步骤1: 提取聊天情境
situations = await situation_extractor.extract_situations(
chat_history=chat_info,
target_message=target_message,
max_situations=3
)
if not situations:
logger.debug("无法提取聊天情境,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
logger.debug(f"提取到 {len(situations)} 个情境")
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
learner = style_learner_manager.get_learner(chat_id)
all_predicted_styles = {}
for i, situation in enumerate(situations, 1):
logger.debug(f"为情境 {i} 预测风格: {situation}")
best_style, scores = learner.predict_style(situation, top_k=max_num)
if best_style and scores:
logger.debug(f"预测最佳风格: {best_style}")
# 合并分数(取最高分)
for style, score in scores.items():
if style not in all_predicted_styles or score > all_predicted_styles[style]:
all_predicted_styles[style] = score
else:
logger.debug("该情境未返回预测结果")
if not all_predicted_styles:
logger.debug("StyleLearner未返回预测结果回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
# 将分数字典转换为列表格式 [(style, score), ...]
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
logger.debug(f"预测到 {len(predicted_styles)} 个风格")
# 步骤3: 根据预测的风格从数据库获取表达方式
logger.debug("从数据库查询表达方式")
expressions = await self.get_model_predicted_expressions(
chat_id=chat_id,
predicted_styles=predicted_styles,
max_num=max_num
)
if not expressions:
logger.debug("未找到匹配预测风格的表达方式,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
logger.debug(f"返回 {len(expressions)} 个表达方式")
return expressions
async def get_model_predicted_expressions(
self,
chat_id: str,
predicted_styles: list[tuple[str, float]],
max_num: int = 10
) -> list[dict[str, Any]]:
"""
根据StyleLearner预测的风格获取表达方式
Args:
chat_id: 聊天ID
predicted_styles: 预测的风格列表,格式: [(style, score), ...]
max_num: 最多返回数量
Returns:
表达方式列表
"""
if not predicted_styles:
return []
# 提取风格名称前3个最佳匹配
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}")
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id支持共享表达方式
related_chat_ids = self.get_related_chat_ids(chat_id)
logger.debug(f"查询相关的chat_ids: {len(related_chat_ids)}")
async with get_db_session() as session:
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
db_chat_ids_result = await session.execute(
select(Expression.chat_id)
.where(Expression.type == "style")
.distinct()
)
db_chat_ids = list(db_chat_ids_result.scalars())
logger.debug(f"数据库中有表达方式的chat_ids: {len(db_chat_ids)}")
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
all_expressions_result = await session.execute(
select(Expression)
.where(Expression.chat_id.in_(related_chat_ids))
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.debug(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
if not all_expressions:
logger.debug("相关chat_id没有数据尝试从所有chat_id查询")
all_expressions_result = await session.execute(
select(Expression)
.where(Expression.type == "style")
)
all_expressions = list(all_expressions_result.scalars())
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
if not all_expressions:
logger.warning("数据库中完全没有任何表达方式,需要先学习")
return []
# 🔥 优化:使用更高效的模糊匹配算法
from difflib import SequenceMatcher
# 预处理:提前计算所有预测 style 的小写版本,避免重复计算
predicted_styles_lower = [(s.lower(), score) for s, score in predicted_styles[:20]]
matched_expressions = []
for expr in all_expressions:
db_style = expr.style or ""
db_style_lower = db_style.lower()
max_similarity = 0.0
best_predicted = ""
# 与每个预测的style计算相似度
for predicted_style_lower, pred_score in predicted_styles_lower:
# 快速检查:完全匹配
if predicted_style_lower == db_style_lower:
max_similarity = 1.0
best_predicted = predicted_style_lower
break
# 快速检查:子串匹配
if len(predicted_style_lower) >= 2 and len(db_style_lower) >= 2:
if predicted_style_lower in db_style_lower or db_style_lower in predicted_style_lower:
similarity = 0.7
if similarity > max_similarity:
max_similarity = similarity
best_predicted = predicted_style_lower
continue
# 计算字符串相似度(较慢,只在必要时使用)
similarity = SequenceMatcher(None, predicted_style_lower, db_style_lower).ratio()
if similarity > max_similarity:
max_similarity = similarity
best_predicted = predicted_style_lower
# 🔥 降低阈值到30%因为StyleLearner预测质量较差
if max_similarity >= 0.3: # 30%相似度阈值
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
if not matched_expressions:
# 收集数据库中的style样例用于调试
all_styles = [e.style for e in all_expressions[:10]]
logger.warning(
f"数据库中没有找到匹配的表达方式相似度阈值30%:\n"
f" 预测的style (前3个): {style_names}\n"
f" 数据库中存在的style样例: {all_styles}\n"
f" 提示: StyleLearner预测质量差建议重新训练或使用classic模式"
)
return []
# 按照相似度*count排序并根据温度采样避免过度集中
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
temperature = getattr(global_config.expression, "model_temperature", 0.0)
sampled_matches = self._sample_with_temperature(
candidates=matched_expressions,
max_num=max_num,
temperature=temperature,
)
expressions_objs = [e[0] for e in sampled_matches]
# 显示最佳匹配的详细信息
logger.debug(
f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式 "
f"(候选 {len(matched_expressions)}temperature={temperature})"
)
# 🔥 优化:使用列表推导式和预定义函数减少开销
expressions = [
{
"situation": expr.situation or "",
"style": expr.style or "",
"type": expr.type or "style",
"count": float(expr.count) if expr.count else 0.0,
"last_active_time": expr.last_active_time or 0.0,
"source_id": expr.chat_id # 添加 source_id 以便后续更新
}
for expr in expressions_objs
]
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
return expressions
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
min_num: int = 5,
target_message: str | None = None,
) -> list[dict[str, Any]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 1. 获取35个随机表达方式现在按权重抽取
style_exprs, grammar_exprs = await self.get_random_expressions(chat_id, 30, 0.5, 0.5)
# 2. 构建所有表达方式的索引和情境列表
all_expressions = []
all_situations = []
# 添加style表达方式
for expr in style_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_with_type = expr.copy()
expr_with_type["type"] = "style"
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
# 添加grammar表达方式
for expr in grammar_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_with_type = expr.copy()
expr_with_type["type"] = "grammar"
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
if global_config is None:
raise RuntimeError("Global config is not initialized")
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
min_num=min_num,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# print(prompt)
# 4. 调用LLM
try:
# start_time = time.time()
content, (_reasoning_content, _model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning("LLM返回空结果")
return []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = orjson.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
valid_expressions.append(expression)
# 对选中的所有表达方式一次性更新count数
if valid_expressions:
asyncio.create_task(self.update_expressions_count_batch(valid_expressions, 0.006)) # noqa: RUF006
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions
except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}")
return []
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
print(f"ExpressionSelector初始化失败: {e}")