fix:修复get_replyer无法用chat_id获取
This commit is contained in:
@@ -37,6 +37,7 @@ def init_prompt():
|
||||
{chat_info}
|
||||
{reply_target_block}
|
||||
{identity}
|
||||
|
||||
你需要使用合适的语言习惯和句法,参考聊天内容,组织一条日常且口语化的回复。注意不要复读你说过的话。
|
||||
{config_expression_style}
|
||||
{keywords_reaction_prompt}
|
||||
|
||||
@@ -21,16 +21,14 @@ logger = get_logger("generator_api")
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_replyer(chat_stream=None, platform: str = None, chat_id: str = None, is_group: bool = True) -> DefaultReplyer:
|
||||
def get_replyer(chat_stream=None, chat_id: str = None) -> DefaultReplyer:
|
||||
"""获取回复器对象
|
||||
|
||||
优先使用chat_stream,如果没有则使用platform和chat_id组合
|
||||
优先使用chat_stream,如果没有则使用chat_id直接查找
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
platform: 平台名称,如"qq"
|
||||
chat_id: 聊天ID(群ID或用户ID)
|
||||
is_group: 是否为群聊
|
||||
chat_id: 聊天ID(实际上就是stream_id)
|
||||
|
||||
Returns:
|
||||
Optional[Any]: 回复器对象,如果获取失败则返回None
|
||||
@@ -41,26 +39,20 @@ def get_replyer(chat_stream=None, platform: str = None, chat_id: str = None, is_
|
||||
logger.debug("[GeneratorAPI] 使用聊天流获取回复器")
|
||||
return DefaultReplyer(chat_stream=chat_stream)
|
||||
|
||||
# 使用平台和ID组合
|
||||
if platform and chat_id:
|
||||
logger.debug("[GeneratorAPI] 使用平台和ID获取回复器")
|
||||
# 使用chat_id直接查找(chat_id即为stream_id)
|
||||
if chat_id:
|
||||
logger.debug("[GeneratorAPI] 使用chat_id获取回复器")
|
||||
chat_manager = get_chat_manager()
|
||||
if not chat_manager:
|
||||
logger.warning("[GeneratorAPI] 无法获取聊天管理器")
|
||||
return None
|
||||
|
||||
# 查找对应的聊天流
|
||||
target_stream = None
|
||||
for _stream_id, stream in chat_manager.streams.items():
|
||||
if stream.platform == platform:
|
||||
if is_group and stream.group_info:
|
||||
if str(stream.group_info.group_id) == str(chat_id):
|
||||
target_stream = stream
|
||||
break
|
||||
elif not is_group and stream.user_info:
|
||||
if str(stream.user_info.user_id) == str(chat_id):
|
||||
target_stream = stream
|
||||
break
|
||||
# 直接使用chat_id作为stream_id查找
|
||||
target_stream = chat_manager.get_stream(chat_id)
|
||||
|
||||
if target_stream is None:
|
||||
logger.warning(f"[GeneratorAPI] 未找到匹配的聊天流 chat_id={chat_id}")
|
||||
return None
|
||||
|
||||
return DefaultReplyer(chat_stream=target_stream)
|
||||
|
||||
@@ -80,29 +72,21 @@ def get_replyer(chat_stream=None, platform: str = None, chat_id: str = None, is_
|
||||
async def generate_reply(
|
||||
chat_stream=None,
|
||||
action_data: Dict[str, Any] = None,
|
||||
platform: str = None,
|
||||
chat_id: str = None,
|
||||
is_group: bool = True,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]]]:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
action_data: 动作数据
|
||||
reasoning: 推理原因
|
||||
thinking_id: 思考ID
|
||||
cycle_timers: 循环计时器
|
||||
anchor_message: 锚点消息
|
||||
platform: 平台名称(备用)
|
||||
chat_id: 聊天ID(备用)
|
||||
is_group: 是否为群聊(备用)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, platform, chat_id, is_group)
|
||||
replyer = get_replyer(chat_stream, chat_id)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, []
|
||||
@@ -129,25 +113,21 @@ async def generate_reply(
|
||||
async def rewrite_reply(
|
||||
chat_stream=None,
|
||||
reply_data: Dict[str, Any] = None,
|
||||
platform: str = None,
|
||||
chat_id: str = None,
|
||||
is_group: bool = True,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]]]:
|
||||
"""重写回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
action_data: 动作数据
|
||||
platform: 平台名称(备用)
|
||||
reply_data: 回复数据
|
||||
chat_id: 聊天ID(备用)
|
||||
is_group: 是否为群聊(备用)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, platform, chat_id, is_group)
|
||||
replyer = get_replyer(chat_stream, chat_id)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, []
|
||||
|
||||
@@ -62,11 +62,8 @@ class ReplyAction(BaseAction):
|
||||
|
||||
try:
|
||||
success, reply_set = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
action_data=self.action_data,
|
||||
platform=self.platform,
|
||||
chat_id=self.chat_id,
|
||||
is_group=self.is_group,
|
||||
)
|
||||
|
||||
# 检查从start_time以来的新消息数量
|
||||
|
||||
Reference in New Issue
Block a user