feat(expression): 添加表达方式选择模式支持与DatabaseMessages兼容性改进
- 新增统一的表达方式选择入口,支持classic和exp_model两种模式 - 添加StyleLearner模型预测模式,可基于机器学习模型选择表达风格 - 改进多个模块对DatabaseMessages数据模型的兼容性处理 - 优化消息处理逻辑,统一处理字典和DatabaseMessages对象 - 在配置中添加expression.mode字段控制表达选择模式
This commit is contained in:
@@ -15,6 +15,9 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 导入StyleLearner管理器
|
||||
from .style_learner import style_learner_manager
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
@@ -236,6 +239,181 @@ class ExpressionSelector:
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_history: list | str,
|
||||
target_message: str | None = None,
|
||||
max_num: int = 10,
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
统一的表达方式选择入口,根据配置自动选择模式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
chat_history: 聊天历史(列表或字符串)
|
||||
target_message: 目标消息
|
||||
max_num: 最多返回数量
|
||||
min_num: 最少返回数量
|
||||
|
||||
Returns:
|
||||
选中的表达方式列表
|
||||
"""
|
||||
# 转换chat_history为字符串
|
||||
if isinstance(chat_history, list):
|
||||
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
|
||||
else:
|
||||
chat_info = chat_history
|
||||
|
||||
# 根据配置选择模式
|
||||
mode = global_config.expression.mode
|
||||
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
|
||||
|
||||
if mode == "exp_model":
|
||||
return await self._select_expressions_model_only(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
target_message=target_message,
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
else: # classic mode
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
target_message=target_message,
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
target_message: str | None = None,
|
||||
max_num: int = 10,
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""经典模式:随机抽样 + LLM评估"""
|
||||
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
|
||||
return await self.select_suitable_expressions_llm(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
max_num=max_num,
|
||||
min_num=min_num,
|
||||
target_message=target_message
|
||||
)
|
||||
|
||||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
target_message: str | None = None,
|
||||
max_num: int = 10,
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""模型预测模式:使用StyleLearner预测最合适的表达风格"""
|
||||
logger.debug(f"[Exp_model模式] 使用StyleLearner预测表达方式")
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return []
|
||||
|
||||
# 获取或创建StyleLearner实例
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
# 使用StyleLearner预测最合适的风格
|
||||
best_style, all_scores = learner.predict_style(chat_info, top_k=max_num)
|
||||
|
||||
if not best_style or not all_scores:
|
||||
logger.warning(f"StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
target_message=target_message,
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
# 将分数字典转换为列表格式 [(style, score), ...]
|
||||
predicted_styles = sorted(all_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 根据预测的风格从数据库获取表达方式
|
||||
expressions = await self.get_model_predicted_expressions(
|
||||
chat_id=chat_id,
|
||||
predicted_styles=predicted_styles,
|
||||
max_num=max_num
|
||||
)
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"未找到匹配预测风格的表达方式,回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
target_message=target_message,
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
logger.debug(f"[Exp_model模式] 成功返回 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
async def get_model_predicted_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
predicted_styles: list[tuple[str, float]],
|
||||
max_num: int = 10
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
根据StyleLearner预测的风格获取表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
predicted_styles: 预测的风格列表,格式: [(style, score), ...]
|
||||
max_num: 最多返回数量
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
if not predicted_styles:
|
||||
return []
|
||||
|
||||
# 提取风格名称(前3个最佳匹配)
|
||||
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
|
||||
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 查询匹配这些风格的表达方式
|
||||
stmt = (
|
||||
select(Expression)
|
||||
.where(Expression.chat_id == chat_id)
|
||||
.where(Expression.style.in_(style_names))
|
||||
.order_by(Expression.count.desc())
|
||||
.limit(max_num)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
expressions_objs = result.scalars().all()
|
||||
|
||||
if not expressions_objs:
|
||||
logger.debug(f"数据库中没有找到风格 {style_names} 的表达方式")
|
||||
return []
|
||||
|
||||
# 转换为字典格式
|
||||
expressions = []
|
||||
for expr in expressions_objs:
|
||||
expressions.append({
|
||||
"situation": expr.situation or "",
|
||||
"style": expr.style or "",
|
||||
"type": expr.type or "style",
|
||||
"count": float(expr.count) if expr.count else 0.0,
|
||||
"last_active_time": expr.last_active_time or 0.0
|
||||
})
|
||||
|
||||
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
||||
Reference in New Issue
Block a user