221 lines
8.6 KiB
Python
221 lines
8.6 KiB
Python
from typing import Optional, List, Dict, Any, Tuple
|
||
from src.common.logger_manager import get_logger
|
||
from src.chat.message_receive.chat_stream import ChatManager, ChatStream
|
||
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
||
import asyncio
|
||
|
||
logger = get_logger("stream_api")
|
||
|
||
|
||
class StreamAPI:
|
||
"""聊天流API模块
|
||
|
||
提供了获取聊天流、通过群ID查找聊天流等功能
|
||
"""
|
||
|
||
def get_chat_stream_by_group_id(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||
"""通过QQ群ID获取聊天流
|
||
|
||
Args:
|
||
group_id: QQ群ID
|
||
platform: 平台标识,默认为"qq"
|
||
|
||
Returns:
|
||
Optional[ChatStream]: 找到的聊天流对象,如果未找到则返回None
|
||
"""
|
||
try:
|
||
chat_manager = ChatManager()
|
||
|
||
# 遍历所有已加载的聊天流,查找匹配的群ID
|
||
for stream_id, stream in chat_manager.streams.items():
|
||
if (
|
||
stream.group_info
|
||
and str(stream.group_info.group_id) == str(group_id)
|
||
and stream.platform == platform
|
||
):
|
||
logger.info(f"{self.log_prefix} 通过群ID {group_id} 找到聊天流: {stream_id}")
|
||
return stream
|
||
|
||
logger.warning(f"{self.log_prefix} 未找到群ID为 {group_id} 的聊天流")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"{self.log_prefix} 通过群ID获取聊天流时出错: {e}")
|
||
return None
|
||
|
||
def get_all_group_chat_streams(self, platform: str = "qq") -> List[ChatStream]:
|
||
"""获取所有群聊的聊天流
|
||
|
||
Args:
|
||
platform: 平台标识,默认为"qq"
|
||
|
||
Returns:
|
||
List[ChatStream]: 所有群聊的聊天流列表
|
||
"""
|
||
try:
|
||
chat_manager = ChatManager()
|
||
group_streams = []
|
||
|
||
for stream in chat_manager.streams.values():
|
||
if stream.group_info and stream.platform == platform:
|
||
group_streams.append(stream)
|
||
|
||
logger.info(f"{self.log_prefix} 找到 {len(group_streams)} 个群聊聊天流")
|
||
return group_streams
|
||
|
||
except Exception as e:
|
||
logger.error(f"{self.log_prefix} 获取所有群聊聊天流时出错: {e}")
|
||
return []
|
||
|
||
def get_chat_stream_by_user_id(self, user_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||
"""通过用户ID获取私聊聊天流
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
platform: 平台标识,默认为"qq"
|
||
|
||
Returns:
|
||
Optional[ChatStream]: 找到的私聊聊天流对象,如果未找到则返回None
|
||
"""
|
||
try:
|
||
chat_manager = ChatManager()
|
||
|
||
# 遍历所有已加载的聊天流,查找匹配的用户ID(私聊)
|
||
for stream_id, stream in chat_manager.streams.items():
|
||
if (
|
||
not stream.group_info # 私聊没有群信息
|
||
and stream.user_info
|
||
and str(stream.user_info.user_id) == str(user_id)
|
||
and stream.platform == platform
|
||
):
|
||
logger.info(f"{self.log_prefix} 通过用户ID {user_id} 找到私聊聊天流: {stream_id}")
|
||
return stream
|
||
|
||
logger.warning(f"{self.log_prefix} 未找到用户ID为 {user_id} 的私聊聊天流")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"{self.log_prefix} 通过用户ID获取私聊聊天流时出错: {e}")
|
||
return None
|
||
|
||
def get_chat_streams_info(self) -> List[Dict[str, Any]]:
|
||
"""获取所有聊天流的基本信息
|
||
|
||
Returns:
|
||
List[Dict[str, Any]]: 包含聊天流基本信息的字典列表
|
||
"""
|
||
try:
|
||
chat_manager = ChatManager()
|
||
streams_info = []
|
||
|
||
for stream_id, stream in chat_manager.streams.items():
|
||
info = {
|
||
"stream_id": stream_id,
|
||
"platform": stream.platform,
|
||
"chat_type": "group" if stream.group_info else "private",
|
||
"create_time": stream.create_time,
|
||
"last_active_time": stream.last_active_time,
|
||
}
|
||
|
||
if stream.group_info:
|
||
info.update({"group_id": stream.group_info.group_id, "group_name": stream.group_info.group_name})
|
||
|
||
if stream.user_info:
|
||
info.update({"user_id": stream.user_info.user_id, "user_nickname": stream.user_info.user_nickname})
|
||
|
||
streams_info.append(info)
|
||
|
||
logger.info(f"{self.log_prefix} 获取到 {len(streams_info)} 个聊天流信息")
|
||
return streams_info
|
||
|
||
except Exception as e:
|
||
logger.error(f"{self.log_prefix} 获取聊天流信息时出错: {e}")
|
||
return []
|
||
|
||
async def get_chat_stream_by_group_id_async(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||
"""异步通过QQ群ID获取聊天流(包括从数据库搜索)
|
||
|
||
Args:
|
||
group_id: QQ群ID
|
||
platform: 平台标识,默认为"qq"
|
||
|
||
Returns:
|
||
Optional[ChatStream]: 找到的聊天流对象,如果未找到则返回None
|
||
"""
|
||
try:
|
||
# 首先尝试从内存中查找
|
||
stream = self.get_chat_stream_by_group_id(group_id, platform)
|
||
if stream:
|
||
return stream
|
||
|
||
# 如果内存中没有,尝试从数据库加载所有聊天流后再查找
|
||
chat_manager = ChatManager()
|
||
await chat_manager.load_all_streams()
|
||
|
||
# 再次尝试从内存中查找
|
||
stream = self.get_chat_stream_by_group_id(group_id, platform)
|
||
return stream
|
||
|
||
except Exception as e:
|
||
logger.error(f"{self.log_prefix} 异步通过群ID获取聊天流时出错: {e}")
|
||
return None
|
||
|
||
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
||
"""等待新消息或超时
|
||
|
||
Args:
|
||
timeout: 超时时间(秒),默认1200秒
|
||
|
||
Returns:
|
||
Tuple[bool, str]: (是否收到新消息, 空字符串)
|
||
"""
|
||
try:
|
||
# 获取必要的服务对象
|
||
observations = self.get_service("observations")
|
||
if not observations:
|
||
logger.warning(f"{self.log_prefix} 无法获取observations服务,无法等待新消息")
|
||
return False, ""
|
||
|
||
# 获取第一个观察对象(通常是ChattingObservation)
|
||
observation = observations[0] if observations else None
|
||
if not observation:
|
||
logger.warning(f"{self.log_prefix} 无观察对象,无法等待新消息")
|
||
return False, ""
|
||
|
||
# 从action上下文获取thinking_id
|
||
thinking_id = self.get_action_context("thinking_id")
|
||
if not thinking_id:
|
||
logger.warning(f"{self.log_prefix} 无thinking_id,无法等待新消息")
|
||
return False, ""
|
||
|
||
logger.info(f"{self.log_prefix} 开始等待新消息... (超时: {timeout}秒)")
|
||
|
||
wait_start_time = asyncio.get_event_loop().time()
|
||
while True:
|
||
# 检查关闭标志
|
||
shutting_down = self.get_action_context("shutting_down", False)
|
||
if shutting_down:
|
||
logger.info(f"{self.log_prefix} 等待新消息时检测到关闭信号,中断等待")
|
||
return False, ""
|
||
|
||
# 检查新消息
|
||
thinking_id_timestamp = parse_thinking_id_to_timestamp(thinking_id)
|
||
if await observation.has_new_messages_since(thinking_id_timestamp):
|
||
logger.info(f"{self.log_prefix} 检测到新消息")
|
||
return True, ""
|
||
|
||
# 检查超时
|
||
if asyncio.get_event_loop().time() - wait_start_time > timeout:
|
||
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒)")
|
||
return False, ""
|
||
|
||
# 短暂休眠
|
||
await asyncio.sleep(0.5)
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
|
||
return False, ""
|
||
except Exception as e:
|
||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||
return False, f"等待新消息失败: {str(e)}"
|