better:精简rewrite的参数,添加注释

This commit is contained in:
SengokuCola
2025-07-27 00:46:34 +08:00
parent 3ab9b8def5
commit ab71d30437
2 changed files with 122 additions and 31 deletions

View File

@@ -158,7 +158,17 @@ class DefaultReplyer:
enable_timeout: bool = False, enable_timeout: bool = False,
) -> Tuple[bool, Optional[str], Optional[str]]: ) -> Tuple[bool, Optional[str], Optional[str]]:
""" """
回复器 (Replier): 核心逻辑,负责生成回复文本。 回复器 (Replier): 负责生成回复文本的核心逻辑
Args:
reply_to: 回复对象,格式为 "发送者:消息内容"
extra_info: 额外信息,用于补充上下文
available_actions: 可用的动作信息字典
enable_tool: 是否启用工具调用
enable_timeout: 是否启用超时处理
Returns:
Tuple[bool, Optional[str], Optional[str]]: (是否成功, 生成的回复内容, 使用的prompt)
""" """
prompt = None prompt = None
if available_actions is None: if available_actions is None:
@@ -219,25 +229,30 @@ class DefaultReplyer:
async def rewrite_reply_with_context( async def rewrite_reply_with_context(
self, self,
reply_data: Dict[str, Any],
raw_reply: str = "", raw_reply: str = "",
reason: str = "", reason: str = "",
reply_to: str = "", reply_to: str = "",
relation_info: str = "",
) -> Tuple[bool, Optional[str]]: ) -> Tuple[bool, Optional[str]]:
""" """
表达器 (Expressor): 核心逻辑,负责生成回复文本。 表达器 (Expressor): 负责重写和优化回复文本。
Args:
raw_reply: 原始回复内容
reason: 回复原因
reply_to: 回复对象,格式为 "发送者:消息内容"
relation_info: 关系信息
Returns:
Tuple[bool, Optional[str]]: (是否成功, 重写后的回复内容)
""" """
try: try:
if not reply_data:
reply_data = {
"reply_to": reply_to,
"relation_info": relation_info,
}
with Timer("构建Prompt", {}): # 内部计时器,可选保留 with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_rewrite_context( prompt = await self.build_prompt_rewrite_context(
reply_data=reply_data, raw_reply=raw_reply,
reason=reason,
reply_to=reply_to,
) )
content = None content = None
@@ -296,7 +311,16 @@ class DefaultReplyer:
return await relationship_fetcher.build_relation_info(person_id, points_num=5) return await relationship_fetcher.build_relation_info(person_id, points_num=5)
async def build_expression_habits(self, chat_history, target): async def build_expression_habits(self, chat_history: str, target: str) -> str:
"""构建表达习惯块
Args:
chat_history: 聊天历史记录
target: 目标消息内容
Returns:
str: 表达习惯信息字符串
"""
if not global_config.expression.enable_expression: if not global_config.expression.enable_expression:
return "" return ""
@@ -346,7 +370,16 @@ class DefaultReplyer:
return expression_habits_block return expression_habits_block
async def build_memory_block(self, chat_history, target): async def build_memory_block(self, chat_history: str, target: str) -> str:
"""构建记忆块
Args:
chat_history: 聊天历史记录
target: 目标消息内容
Returns:
str: 记忆信息字符串
"""
if not global_config.memory.enable_memory: if not global_config.memory.enable_memory:
return "" return ""
@@ -374,12 +407,13 @@ class DefaultReplyer:
return memory_str return memory_str
async def build_tool_info(self, chat_history, reply_to: str = "", enable_tool: bool = True): async def build_tool_info(self, chat_history: str, reply_to: str = "", enable_tool: bool = True) -> str:
"""构建工具信息块 """构建工具信息块
Args: Args:
reply_data: 回复数据,包含要回复的消息内容 chat_history: 聊天历史记录
chat_history: 聊天历史 reply_to: 回复对象,格式为 "发送者:消息内容"
enable_tool: 是否启用工具调用
Returns: Returns:
str: 工具信息字符串 str: 工具信息字符串
@@ -423,7 +457,15 @@ class DefaultReplyer:
logger.error(f"工具信息获取失败: {e}") logger.error(f"工具信息获取失败: {e}")
return "" return ""
def _parse_reply_target(self, target_message: str) -> tuple: def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息
Args:
target_message: 目标消息,格式为 "发送者:消息内容""发送者:消息内容"
Returns:
Tuple[str, str]: (发送者名称, 消息内容)
"""
sender = "" sender = ""
target = "" target = ""
# 添加None检查防止NoneType错误 # 添加None检查防止NoneType错误
@@ -437,7 +479,15 @@ class DefaultReplyer:
target = parts[1].strip() target = parts[1].strip()
return sender, target return sender, target
async def build_keywords_reaction_prompt(self, target): async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
"""构建关键词反应提示
Args:
target: 目标消息内容
Returns:
str: 关键词反应提示字符串
"""
# 关键词检测与反应 # 关键词检测与反应
keywords_reaction_prompt = "" keywords_reaction_prompt = ""
try: try:
@@ -471,15 +521,23 @@ class DefaultReplyer:
return keywords_reaction_prompt return keywords_reaction_prompt
async def _time_and_run_task(self, coroutine, name: str): async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]:
"""一个简单的帮助函数,用于计时运行异步任务,返回任务名、结果和耗时""" """计时运行异步任务的辅助函数
Args:
coroutine: 要执行的协程
name: 任务名称
Returns:
Tuple[str, Any, float]: (任务名称, 任务结果, 执行耗时)
"""
start_time = time.time() start_time = time.time()
result = await coroutine result = await coroutine
end_time = time.time() end_time = time.time()
duration = end_time - start_time duration = end_time - start_time
return name, result, duration return name, result, duration
def build_s4u_chat_history_prompts(self, message_list_before_now: list, target_user_id: str) -> tuple[str, str]: def build_s4u_chat_history_prompts(self, message_list_before_now: List[Dict[str, Any]], target_user_id: str) -> Tuple[str, str]:
""" """
构建 s4u 风格的分离对话 prompt 构建 s4u 风格的分离对话 prompt
@@ -488,7 +546,7 @@ class DefaultReplyer:
target_user_id: 目标用户ID当前对话对象 target_user_id: 目标用户ID当前对话对象
Returns: Returns:
tuple: (核心对话prompt, 背景对话prompt) Tuple[str, str]: (核心对话prompt, 背景对话prompt)
""" """
core_dialogue_list = [] core_dialogue_list = []
background_dialogue_list = [] background_dialogue_list = []
@@ -507,7 +565,7 @@ class DefaultReplyer:
# 其他用户的对话 # 其他用户的对话
background_dialogue_list.append(msg_dict) background_dialogue_list.append(msg_dict)
except Exception as e: except Exception as e:
logger.error(f"![1753364551656](image/default_generator/1753364551656.png)记录: {msg_dict}, 错误: {e}") logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
# 构建背景对话 prompt # 构建背景对话 prompt
background_dialogue_prompt = "" background_dialogue_prompt = ""
@@ -552,8 +610,25 @@ class DefaultReplyer:
sender: str, sender: str,
target: str, target: str,
chat_info: str, chat_info: str,
): ) -> Any:
"""构建 mai_think 上下文信息""" """构建 mai_think 上下文信息
Args:
chat_id: 聊天ID
memory_block: 记忆块内容
relation_info: 关系信息
time_block: 时间块内容
chat_target_1: 聊天目标1
chat_target_2: 聊天目标2
mood_prompt: 情绪提示
identity_block: 身份块内容
sender: 发送者名称
target: 目标消息内容
chat_info: 聊天信息
Returns:
Any: mai_think 实例
"""
mai_think = mai_thinking_manager.get_mai_think(chat_id) mai_think = mai_thinking_manager.get_mai_think(chat_id)
mai_think.memory_block = memory_block mai_think.memory_block = memory_block
mai_think.relation_info_block = relation_info mai_think.relation_info_block = relation_info
@@ -799,15 +874,14 @@ class DefaultReplyer:
async def build_prompt_rewrite_context( async def build_prompt_rewrite_context(
self, self,
reply_data: Dict[str, Any], raw_reply: str,
reason: str,
reply_to: str,
) -> str: ) -> str:
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info) is_group_chat = bool(chat_stream.group_info)
reply_to = reply_data.get("reply_to", "none")
raw_reply = reply_data.get("raw_reply", "")
reason = reply_data.get("reason", "")
sender, target = self._parse_reply_target(reply_to) sender, target = self._parse_reply_target(reply_to)
# 添加情绪状态获取 # 添加情绪状态获取
@@ -834,7 +908,7 @@ class DefaultReplyer:
# 并行执行2个构建任务 # 并行执行2个构建任务
expression_habits_block, relation_info = await asyncio.gather( expression_habits_block, relation_info = await asyncio.gather(
self.build_expression_habits(chat_talking_prompt_half, target), self.build_expression_habits(chat_talking_prompt_half, target),
self.build_relation_info(reply_data), self.build_relation_info(reply_to),
) )
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target) keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)

View File

@@ -151,15 +151,22 @@ async def rewrite_reply(
enable_splitter: bool = True, enable_splitter: bool = True,
enable_chinese_typo: bool = True, enable_chinese_typo: bool = True,
model_configs: Optional[List[Dict[str, Any]]] = None, model_configs: Optional[List[Dict[str, Any]]] = None,
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
) -> Tuple[bool, List[Tuple[str, Any]]]: ) -> Tuple[bool, List[Tuple[str, Any]]]:
"""重写回复 """重写回复
Args: Args:
chat_stream: 聊天流对象(优先) chat_stream: 聊天流对象(优先)
reply_data: 回复数据 reply_data: 回复数据字典(备用,当其他参数缺失时从此获取)
chat_id: 聊天ID备用 chat_id: 聊天ID备用
enable_splitter: 是否启用消息分割器 enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器 enable_chinese_typo: 是否启用错字生成器
model_configs: 模型配置列表
raw_reply: 原始回复内容
reason: 回复原因
reply_to: 回复对象
Returns: Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合) Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
@@ -173,8 +180,18 @@ async def rewrite_reply(
logger.info("[GeneratorAPI] 开始重写回复") logger.info("[GeneratorAPI] 开始重写回复")
# 如果参数缺失从reply_data中获取
if reply_data:
raw_reply = raw_reply or reply_data.get("raw_reply", "")
reason = reason or reply_data.get("reason", "")
reply_to = reply_to or reply_data.get("reply_to", "")
# 调用回复器重写回复 # 调用回复器重写回复
success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {}) success, content = await replyer.rewrite_reply_with_context(
raw_reply=raw_reply,
reason=reason,
reply_to=reply_to,
)
reply_set = [] reply_set = []
if content: if content:
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo) reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)