feat(expression): 添加表达方式选择模式支持与DatabaseMessages兼容性改进
- 新增统一的表达方式选择入口,支持classic和exp_model两种模式 - 添加StyleLearner模型预测模式,可基于机器学习模型选择表达风格 - 改进多个模块对DatabaseMessages数据模型的兼容性处理 - 优化消息处理逻辑,统一处理字典和DatabaseMessages对象 - 在配置中添加expression.mode字段控制表达选择模式
This commit is contained in:
@@ -27,6 +27,7 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
@@ -474,10 +475,13 @@ class DefaultReplyer:
|
||||
style_habits = []
|
||||
grammar_habits = []
|
||||
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, min_num=2, target_message=target
|
||||
# 使用统一的表达方式选择入口(支持classic和exp_model模式)
|
||||
selected_expressions = await expression_selector.select_suitable_expressions(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
chat_history=chat_history,
|
||||
target_message=target,
|
||||
max_num=8,
|
||||
min_num=2
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
@@ -1206,7 +1210,7 @@ class DefaultReplyer:
|
||||
extra_info: str = "",
|
||||
available_actions: dict[str, ActionInfo] | None = None,
|
||||
enable_tool: bool = True,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
reply_message: dict[str, Any] | DatabaseMessages | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
构建回复器上下文
|
||||
@@ -1248,10 +1252,24 @@ class DefaultReplyer:
|
||||
if reply_message is None:
|
||||
logger.warning("reply_message 为 None,无法构建prompt")
|
||||
return ""
|
||||
platform = reply_message.get("chat_info_platform")
|
||||
|
||||
# 统一处理 DatabaseMessages 对象和字典
|
||||
if isinstance(reply_message, DatabaseMessages):
|
||||
platform = reply_message.chat_info.platform
|
||||
user_id = reply_message.user_info.user_id
|
||||
user_nickname = reply_message.user_info.user_nickname
|
||||
user_cardname = reply_message.user_info.user_cardname
|
||||
processed_plain_text = reply_message.processed_plain_text
|
||||
else:
|
||||
platform = reply_message.get("chat_info_platform")
|
||||
user_id = reply_message.get("user_id")
|
||||
user_nickname = reply_message.get("user_nickname")
|
||||
user_cardname = reply_message.get("user_cardname")
|
||||
processed_plain_text = reply_message.get("processed_plain_text")
|
||||
|
||||
person_id = person_info_manager.get_person_id(
|
||||
platform, # type: ignore
|
||||
reply_message.get("user_id"), # type: ignore
|
||||
user_id, # type: ignore
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
@@ -1260,22 +1278,22 @@ class DefaultReplyer:
|
||||
# 尝试从reply_message获取用户名
|
||||
await person_info_manager.first_knowing_some_one(
|
||||
platform, # type: ignore
|
||||
reply_message.get("user_id"), # type: ignore
|
||||
reply_message.get("user_nickname") or "",
|
||||
reply_message.get("user_cardname") or "",
|
||||
user_id, # type: ignore
|
||||
user_nickname or "",
|
||||
user_cardname or "",
|
||||
)
|
||||
|
||||
# 检查是否是bot自己的名字,如果是则替换为"(你)"
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
current_user_id = await person_info_manager.get_value(person_id, "user_id")
|
||||
current_platform = reply_message.get("chat_info_platform")
|
||||
current_platform = platform
|
||||
|
||||
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
|
||||
sender = f"{person_name}(你)"
|
||||
else:
|
||||
# 如果不是bot自己,直接使用person_name
|
||||
sender = person_name
|
||||
target = reply_message.get("processed_plain_text")
|
||||
target = processed_plain_text
|
||||
|
||||
# 最终的空值检查,确保sender和target不为None
|
||||
if sender is None:
|
||||
@@ -1609,15 +1627,22 @@ class DefaultReplyer:
|
||||
raw_reply: str,
|
||||
reason: str,
|
||||
reply_to: str,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
reply_message: dict[str, Any] | DatabaseMessages | None = None,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
if reply_message:
|
||||
sender = reply_message.get("sender")
|
||||
target = reply_message.get("target")
|
||||
if isinstance(reply_message, DatabaseMessages):
|
||||
# 从 DatabaseMessages 对象获取 sender 和 target
|
||||
# 注意: DatabaseMessages 没有直接的 sender/target 字段
|
||||
# 需要根据实际情况构造
|
||||
sender = reply_message.user_info.user_nickname or reply_message.user_info.user_id
|
||||
target = reply_message.processed_plain_text or ""
|
||||
else:
|
||||
sender = reply_message.get("sender")
|
||||
target = reply_message.get("target")
|
||||
else:
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user