feat(expression): 添加表达方式选择模式支持与DatabaseMessages兼容性改进

- 新增统一的表达方式选择入口,支持classic和exp_model两种模式
- 添加StyleLearner模型预测模式,可基于机器学习模型选择表达风格
- 改进多个模块对DatabaseMessages数据模型的兼容性处理
- 优化消息处理逻辑,统一处理字典和DatabaseMessages对象
- 在配置中添加expression.mode字段控制表达选择模式
This commit is contained in:
Windpicker-owo
2025-10-29 22:52:32 +08:00
parent f2d7af6d87
commit f6349f278d
16 changed files with 1419 additions and 54 deletions

View File

@@ -15,6 +15,9 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
# 导入StyleLearner管理器
from .style_learner import style_learner_manager
logger = get_logger("expression_selector")
@@ -236,6 +239,181 @@ class ExpressionSelector:
)
await session.commit()
async def select_suitable_expressions(
self,
chat_id: str,
chat_history: list | str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""
统一的表达方式选择入口,根据配置自动选择模式
Args:
chat_id: 聊天ID
chat_history: 聊天历史(列表或字符串)
target_message: 目标消息
max_num: 最多返回数量
min_num: 最少返回数量
Returns:
选中的表达方式列表
"""
# 转换chat_history为字符串
if isinstance(chat_history, list):
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
else:
chat_info = chat_history
# 根据配置选择模式
mode = global_config.expression.mode
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
if mode == "exp_model":
return await self._select_expressions_model_only(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
else: # classic mode
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
async def _select_expressions_classic(
self,
chat_id: str,
chat_info: str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""经典模式:随机抽样 + LLM评估"""
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
return await self.select_suitable_expressions_llm(
chat_id=chat_id,
chat_info=chat_info,
max_num=max_num,
min_num=min_num,
target_message=target_message
)
async def _select_expressions_model_only(
self,
chat_id: str,
chat_info: str,
target_message: str | None = None,
max_num: int = 10,
min_num: int = 5,
) -> list[dict[str, Any]]:
"""模型预测模式使用StyleLearner预测最合适的表达风格"""
logger.debug(f"[Exp_model模式] 使用StyleLearner预测表达方式")
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 获取或创建StyleLearner实例
learner = style_learner_manager.get_learner(chat_id)
# 使用StyleLearner预测最合适的风格
best_style, all_scores = learner.predict_style(chat_info, top_k=max_num)
if not best_style or not all_scores:
logger.warning(f"StyleLearner未返回预测结果可能模型未训练回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
# 将分数字典转换为列表格式 [(style, score), ...]
predicted_styles = sorted(all_scores.items(), key=lambda x: x[1], reverse=True)
# 根据预测的风格从数据库获取表达方式
expressions = await self.get_model_predicted_expressions(
chat_id=chat_id,
predicted_styles=predicted_styles,
max_num=max_num
)
if not expressions:
logger.warning(f"未找到匹配预测风格的表达方式,回退到经典模式")
return await self._select_expressions_classic(
chat_id=chat_id,
chat_info=chat_info,
target_message=target_message,
max_num=max_num,
min_num=min_num
)
logger.debug(f"[Exp_model模式] 成功返回 {len(expressions)} 个表达方式")
return expressions
async def get_model_predicted_expressions(
self,
chat_id: str,
predicted_styles: list[tuple[str, float]],
max_num: int = 10
) -> list[dict[str, Any]]:
"""
根据StyleLearner预测的风格获取表达方式
Args:
chat_id: 聊天ID
predicted_styles: 预测的风格列表,格式: [(style, score), ...]
max_num: 最多返回数量
Returns:
表达方式列表
"""
if not predicted_styles:
return []
# 提取风格名称前3个最佳匹配
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
async with get_db_session() as session:
# 查询匹配这些风格的表达方式
stmt = (
select(Expression)
.where(Expression.chat_id == chat_id)
.where(Expression.style.in_(style_names))
.order_by(Expression.count.desc())
.limit(max_num)
)
result = await session.execute(stmt)
expressions_objs = result.scalars().all()
if not expressions_objs:
logger.debug(f"数据库中没有找到风格 {style_names} 的表达方式")
return []
# 转换为字典格式
expressions = []
for expr in expressions_objs:
expressions.append({
"situation": expr.situation or "",
"style": expr.style or "",
"type": expr.type or "style",
"count": float(expr.count) if expr.count else 0.0,
"last_active_time": expr.last_active_time or 0.0
})
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
return expressions
async def select_suitable_expressions_llm(
self,
chat_id: str,