feat:合并normal和focus的prompt构建
This commit is contained in:
@@ -8,10 +8,12 @@
|
||||
success, reply_set = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
||||
"""
|
||||
|
||||
from typing import Tuple, Any, Dict, List
|
||||
from typing import Tuple, Any, Dict, List, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
|
||||
logger = get_logger("generator_api")
|
||||
|
||||
@@ -21,46 +23,36 @@ logger = get_logger("generator_api")
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_replyer(chat_stream=None, chat_id: str = None) -> DefaultReplyer:
|
||||
def get_replyer(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
request_type: str = "replyer"
|
||||
) -> Optional[DefaultReplyer]:
|
||||
"""获取回复器对象
|
||||
|
||||
优先使用chat_stream,如果没有则使用chat_id直接查找
|
||||
优先使用chat_stream,如果没有则使用chat_id直接查找。
|
||||
使用 ReplyerManager 来管理实例,避免重复创建。
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
chat_id: 聊天ID(实际上就是stream_id)
|
||||
model_configs: 模型配置列表
|
||||
request_type: 请求类型
|
||||
|
||||
Returns:
|
||||
Optional[Any]: 回复器对象,如果获取失败则返回None
|
||||
Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None
|
||||
"""
|
||||
try:
|
||||
# 优先使用聊天流
|
||||
if chat_stream:
|
||||
logger.debug("[GeneratorAPI] 使用聊天流获取回复器")
|
||||
return DefaultReplyer(chat_stream=chat_stream)
|
||||
|
||||
# 使用chat_id直接查找(chat_id即为stream_id)
|
||||
if chat_id:
|
||||
logger.debug("[GeneratorAPI] 使用chat_id获取回复器")
|
||||
chat_manager = get_chat_manager()
|
||||
if not chat_manager:
|
||||
logger.warning("[GeneratorAPI] 无法获取聊天管理器")
|
||||
return None
|
||||
|
||||
# 直接使用chat_id作为stream_id查找
|
||||
target_stream = chat_manager.get_stream(chat_id)
|
||||
|
||||
if target_stream is None:
|
||||
logger.warning(f"[GeneratorAPI] 未找到匹配的聊天流 chat_id={chat_id}")
|
||||
return None
|
||||
|
||||
return DefaultReplyer(chat_stream=target_stream)
|
||||
|
||||
logger.warning("[GeneratorAPI] 缺少必要参数,无法获取回复器")
|
||||
return None
|
||||
|
||||
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||
return replyer_manager.get_replyer(
|
||||
chat_stream=chat_stream,
|
||||
chat_id=chat_id,
|
||||
model_configs=model_configs,
|
||||
request_type=request_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 获取回复器失败: {e}")
|
||||
logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
@@ -71,10 +63,18 @@ def get_replyer(chat_stream=None, chat_id: str = None) -> DefaultReplyer:
|
||||
|
||||
async def generate_reply(
|
||||
chat_stream=None,
|
||||
action_data: Dict[str, Any] = None,
|
||||
chat_id: str = None,
|
||||
action_data: Dict[str, Any] = None,
|
||||
reply_to: str = "",
|
||||
relation_info: str = "",
|
||||
structured_info: str = "",
|
||||
extra_info: str = "",
|
||||
available_actions: List[str] = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
return_prompt: bool = False,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
request_type: str = "",
|
||||
) -> Tuple[bool, List[Tuple[str, Any]]]:
|
||||
"""生成回复
|
||||
|
||||
@@ -84,13 +84,13 @@ async def generate_reply(
|
||||
chat_id: 聊天ID(备用)
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
|
||||
return_prompt: 是否返回提示词
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, chat_id)
|
||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, []
|
||||
@@ -98,16 +98,26 @@ async def generate_reply(
|
||||
logger.info("[GeneratorAPI] 开始生成回复")
|
||||
|
||||
# 调用回复器生成回复
|
||||
success, reply_set = await replyer.generate_reply_with_context(
|
||||
reply_data=action_data or {}, enable_splitter=enable_splitter, enable_chinese_typo=enable_chinese_typo
|
||||
success, content, prompt = await replyer.generate_reply_with_context(
|
||||
reply_data=action_data or {},
|
||||
reply_to=reply_to,
|
||||
relation_info=relation_info,
|
||||
structured_info=structured_info,
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
|
||||
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
if success:
|
||||
logger.info(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||
|
||||
return success, reply_set or []
|
||||
if return_prompt:
|
||||
return success, reply_set or [], prompt
|
||||
else:
|
||||
return success, reply_set or []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
||||
@@ -120,6 +130,7 @@ async def rewrite_reply(
|
||||
chat_id: str = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]]]:
|
||||
"""重写回复
|
||||
|
||||
@@ -135,7 +146,7 @@ async def rewrite_reply(
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, chat_id)
|
||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, []
|
||||
@@ -143,9 +154,11 @@ async def rewrite_reply(
|
||||
logger.info("[GeneratorAPI] 开始重写回复")
|
||||
|
||||
# 调用回复器重写回复
|
||||
success, reply_set = await replyer.rewrite_reply_with_context(
|
||||
reply_data=reply_data or {}, enable_splitter=enable_splitter, enable_chinese_typo=enable_chinese_typo
|
||||
success, content = await replyer.rewrite_reply_with_context(
|
||||
reply_data=reply_data or {}
|
||||
)
|
||||
|
||||
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
if success:
|
||||
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
|
||||
@@ -157,3 +170,30 @@ async def rewrite_reply(
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
||||
return False, []
|
||||
|
||||
|
||||
async def process_human_text(
|
||||
content:str,
|
||||
enable_splitter:bool,
|
||||
enable_chinese_typo:bool
|
||||
) -> List[Tuple[str, Any]]:
|
||||
"""将文本处理为更拟人化的文本
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
"""
|
||||
try:
|
||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
reply_set = []
|
||||
for str in processed_response:
|
||||
reply_seg = ("text", str)
|
||||
reply_set.append(reply_seg)
|
||||
|
||||
return reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
|
||||
return []
|
||||
Reference in New Issue
Block a user