This commit is contained in:
SengokuCola
2025-04-09 20:10:58 +08:00
22 changed files with 614 additions and 692 deletions

2
bot.py
View File

@@ -196,7 +196,7 @@ def raw_main():
# 安装崩溃日志处理器 # 安装崩溃日志处理器
install_crash_handler() install_crash_handler()
check_eula() check_eula()
print("检查EULA和隐私条款完成") print("检查EULA和隐私条款完成")
easter_egg() easter_egg()

View File

@@ -4,69 +4,66 @@ import logging
from pathlib import Path from pathlib import Path
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
def setup_crash_logger(): def setup_crash_logger():
"""设置崩溃日志记录器""" """设置崩溃日志记录器"""
# 创建logs/crash目录如果不存在 # 创建logs/crash目录如果不存在
crash_log_dir = Path("logs/crash") crash_log_dir = Path("logs/crash")
crash_log_dir.mkdir(parents=True, exist_ok=True) crash_log_dir.mkdir(parents=True, exist_ok=True)
# 创建日志记录器 # 创建日志记录器
crash_logger = logging.getLogger('crash_logger') crash_logger = logging.getLogger("crash_logger")
crash_logger.setLevel(logging.ERROR) crash_logger.setLevel(logging.ERROR)
# 设置日志格式 # 设置日志格式
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s\n' "%(asctime)s - %(name)s - %(levelname)s\n异常类型: %(exc_info)s\n详细信息:\n%(message)s\n-------------------\n"
'异常类型: %(exc_info)s\n'
'详细信息:\n%(message)s\n'
'-------------------\n'
) )
# 创建按大小轮转的文件处理器最大10MB保留5个备份 # 创建按大小轮转的文件处理器最大10MB保留5个备份
log_file = crash_log_dir / "crash.log" log_file = crash_log_dir / "crash.log"
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(
log_file, log_file,
maxBytes=10*1024*1024, # 10MB maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5, backupCount=5,
encoding='utf-8' encoding="utf-8",
) )
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
crash_logger.addHandler(file_handler) crash_logger.addHandler(file_handler)
return crash_logger return crash_logger
def log_crash(exc_type, exc_value, exc_traceback): def log_crash(exc_type, exc_value, exc_traceback):
"""记录崩溃信息到日志文件""" """记录崩溃信息到日志文件"""
if exc_type is None: if exc_type is None:
return return
# 获取崩溃日志记录器 # 获取崩溃日志记录器
crash_logger = logging.getLogger('crash_logger') crash_logger = logging.getLogger("crash_logger")
# 获取完整的异常堆栈信息 # 获取完整的异常堆栈信息
stack_trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) stack_trace = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
# 记录崩溃信息 # 记录崩溃信息
crash_logger.error( crash_logger.error(stack_trace, exc_info=(exc_type, exc_value, exc_traceback))
stack_trace,
exc_info=(exc_type, exc_value, exc_traceback)
)
def install_crash_handler(): def install_crash_handler():
"""安装全局异常处理器""" """安装全局异常处理器"""
# 设置崩溃日志记录器 # 设置崩溃日志记录器
setup_crash_logger() setup_crash_logger()
# 保存原始的异常处理器 # 保存原始的异常处理器
original_hook = sys.excepthook original_hook = sys.excepthook
def exception_handler(exc_type, exc_value, exc_traceback): def exception_handler(exc_type, exc_value, exc_traceback):
"""全局异常处理器""" """全局异常处理器"""
# 记录崩溃信息 # 记录崩溃信息
log_crash(exc_type, exc_value, exc_traceback) log_crash(exc_type, exc_value, exc_traceback)
# 调用原始的异常处理器 # 调用原始的异常处理器
original_hook(exc_type, exc_value, exc_traceback) original_hook(exc_type, exc_value, exc_traceback)
# 设置全局异常处理器 # 设置全局异常处理器
sys.excepthook = exception_handler sys.excepthook = exception_handler

73
src/common/server.py Normal file
View File

@@ -0,0 +1,73 @@
from fastapi import FastAPI, APIRouter
from typing import Optional, Union
from uvicorn import Config, Server as UvicornServer
import os
class Server:
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
self.app = FastAPI(title=app_name)
self._host: str = "127.0.0.1"
self._port: int = 8080
self._server: Optional[UvicornServer] = None
self.set_address(host, port)
def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由
APIRouter 用于对相关的路由端点进行分组和模块化管理:
1. 可以将相关的端点组织在一起,便于管理
2. 支持添加统一的路由前缀
3. 可以为一组路由添加共同的依赖项、标签等
示例:
router = APIRouter()
@router.get("/users")
def get_users():
return {"users": [...]}
@router.post("/users")
def create_user():
return {"msg": "user created"}
# 注册路由,添加前缀 "/api/v1"
server.register_router(router, prefix="/api/v1")
"""
self.app.include_router(router, prefix=prefix)
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
"""设置服务器地址和端口"""
if host:
self._host = host
if port:
self._port = port
async def run(self):
"""启动服务器"""
config = Config(app=self.app, host=self._host, port=self._port)
self._server = UvicornServer(config=config)
try:
await self._server.serve()
except KeyboardInterrupt:
await self.shutdown()
raise
except Exception as e:
await self.shutdown()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.shutdown()
async def shutdown(self):
"""安全关闭服务器"""
if self._server:
self._server.should_exit = True
await self._server.shutdown()
self._server = None
def get_app(self) -> FastAPI:
"""获取 FastAPI 实例"""
return self.app
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))

View File

@@ -16,7 +16,7 @@ from .plugins.chat.bot import chat_bot
from .common.logger import get_module_logger from .common.logger import get_module_logger
from .plugins.remote import heartbeat_thread # noqa: F401 from .plugins.remote import heartbeat_thread # noqa: F401
from .individuality.individuality import Individuality from .individuality.individuality import Individuality
from .common.server import global_server
logger = get_module_logger("main") logger = get_module_logger("main")
@@ -33,6 +33,7 @@ class MainSystem:
from .plugins.message import global_api from .plugins.message import global_api
self.app = global_api self.app = global_api
self.server = global_server
async def initialize(self): async def initialize(self):
"""初始化系统组件""" """初始化系统组件"""
@@ -126,6 +127,7 @@ class MainSystem:
emoji_manager.start_periodic_check_register(), emoji_manager.start_periodic_check_register(),
# emoji_manager.start_periodic_register(), # emoji_manager.start_periodic_register(),
self.app.run(), self.app.run(),
self.server.run(),
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View File

@@ -10,6 +10,7 @@ from .conversation_info import ConversationInfo
logger = get_module_logger("action_planner") logger = get_module_logger("action_planner")
class ActionPlannerInfo: class ActionPlannerInfo:
def __init__(self): def __init__(self):
self.done_action = [] self.done_action = []
@@ -20,68 +21,57 @@ class ActionPlannerInfo:
class ActionPlanner: class ActionPlanner:
"""行动规划器""" """行动规划器"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning"
temperature=0.7,
max_tokens=1000,
request_type="action_planning"
) )
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
async def plan( async def plan(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> Tuple[str, str]:
self,
observation_info: ObservationInfo,
conversation_info: ConversationInfo
) -> Tuple[str, str]:
"""规划下一步行动 """规划下一步行动
Args: Args:
observation_info: 决策信息 observation_info: 决策信息
conversation_info: 对话信息 conversation_info: 对话信息
Returns: Returns:
Tuple[str, str]: (行动类型, 行动原因) Tuple[str, str]: (行动类型, 行动原因)
""" """
# 构建提示词 # 构建提示词
logger.debug(f"开始规划行动:当前目标: {conversation_info.goal_list}") logger.debug(f"开始规划行动:当前目标: {conversation_info.goal_list}")
#构建对话目标 # 构建对话目标
if conversation_info.goal_list: if conversation_info.goal_list:
goal, reasoning = conversation_info.goal_list[-1] goal, reasoning = conversation_info.goal_list[-1]
else: else:
goal = "目前没有明确对话目标" goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标" reasoning = "目前没有明确对话目标,最好思考一个对话目标"
# 获取聊天历史记录 # 获取聊天历史记录
chat_history_list = observation_info.chat_history chat_history_list = observation_info.chat_history
chat_history_text = "" chat_history_text = ""
for msg in chat_history_list: for msg in chat_history_list:
chat_history_text += f"{msg}\n" chat_history_text += f"{msg}\n"
if observation_info.new_messages_count > 0: if observation_info.new_messages_count > 0:
new_messages_list = observation_info.unprocessed_messages new_messages_list = observation_info.unprocessed_messages
chat_history_text += f"{observation_info.new_messages_count}条新消息:\n" chat_history_text += f"{observation_info.new_messages_count}条新消息:\n"
for msg in new_messages_list: for msg in new_messages_list:
chat_history_text += f"{msg}\n" chat_history_text += f"{msg}\n"
observation_info.clear_unprocessed_messages() observation_info.clear_unprocessed_messages()
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本 # 构建action历史文本
action_history_list = conversation_info.done_action action_history_list = conversation_info.done_action
action_history_text = "你之前做的事情是:" action_history_text = "你之前做的事情是:"
for action in action_history_list: for action in action_history_list:
action_history_text += f"{action}\n" action_history_text += f"{action}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下内容根据信息决定下一步行动 prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下内容根据信息决定下一步行动
@@ -111,29 +101,27 @@ rethink_goal: 重新思考对话目标,当发现对话目标不合适时选择
try: try:
content, _ = await self.llm.generate_response_async(prompt) content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}") logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容 # 使用简化函数提取JSON内容
success, result = get_items_from_json( success, result = get_items_from_json(
content, content, "action", "reason", default_values={"action": "direct_reply", "reason": "没有明确原因"}
"action", "reason",
default_values={"action": "direct_reply", "reason": "没有明确原因"}
) )
if not success: if not success:
return "direct_reply", "JSON解析失败选择直接回复" return "direct_reply", "JSON解析失败选择直接回复"
action = result["action"] action = result["action"]
reason = result["reason"] reason = result["reason"]
# 验证action类型 # 验证action类型
if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal"]: if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal"]:
logger.warning(f"未知的行动类型: {action}默认使用listening") logger.warning(f"未知的行动类型: {action}默认使用listening")
action = "listening" action = "listening"
logger.info(f"规划的行动: {action}") logger.info(f"规划的行动: {action}")
logger.info(f"行动原因: {reason}") logger.info(f"行动原因: {reason}")
return action, reason return action, reason
except Exception as e: except Exception as e:
logger.error(f"规划行动时出错: {str(e)}") logger.error(f"规划行动时出错: {str(e)}")
return "direct_reply", "发生错误,选择直接回复" return "direct_reply", "发生错误,选择直接回复"

View File

@@ -17,20 +17,20 @@ class ChatObserver:
_instances: Dict[str, "ChatObserver"] = {} _instances: Dict[str, "ChatObserver"] = {}
@classmethod @classmethod
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver': def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> "ChatObserver":
"""获取或创建观察器实例 """获取或创建观察器实例
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
message_storage: 消息存储实现如果为None则使用MongoDB实现 message_storage: 消息存储实现如果为None则使用MongoDB实现
Returns: Returns:
ChatObserver: 观察器实例 ChatObserver: 观察器实例
""" """
if stream_id not in cls._instances: if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id, message_storage) cls._instances[stream_id] = cls(stream_id, message_storage)
return cls._instances[stream_id] return cls._instances[stream_id]
def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None): def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None):
"""初始化观察器 """初始化观察器
@@ -43,15 +43,15 @@ class ChatObserver:
self.stream_id = stream_id self.stream_id = stream_id
self.message_storage = message_storage or MongoDBMessageStorage() self.message_storage = message_storage or MongoDBMessageStorage()
self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
self.last_check_time: float = time.time() # 上次查看聊天记录时间 self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[str] = None # 最后读取的消息ID self.last_message_read: Optional[str] = None # 最后读取的消息ID
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳 self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
# 消息历史记录 # 消息历史记录
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史 self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
self.last_message_id: Optional[str] = None # 最后一条消息的ID self.last_message_id: Optional[str] = None # 最后一条消息的ID
@@ -62,20 +62,20 @@ class ChatObserver:
self._task: Optional[asyncio.Task] = None self._task: Optional[asyncio.Task] = None
self._update_event = asyncio.Event() # 触发更新的事件 self._update_event = asyncio.Event() # 触发更新的事件
self._update_complete = asyncio.Event() # 更新完成的事件 self._update_complete = asyncio.Event() # 更新完成的事件
# 通知管理器 # 通知管理器
self.notification_manager = NotificationManager() self.notification_manager = NotificationManager()
# 冷场检查配置 # 冷场检查配置
self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场 self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场
self.last_cold_chat_check: float = time.time() self.last_cold_chat_check: float = time.time()
self.is_cold_chat_state: bool = False self.is_cold_chat_state: bool = False
self.update_event = asyncio.Event() self.update_event = asyncio.Event()
self.update_interval = 5 # 更新间隔(秒) self.update_interval = 5 # 更新间隔(秒)
self.message_cache = [] self.message_cache = []
self.update_running = False self.update_running = False
async def check(self) -> bool: async def check(self) -> bool:
"""检查距离上一次观察之后是否有了新消息 """检查距离上一次观察之后是否有了新消息
@@ -83,21 +83,18 @@ class ChatObserver:
bool: 是否有新消息 bool: 是否有新消息
""" """
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}") logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
new_message_exists = await self.message_storage.has_new_messages( new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
self.stream_id,
self.last_check_time
)
if new_message_exists: if new_message_exists:
logger.debug("发现新消息") logger.debug("发现新消息")
self.last_check_time = time.time() self.last_check_time = time.time()
return new_message_exists return new_message_exists
async def _add_message_to_history(self, message: Dict[str, Any]): async def _add_message_to_history(self, message: Dict[str, Any]):
"""添加消息到历史记录并发送通知 """添加消息到历史记录并发送通知
Args: Args:
message: 消息数据 message: 消息数据
""" """
@@ -112,76 +109,65 @@ class ChatObserver:
self.last_bot_speak_time = message["time"] self.last_bot_speak_time = message["time"]
else: else:
self.last_user_speak_time = message["time"] self.last_user_speak_time = message["time"]
# 发送新消息通知 # 发送新消息通知
notification = create_new_message_notification( notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message)
sender="chat_observer",
target="pfc",
message=message
)
await self.notification_manager.send_notification(notification) await self.notification_manager.send_notification(notification)
# 检查并更新冷场状态 # 检查并更新冷场状态
await self._check_cold_chat() await self._check_cold_chat()
async def _check_cold_chat(self): async def _check_cold_chat(self):
"""检查是否处于冷场状态并发送通知""" """检查是否处于冷场状态并发送通知"""
current_time = time.time() current_time = time.time()
# 每10秒检查一次冷场状态 # 每10秒检查一次冷场状态
if current_time - self.last_cold_chat_check < 10: if current_time - self.last_cold_chat_check < 10:
return return
self.last_cold_chat_check = current_time self.last_cold_chat_check = current_time
# 判断是否冷场 # 判断是否冷场
is_cold = False is_cold = False
if self.last_message_time is None: if self.last_message_time is None:
is_cold = True is_cold = True
else: else:
is_cold = (current_time - self.last_message_time) > self.cold_chat_threshold is_cold = (current_time - self.last_message_time) > self.cold_chat_threshold
# 如果冷场状态发生变化,发送通知 # 如果冷场状态发生变化,发送通知
if is_cold != self.is_cold_chat_state: if is_cold != self.is_cold_chat_state:
self.is_cold_chat_state = is_cold self.is_cold_chat_state = is_cold
notification = create_cold_chat_notification( notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
sender="chat_observer",
target="pfc",
is_cold=is_cold
)
await self.notification_manager.send_notification(notification) await self.notification_manager.send_notification(notification)
async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象""" """获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
messages = await self.message_storage.get_messages_after( messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
self.stream_id,
self.last_message_read
)
for message in messages: for message in messages:
await self._add_message_to_history(message) await self._add_message_to_history(message)
return messages, self.message_history return messages, self.message_history
def new_message_after(self, time_point: float) -> bool: def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息 """判断是否在指定时间点后有新消息
Args: Args:
time_point: 时间戳 time_point: 时间戳
Returns: Returns:
bool: 是否有新消息 bool: 是否有新消息
""" """
if time_point is None: if time_point is None:
logger.warning("time_point 为 None返回 False") logger.warning("time_point 为 None返回 False")
return False return False
if self.last_message_time is None: if self.last_message_time is None:
logger.debug("没有最后消息时间,返回 False") logger.debug("没有最后消息时间,返回 False")
return False return False
has_new = self.last_message_time > time_point has_new = self.last_message_time > time_point
logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}") logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}")
return has_new return has_new
def get_message_history( def get_message_history(
self, self,
start_time: Optional[float] = None, start_time: Optional[float] = None,
@@ -224,11 +210,8 @@ class ChatObserver:
Returns: Returns:
List[Dict[str, Any]]: 新消息列表 List[Dict[str, Any]]: 新消息列表
""" """
new_messages = await self.message_storage.get_messages_after( new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
self.stream_id,
self.last_message_read
)
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
@@ -243,17 +226,15 @@ class ChatObserver:
Returns: Returns:
List[Dict[str, Any]]: 最多5条消息 List[Dict[str, Any]]: 最多5条消息
""" """
new_messages = await self.message_storage.get_messages_before( new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
self.stream_id,
time_point
)
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
return new_messages return new_messages
'''主要观察循环''' """主要观察循环"""
async def _update_loop(self): async def _update_loop(self):
"""更新循环""" """更新循环"""
try: try:
@@ -282,7 +263,7 @@ class ChatObserver:
# 处理新消息 # 处理新消息
for message in new_messages: for message in new_messages:
await self._add_message_to_history(message) await self._add_message_to_history(message)
# 设置完成事件 # 设置完成事件
self._update_complete.set() self._update_complete.set()
@@ -379,7 +360,7 @@ class ChatObserver:
if not self.update_running: if not self.update_running:
self.update_running = True self.update_running = True
asyncio.create_task(self._periodic_update()) asyncio.create_task(self._periodic_update())
async def _periodic_update(self): async def _periodic_update(self):
"""定期更新消息历史""" """定期更新消息历史"""
try: try:
@@ -388,53 +369,52 @@ class ChatObserver:
await asyncio.sleep(self.update_interval) await asyncio.sleep(self.update_interval)
except Exception as e: except Exception as e:
logger.error(f"定期更新消息历史时出错: {str(e)}") logger.error(f"定期更新消息历史时出错: {str(e)}")
async def _update_message_history(self) -> bool: async def _update_message_history(self) -> bool:
"""更新消息历史 """更新消息历史
Returns: Returns:
bool: 是否有新消息 bool: 是否有新消息
""" """
try: try:
messages = await self.message_storage.get_messages_for_stream( messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50)
self.stream_id,
limit=50
)
if not messages: if not messages:
return False return False
# 检查是否有新消息 # 检查是否有新消息
has_new_messages = False has_new_messages = False
if messages and (not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]): if messages and (
not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]
):
has_new_messages = True has_new_messages = True
self.message_cache = messages self.message_cache = messages
if has_new_messages: if has_new_messages:
self.update_event.set() self.update_event.set()
self.update_event.clear() self.update_event.clear()
return True return True
return False return False
except Exception as e: except Exception as e:
logger.error(f"更新消息历史时出错: {str(e)}") logger.error(f"更新消息历史时出错: {str(e)}")
return False return False
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]: def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
"""获取缓存的消息历史 """获取缓存的消息历史
Args: Args:
limit: 获取的最大消息数量默认50 limit: 获取的最大消息数量默认50
Returns: Returns:
List[Dict[str, Any]]: 缓存的消息历史列表 List[Dict[str, Any]]: 缓存的消息历史列表
""" """
return self.message_cache[:limit] return self.message_cache[:limit]
def get_last_message(self) -> Optional[Dict[str, Any]]: def get_last_message(self) -> Optional[Dict[str, Any]]:
"""获取最后一条消息 """获取最后一条消息
Returns: Returns:
Optional[Dict[str, Any]]: 最后一条消息如果没有则返回None Optional[Dict[str, Any]]: 最后一条消息如果没有则返回None
""" """

View File

@@ -4,32 +4,38 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class ChatState(Enum): class ChatState(Enum):
"""聊天状态枚举""" """聊天状态枚举"""
NORMAL = auto() # 正常状态
NEW_MESSAGE = auto() # 有新消息 NORMAL = auto() # 正常状态
COLD_CHAT = auto() # 冷场状态 NEW_MESSAGE = auto() # 有新消息
ACTIVE_CHAT = auto() # 活跃状态 COLD_CHAT = auto() # 冷场状态
BOT_SPEAKING = auto() # 机器人正在说话 ACTIVE_CHAT = auto() # 活跃状态
USER_SPEAKING = auto() # 用户正在说话 BOT_SPEAKING = auto() # 机器人正在说话
SILENT = auto() # 沉默状态 USER_SPEAKING = auto() # 用户正在说话
ERROR = auto() # 错误状态 SILENT = auto() # 沉默状态
ERROR = auto() # 错误状态
class NotificationType(Enum): class NotificationType(Enum):
"""通知类型枚举""" """通知类型枚举"""
NEW_MESSAGE = auto() # 新消息通知
COLD_CHAT = auto() # 冷场通知 NEW_MESSAGE = auto() # 新消息通知
ACTIVE_CHAT = auto() # 活跃通知 COLD_CHAT = auto() # 冷场通知
BOT_SPEAKING = auto() # 机器人说话通知 ACTIVE_CHAT = auto() # 活跃通知
USER_SPEAKING = auto() # 用户说话通知 BOT_SPEAKING = auto() # 机器人说话通知
MESSAGE_DELETED = auto() # 消息删除通知 USER_SPEAKING = auto() # 用户说话通知
USER_JOINED = auto() # 用户加入通知 MESSAGE_DELETED = auto() # 消息删除通知
USER_LEFT = auto() # 用户离开通知 USER_JOINED = auto() # 用户加入通知
ERROR = auto() # 错误通知 USER_LEFT = auto() # 用户离开通知
ERROR = auto() # 错误通知
@dataclass @dataclass
class ChatStateInfo: class ChatStateInfo:
"""聊天状态信息""" """聊天状态信息"""
state: ChatState state: ChatState
last_message_time: Optional[float] = None last_message_time: Optional[float] = None
last_message_content: Optional[str] = None last_message_content: Optional[str] = None
@@ -38,53 +44,55 @@ class ChatStateInfo:
cold_duration: float = 0.0 # 冷场持续时间(秒) cold_duration: float = 0.0 # 冷场持续时间(秒)
active_duration: float = 0.0 # 活跃持续时间(秒) active_duration: float = 0.0 # 活跃持续时间(秒)
@dataclass @dataclass
class Notification: class Notification:
"""通知基类""" """通知基类"""
type: NotificationType type: NotificationType
timestamp: float timestamp: float
sender: str # 发送者标识 sender: str # 发送者标识
target: str # 接收者标识 target: str # 接收者标识
data: Dict[str, Any] data: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式""" """转换为字典格式"""
return { return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
"type": self.type.name,
"timestamp": self.timestamp,
"data": self.data
}
@dataclass @dataclass
class StateNotification(Notification): class StateNotification(Notification):
"""持续状态通知""" """持续状态通知"""
is_active: bool = True is_active: bool = True
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
base_dict = super().to_dict() base_dict = super().to_dict()
base_dict["is_active"] = self.is_active base_dict["is_active"] = self.is_active
return base_dict return base_dict
class NotificationHandler(ABC): class NotificationHandler(ABC):
"""通知处理器接口""" """通知处理器接口"""
@abstractmethod @abstractmethod
async def handle_notification(self, notification: Notification): async def handle_notification(self, notification: Notification):
"""处理通知""" """处理通知"""
pass pass
class NotificationManager: class NotificationManager:
"""通知管理器""" """通知管理器"""
def __init__(self): def __init__(self):
# 按接收者和通知类型存储处理器 # 按接收者和通知类型存储处理器
self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {} self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {}
self._active_states: Set[NotificationType] = set() self._active_states: Set[NotificationType] = set()
self._notification_history: List[Notification] = [] self._notification_history: List[Notification] = []
def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler): def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
"""注册通知处理器 """注册通知处理器
Args: Args:
target: 接收者标识(例如:"pfc" target: 接收者标识(例如:"pfc"
notification_type: 要处理的通知类型 notification_type: 要处理的通知类型
@@ -95,10 +103,10 @@ class NotificationManager:
if notification_type not in self._handlers[target]: if notification_type not in self._handlers[target]:
self._handlers[target][notification_type] = [] self._handlers[target][notification_type] = []
self._handlers[target][notification_type].append(handler) self._handlers[target][notification_type].append(handler)
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler): def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
"""注销通知处理器 """注销通知处理器
Args: Args:
target: 接收者标识 target: 接收者标识
notification_type: 通知类型 notification_type: 通知类型
@@ -114,56 +122,56 @@ class NotificationManager:
# 如果该目标没有任何处理器,删除该目标 # 如果该目标没有任何处理器,删除该目标
if not self._handlers[target]: if not self._handlers[target]:
del self._handlers[target] del self._handlers[target]
async def send_notification(self, notification: Notification): async def send_notification(self, notification: Notification):
"""发送通知""" """发送通知"""
self._notification_history.append(notification) self._notification_history.append(notification)
# 如果是状态通知,更新活跃状态 # 如果是状态通知,更新活跃状态
if isinstance(notification, StateNotification): if isinstance(notification, StateNotification):
if notification.is_active: if notification.is_active:
self._active_states.add(notification.type) self._active_states.add(notification.type)
else: else:
self._active_states.discard(notification.type) self._active_states.discard(notification.type)
# 调用目标接收者的处理器 # 调用目标接收者的处理器
target = notification.target target = notification.target
if target in self._handlers: if target in self._handlers:
handlers = self._handlers[target].get(notification.type, []) handlers = self._handlers[target].get(notification.type, [])
for handler in handlers: for handler in handlers:
await handler.handle_notification(notification) await handler.handle_notification(notification)
def get_active_states(self) -> Set[NotificationType]: def get_active_states(self) -> Set[NotificationType]:
"""获取当前活跃的状态""" """获取当前活跃的状态"""
return self._active_states.copy() return self._active_states.copy()
def is_state_active(self, state_type: NotificationType) -> bool: def is_state_active(self, state_type: NotificationType) -> bool:
"""检查特定状态是否活跃""" """检查特定状态是否活跃"""
return state_type in self._active_states return state_type in self._active_states
def get_notification_history(self, def get_notification_history(
sender: Optional[str] = None, self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
target: Optional[str] = None, ) -> List[Notification]:
limit: Optional[int] = None) -> List[Notification]:
"""获取通知历史 """获取通知历史
Args: Args:
sender: 过滤特定发送者的通知 sender: 过滤特定发送者的通知
target: 过滤特定接收者的通知 target: 过滤特定接收者的通知
limit: 限制返回数量 limit: 限制返回数量
""" """
history = self._notification_history history = self._notification_history
if sender: if sender:
history = [n for n in history if n.sender == sender] history = [n for n in history if n.sender == sender]
if target: if target:
history = [n for n in history if n.target == target] history = [n for n in history if n.target == target]
if limit is not None: if limit is not None:
history = history[-limit:] history = history[-limit:]
return history return history
# 一些常用的通知创建函数 # 一些常用的通知创建函数
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification: def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
"""创建新消息通知""" """创建新消息通知"""
@@ -176,10 +184,11 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str,
"message_id": message.get("message_id"), "message_id": message.get("message_id"),
"content": message.get("content"), "content": message.get("content"),
"sender": message.get("sender"), "sender": message.get("sender"),
"time": message.get("time") "time": message.get("time"),
} },
) )
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification: def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
"""创建冷场状态通知""" """创建冷场状态通知"""
return StateNotification( return StateNotification(
@@ -188,9 +197,10 @@ def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> St
sender=sender, sender=sender,
target=target, target=target,
data={"is_cold": is_cold}, data={"is_cold": is_cold},
is_active=is_cold is_active=is_cold,
) )
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification: def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
"""创建活跃状态通知""" """创建活跃状态通知"""
return StateNotification( return StateNotification(
@@ -199,69 +209,70 @@ def create_active_chat_notification(sender: str, target: str, is_active: bool) -
sender=sender, sender=sender,
target=target, target=target,
data={"is_active": is_active}, data={"is_active": is_active},
is_active=is_active is_active=is_active,
) )
class ChatStateManager: class ChatStateManager:
"""聊天状态管理器""" """聊天状态管理器"""
def __init__(self): def __init__(self):
self.current_state = ChatState.NORMAL self.current_state = ChatState.NORMAL
self.state_info = ChatStateInfo(state=ChatState.NORMAL) self.state_info = ChatStateInfo(state=ChatState.NORMAL)
self.state_history: list[ChatStateInfo] = [] self.state_history: list[ChatStateInfo] = []
def update_state(self, new_state: ChatState, **kwargs): def update_state(self, new_state: ChatState, **kwargs):
"""更新聊天状态 """更新聊天状态
Args: Args:
new_state: 新的状态 new_state: 新的状态
**kwargs: 其他状态信息 **kwargs: 其他状态信息
""" """
self.current_state = new_state self.current_state = new_state
self.state_info.state = new_state self.state_info.state = new_state
# 更新其他状态信息 # 更新其他状态信息
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(self.state_info, key): if hasattr(self.state_info, key):
setattr(self.state_info, key, value) setattr(self.state_info, key, value)
# 记录状态历史 # 记录状态历史
self.state_history.append(self.state_info) self.state_history.append(self.state_info)
def get_current_state_info(self) -> ChatStateInfo: def get_current_state_info(self) -> ChatStateInfo:
"""获取当前状态信息""" """获取当前状态信息"""
return self.state_info return self.state_info
def get_state_history(self) -> list[ChatStateInfo]: def get_state_history(self) -> list[ChatStateInfo]:
"""获取状态历史""" """获取状态历史"""
return self.state_history return self.state_history
def is_cold_chat(self, threshold: float = 60.0) -> bool: def is_cold_chat(self, threshold: float = 60.0) -> bool:
"""判断是否处于冷场状态 """判断是否处于冷场状态
Args: Args:
threshold: 冷场阈值(秒) threshold: 冷场阈值(秒)
Returns: Returns:
bool: 是否冷场 bool: 是否冷场
""" """
if not self.state_info.last_message_time: if not self.state_info.last_message_time:
return True return True
current_time = datetime.now().timestamp() current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) > threshold return (current_time - self.state_info.last_message_time) > threshold
def is_active_chat(self, threshold: float = 5.0) -> bool: def is_active_chat(self, threshold: float = 5.0) -> bool:
"""判断是否处于活跃状态 """判断是否处于活跃状态
Args: Args:
threshold: 活跃阈值(秒) threshold: 活跃阈值(秒)
Returns: Returns:
bool: 是否活跃 bool: 是否活跃
""" """
if not self.state_info.last_message_time: if not self.state_info.last_message_time:
return False return False
current_time = datetime.now().timestamp() current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) <= threshold return (current_time - self.state_info.last_message_time) <= threshold

View File

@@ -20,23 +20,23 @@ logger = get_module_logger("pfc_conversation")
class Conversation: class Conversation:
"""对话类,负责管理单个对话的状态和行为""" """对话类,负责管理单个对话的状态和行为"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
"""初始化对话实例 """初始化对话实例
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
""" """
self.stream_id = stream_id self.stream_id = stream_id
self.state = ConversationState.INIT self.state = ConversationState.INIT
self.should_continue = False self.should_continue = False
# 回复相关 # 回复相关
self.generated_reply = "" self.generated_reply = ""
async def _initialize(self): async def _initialize(self):
"""初始化实例,注册所有组件""" """初始化实例,注册所有组件"""
try: try:
self.action_planner = ActionPlanner(self.stream_id) self.action_planner = ActionPlanner(self.stream_id)
self.goal_analyzer = GoalAnalyzer(self.stream_id) self.goal_analyzer = GoalAnalyzer(self.stream_id)
@@ -44,37 +44,35 @@ class Conversation:
self.knowledge_fetcher = KnowledgeFetcher() self.knowledge_fetcher = KnowledgeFetcher()
self.waiter = Waiter(self.stream_id) self.waiter = Waiter(self.stream_id)
self.direct_sender = DirectMessageSender() self.direct_sender = DirectMessageSender()
# 获取聊天流信息 # 获取聊天流信息
self.chat_stream = chat_manager.get_stream(self.stream_id) self.chat_stream = chat_manager.get_stream(self.stream_id)
self.stop_action_planner = False self.stop_action_planner = False
except Exception as e: except Exception as e:
logger.error(f"初始化对话实例:注册运行组件失败: {e}") logger.error(f"初始化对话实例:注册运行组件失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise raise
try: try:
#决策所需要的信息,包括自身自信和观察信息两部分 # 决策所需要的信息,包括自身自信和观察信息两部分
#注册观察器和观测信息 # 注册观察器和观测信息
self.chat_observer = ChatObserver.get_instance(self.stream_id) self.chat_observer = ChatObserver.get_instance(self.stream_id)
self.chat_observer.start() self.chat_observer.start()
self.observation_info = ObservationInfo() self.observation_info = ObservationInfo()
self.observation_info.bind_to_chat_observer(self.stream_id) self.observation_info.bind_to_chat_observer(self.stream_id)
#对话信息 # 对话信息
self.conversation_info = ConversationInfo() self.conversation_info = ConversationInfo()
except Exception as e: except Exception as e:
logger.error(f"初始化对话实例:注册信息组件失败: {e}") logger.error(f"初始化对话实例:注册信息组件失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise raise
# 组件准备完成,启动该论对话 # 组件准备完成,启动该论对话
self.should_continue = True self.should_continue = True
asyncio.create_task(self.start()) asyncio.create_task(self.start())
async def start(self): async def start(self):
"""开始对话流程""" """开始对话流程"""
try: try:
@@ -83,17 +81,13 @@ class Conversation:
except Exception as e: except Exception as e:
logger.error(f"启动对话系统失败: {e}") logger.error(f"启动对话系统失败: {e}")
raise raise
async def _plan_and_action_loop(self): async def _plan_and_action_loop(self):
"""思考步PFC核心循环模块""" """思考步PFC核心循环模块"""
# 获取最近的消息历史 # 获取最近的消息历史
while self.should_continue: while self.should_continue:
# 使用决策信息来辅助行动规划 # 使用决策信息来辅助行动规划
action, reason = await self.action_planner.plan( action, reason = await self.action_planner.plan(self.observation_info, self.conversation_info)
self.observation_info,
self.conversation_info
)
if self._check_new_messages_after_planning(): if self._check_new_messages_after_planning():
continue continue
@@ -107,93 +101,90 @@ class Conversation:
# 如果需要,可以在这里添加逻辑来根据新消息重新决定行动 # 如果需要,可以在这里添加逻辑来根据新消息重新决定行动
return True return True
return False return False
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message: def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象""" """将消息字典转换为Message对象"""
try: try:
chat_info = msg_dict.get("chat_info", {}) chat_info = msg_dict.get("chat_info", {})
chat_stream = ChatStream.from_dict(chat_info) chat_stream = ChatStream.from_dict(chat_info)
user_info = UserInfo.from_dict(msg_dict.get("user_info", {})) user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
return Message( return Message(
message_id=msg_dict["message_id"], message_id=msg_dict["message_id"],
chat_stream=chat_stream, chat_stream=chat_stream,
time=msg_dict["time"], time=msg_dict["time"],
user_info=user_info, user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""), processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", "") detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
) )
except Exception as e: except Exception as e:
logger.warning(f"转换消息时出错: {e}") logger.warning(f"转换消息时出错: {e}")
raise raise
async def _handle_action(self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo): async def _handle_action(
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
):
"""处理规划的行动""" """处理规划的行动"""
logger.info(f"执行行动: {action}, 原因: {reason}") logger.info(f"执行行动: {action}, 原因: {reason}")
# 记录action历史先设置为stop完成后再设置为done # 记录action历史先设置为stop完成后再设置为done
conversation_info.done_action.append({ conversation_info.done_action.append(
"action": action, {
"reason": reason, "action": action,
"status": "start", "reason": reason,
"time": datetime.datetime.now().strftime("%H:%M:%S") "status": "start",
}) "time": datetime.datetime.now().strftime("%H:%M:%S"),
}
)
if action == "direct_reply": if action == "direct_reply":
self.state = ConversationState.GENERATING self.state = ConversationState.GENERATING
self.generated_reply = await self.reply_generator.generate( self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
observation_info,
conversation_info
)
# # 检查回复是否合适 # # 检查回复是否合适
# is_suitable, reason, need_replan = await self.reply_generator.check_reply( # is_suitable, reason, need_replan = await self.reply_generator.check_reply(
# self.generated_reply, # self.generated_reply,
# self.current_goal # self.current_goal
# ) # )
if self._check_new_messages_after_planning(): if self._check_new_messages_after_planning():
return None return None
await self._send_reply() await self._send_reply()
conversation_info.done_action.append({ conversation_info.done_action.append(
"action": action, {
"reason": reason, "action": action,
"status": "done", "reason": reason,
"time": datetime.datetime.now().strftime("%H:%M:%S") "status": "done",
}) "time": datetime.datetime.now().strftime("%H:%M:%S"),
}
)
elif action == "fetch_knowledge": elif action == "fetch_knowledge":
self.state = ConversationState.FETCHING self.state = ConversationState.FETCHING
knowledge = "TODO:知识" knowledge = "TODO:知识"
topic = "TODO:关键词" topic = "TODO:关键词"
logger.info(f"假装获取到知识{knowledge},关键词是: {topic}") logger.info(f"假装获取到知识{knowledge},关键词是: {topic}")
if knowledge: if knowledge:
if topic not in self.conversation_info.knowledge_list: if topic not in self.conversation_info.knowledge_list:
self.conversation_info.knowledge_list.append({ self.conversation_info.knowledge_list.append({"topic": topic, "knowledge": knowledge})
"topic": topic,
"knowledge": knowledge
})
else: else:
self.conversation_info.knowledge_list[topic] += knowledge self.conversation_info.knowledge_list[topic] += knowledge
elif action == "rethink_goal": elif action == "rethink_goal":
self.state = ConversationState.RETHINKING self.state = ConversationState.RETHINKING
await self.goal_analyzer.analyze_goal(conversation_info, observation_info) await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
elif action == "listening": elif action == "listening":
self.state = ConversationState.LISTENING self.state = ConversationState.LISTENING
logger.info("倾听对方发言...") logger.info("倾听对方发言...")
if await self.waiter.wait(): # 如果返回True表示超时 if await self.waiter.wait(): # 如果返回True表示超时
await self._send_timeout_message() await self._send_timeout_message()
await self._stop_conversation() await self._stop_conversation()
else: # wait else: # wait
self.state = ConversationState.WAITING self.state = ConversationState.WAITING
logger.info("等待更多信息...") logger.info("等待更多信息...")
@@ -207,12 +198,10 @@ class Conversation:
messages = self.chat_observer.get_cached_messages(limit=1) messages = self.chat_observer.get_cached_messages(limit=1)
if not messages: if not messages:
return return
latest_message = self._convert_to_message(messages[0]) latest_message = self._convert_to_message(messages[0])
await self.direct_sender.send_message( await self.direct_sender.send_message(
chat_stream=self.chat_stream, chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
content="TODO:超时消息",
reply_to_message=latest_message
) )
except Exception as e: except Exception as e:
logger.error(f"发送超时消息失败: {str(e)}") logger.error(f"发送超时消息失败: {str(e)}")
@@ -222,24 +211,22 @@ class Conversation:
if not self.generated_reply: if not self.generated_reply:
logger.warning("没有生成回复") logger.warning("没有生成回复")
return return
messages = self.chat_observer.get_cached_messages(limit=1) messages = self.chat_observer.get_cached_messages(limit=1)
if not messages: if not messages:
logger.warning("没有最近的消息可以回复") logger.warning("没有最近的消息可以回复")
return return
latest_message = self._convert_to_message(messages[0]) latest_message = self._convert_to_message(messages[0])
try: try:
await self.direct_sender.send_message( await self.direct_sender.send_message(
chat_stream=self.chat_stream, chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
content=self.generated_reply,
reply_to_message=latest_message
) )
self.chat_observer.trigger_update() # 触发立即更新 self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update(): if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时") logger.warning("等待消息更新超时")
self.state = ConversationState.ANALYZING self.state = ConversationState.ANALYZING
except Exception as e: except Exception as e:
logger.error(f"发送消息失败: {str(e)}") logger.error(f"发送消息失败: {str(e)}")
self.state = ConversationState.ANALYZING self.state = ConversationState.ANALYZING

View File

@@ -1,8 +1,6 @@
class ConversationInfo: class ConversationInfo:
def __init__(self): def __init__(self):
self.done_action = [] self.done_action = []
self.goal_list = [] self.goal_list = []
self.knowledge_list = [] self.knowledge_list = []
self.memory_list = [] self.memory_list = []

View File

@@ -7,12 +7,13 @@ from src.plugins.chat.message import MessageSending
logger = get_module_logger("message_sender") logger = get_module_logger("message_sender")
class DirectMessageSender: class DirectMessageSender:
"""直接消息发送器""" """直接消息发送器"""
def __init__(self): def __init__(self):
pass pass
async def send_message( async def send_message(
self, self,
chat_stream: ChatStream, chat_stream: ChatStream,
@@ -20,7 +21,7 @@ class DirectMessageSender:
reply_to_message: Optional[Message] = None, reply_to_message: Optional[Message] = None,
) -> None: ) -> None:
"""发送消息到聊天流 """发送消息到聊天流
Args: Args:
chat_stream: 聊天流 chat_stream: 聊天流
content: 消息内容 content: 消息内容
@@ -29,21 +30,18 @@ class DirectMessageSender:
try: try:
# 创建消息内容 # 创建消息内容
segments = [Seg(type="text", data={"text": content})] segments = [Seg(type="text", data={"text": content})]
# 检查是否需要引用回复 # 检查是否需要引用回复
if reply_to_message: if reply_to_message:
reply_id = reply_to_message.message_id reply_id = reply_to_message.message_id
message_sending = MessageSending( message_sending = MessageSending(segments=segments, reply_to_id=reply_id)
segments=segments,
reply_to_id=reply_id
)
else: else:
message_sending = MessageSending(segments=segments) message_sending = MessageSending(segments=segments)
# 发送消息 # 发送消息
await chat_stream.send_message(message_sending) await chat_stream.send_message(message_sending)
logger.info(f"消息已发送: {content}") logger.info(f"消息已发送: {content}")
except Exception as e: except Exception as e:
logger.error(f"发送消息失败: {str(e)}") logger.error(f"发送消息失败: {str(e)}")
raise raise

View File

@@ -2,133 +2,126 @@ from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from src.common.database import db from src.common.database import db
class MessageStorage(ABC): class MessageStorage(ABC):
"""消息存储接口""" """消息存储接口"""
@abstractmethod @abstractmethod
async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取指定消息ID之后的所有消息 """获取指定消息ID之后的所有消息
Args: Args:
chat_id: 聊天ID chat_id: 聊天ID
message_id: 消息ID如果为None则获取所有消息 message_id: 消息ID如果为None则获取所有消息
Returns: Returns:
List[Dict[str, Any]]: 消息列表 List[Dict[str, Any]]: 消息列表
""" """
pass pass
@abstractmethod @abstractmethod
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
"""获取指定时间点之前的消息 """获取指定时间点之前的消息
Args: Args:
chat_id: 聊天ID chat_id: 聊天ID
time_point: 时间戳 time_point: 时间戳
limit: 最大消息数量 limit: 最大消息数量
Returns: Returns:
List[Dict[str, Any]]: 消息列表 List[Dict[str, Any]]: 消息列表
""" """
pass pass
@abstractmethod @abstractmethod
async def has_new_messages(self, chat_id: str, after_time: float) -> bool: async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
"""检查是否有新消息 """检查是否有新消息
Args: Args:
chat_id: 聊天ID chat_id: 聊天ID
after_time: 时间戳 after_time: 时间戳
Returns: Returns:
bool: 是否有新消息 bool: 是否有新消息
""" """
pass pass
class MongoDBMessageStorage(MessageStorage): class MongoDBMessageStorage(MessageStorage):
"""MongoDB消息存储实现""" """MongoDB消息存储实现"""
def __init__(self): def __init__(self):
self.db = db self.db = db
async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id} query = {"chat_id": chat_id}
if message_id: if message_id:
# 获取ID大于message_id的消息 # 获取ID大于message_id的消息
last_message = self.db.messages.find_one({"message_id": message_id}) last_message = self.db.messages.find_one({"message_id": message_id})
if last_message: if last_message:
query["time"] = {"$gt": last_message["time"]} query["time"] = {"$gt": last_message["time"]}
return list( return list(self.db.messages.find(query).sort("time", 1))
self.db.messages.find(query).sort("time", 1)
)
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = { query = {"chat_id": chat_id, "time": {"$lt": time_point}}
"chat_id": chat_id,
"time": {"$lt": time_point} messages = list(self.db.messages.find(query).sort("time", -1).limit(limit))
}
messages = list(
self.db.messages.find(query).sort("time", -1).limit(limit)
)
# 将消息按时间正序排列 # 将消息按时间正序排列
messages.reverse() messages.reverse()
return messages return messages
async def has_new_messages(self, chat_id: str, after_time: float) -> bool: async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = { query = {"chat_id": chat_id, "time": {"$gt": after_time}}
"chat_id": chat_id,
"time": {"$gt": after_time}
}
return self.db.messages.find_one(query) is not None return self.db.messages.find_one(query) is not None
# # 创建一个内存消息存储实现,用于测试 # # 创建一个内存消息存储实现,用于测试
# class InMemoryMessageStorage(MessageStorage): # class InMemoryMessageStorage(MessageStorage):
# """内存消息存储实现,主要用于测试""" # """内存消息存储实现,主要用于测试"""
# def __init__(self): # def __init__(self):
# self.messages: Dict[str, List[Dict[str, Any]]] = {} # self.messages: Dict[str, List[Dict[str, Any]]] = {}
# async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: # async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
# if chat_id not in self.messages: # if chat_id not in self.messages:
# return [] # return []
# messages = self.messages[chat_id] # messages = self.messages[chat_id]
# if not message_id: # if not message_id:
# return messages # return messages
# # 找到message_id的索引 # # 找到message_id的索引
# try: # try:
# index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id) # index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id)
# return messages[index + 1:] # return messages[index + 1:]
# except StopIteration: # except StopIteration:
# return [] # return []
# async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: # async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
# if chat_id not in self.messages: # if chat_id not in self.messages:
# return [] # return []
# messages = [ # messages = [
# m for m in self.messages[chat_id] # m for m in self.messages[chat_id]
# if m["time"] < time_point # if m["time"] < time_point
# ] # ]
# return messages[-limit:] # return messages[-limit:]
# async def has_new_messages(self, chat_id: str, after_time: float) -> bool: # async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
# if chat_id not in self.messages: # if chat_id not in self.messages:
# return False # return False
# return any(m["time"] > after_time for m in self.messages[chat_id]) # return any(m["time"] > after_time for m in self.messages[chat_id])
# # 测试辅助方法 # # 测试辅助方法
# def add_message(self, chat_id: str, message: Dict[str, Any]): # def add_message(self, chat_id: str, message: Dict[str, Any]):
# """添加测试消息""" # """添加测试消息"""
# if chat_id not in self.messages: # if chat_id not in self.messages:
# self.messages[chat_id] = [] # self.messages[chat_id] = []
# self.messages[chat_id].append(message) # self.messages[chat_id].append(message)
# self.messages[chat_id].sort(key=lambda m: m["time"]) # self.messages[chat_id].sort(key=lambda m: m["time"])

View File

@@ -7,25 +7,26 @@ if TYPE_CHECKING:
logger = get_module_logger("notification_handler") logger = get_module_logger("notification_handler")
class PFCNotificationHandler(NotificationHandler): class PFCNotificationHandler(NotificationHandler):
"""PFC通知处理器""" """PFC通知处理器"""
def __init__(self, conversation: 'Conversation'): def __init__(self, conversation: "Conversation"):
"""初始化PFC通知处理器 """初始化PFC通知处理器
Args: Args:
conversation: 对话实例 conversation: 对话实例
""" """
self.conversation = conversation self.conversation = conversation
async def handle_notification(self, notification: Notification): async def handle_notification(self, notification: Notification):
"""处理通知 """处理通知
Args: Args:
notification: 通知对象 notification: 通知对象
""" """
logger.debug(f"收到通知: {notification.type.name}, 数据: {notification.data}") logger.debug(f"收到通知: {notification.type.name}, 数据: {notification.data}")
# 根据通知类型执行不同的处理 # 根据通知类型执行不同的处理
if notification.type == NotificationType.NEW_MESSAGE: if notification.type == NotificationType.NEW_MESSAGE:
# 新消息通知 # 新消息通知
@@ -38,34 +39,33 @@ class PFCNotificationHandler(NotificationHandler):
await self._handle_command(notification) await self._handle_command(notification)
else: else:
logger.warning(f"未知的通知类型: {notification.type.name}") logger.warning(f"未知的通知类型: {notification.type.name}")
async def _handle_new_message(self, notification: Notification): async def _handle_new_message(self, notification: Notification):
"""处理新消息通知 """处理新消息通知
Args: Args:
notification: 通知对象 notification: 通知对象
""" """
# 更新决策信息 # 更新决策信息
observation_info = self.conversation.observation_info observation_info = self.conversation.observation_info
observation_info.last_message_time = notification.data.get("time", 0) observation_info.last_message_time = notification.data.get("time", 0)
observation_info.add_unprocessed_message(notification.data) observation_info.add_unprocessed_message(notification.data)
# 手动触发观察器更新 # 手动触发观察器更新
self.conversation.chat_observer.trigger_update() self.conversation.chat_observer.trigger_update()
async def _handle_cold_chat(self, notification: Notification): async def _handle_cold_chat(self, notification: Notification):
"""处理冷聊天通知 """处理冷聊天通知
Args: Args:
notification: 通知对象 notification: 通知对象
""" """
# 获取冷聊天信息 # 获取冷聊天信息
cold_duration = notification.data.get("duration", 0) cold_duration = notification.data.get("duration", 0)
# 更新决策信息 # 更新决策信息
observation_info = self.conversation.observation_info observation_info = self.conversation.observation_info
observation_info.conversation_cold_duration = cold_duration observation_info.conversation_cold_duration = cold_duration
logger.info(f"对话已冷: {cold_duration}") logger.info(f"对话已冷: {cold_duration}")

View File

@@ -1,5 +1,5 @@
#Programmable Friendly Conversationalist # Programmable Friendly Conversationalist
#Prefrontal cortex # Prefrontal cortex
from typing import List, Optional, Dict, Any, Set from typing import List, Optional, Dict, Any, Set
from ..message.message_base import UserInfo from ..message.message_base import UserInfo
import time import time
@@ -10,26 +10,27 @@ from .chat_states import NotificationHandler
logger = get_module_logger("observation_info") logger = get_module_logger("observation_info")
class ObservationInfoHandler(NotificationHandler): class ObservationInfoHandler(NotificationHandler):
"""ObservationInfo的通知处理器""" """ObservationInfo的通知处理器"""
def __init__(self, observation_info: 'ObservationInfo'): def __init__(self, observation_info: "ObservationInfo"):
"""初始化处理器 """初始化处理器
Args: Args:
observation_info: 要更新的ObservationInfo实例 observation_info: 要更新的ObservationInfo实例
""" """
self.observation_info = observation_info self.observation_info = observation_info
async def handle_notification(self, notification: Dict[str, Any]): async def handle_notification(self, notification: Dict[str, Any]):
"""处理通知 """处理通知
Args: Args:
notification: 通知数据 notification: 通知数据
""" """
notification_type = notification.get("type") notification_type = notification.get("type")
data = notification.get("data", {}) data = notification.get("data", {})
if notification_type == "NEW_MESSAGE": if notification_type == "NEW_MESSAGE":
# 处理新消息通知 # 处理新消息通知
logger.debug(f"收到新消息通知data: {data}") logger.debug(f"收到新消息通知data: {data}")
@@ -37,62 +38,62 @@ class ObservationInfoHandler(NotificationHandler):
self.observation_info.update_from_message(message) self.observation_info.update_from_message(message)
# self.observation_info.has_unread_messages = True # self.observation_info.has_unread_messages = True
# self.observation_info.new_unread_message.append(message.get("processed_plain_text", "")) # self.observation_info.new_unread_message.append(message.get("processed_plain_text", ""))
elif notification_type == "COLD_CHAT": elif notification_type == "COLD_CHAT":
# 处理冷场通知 # 处理冷场通知
is_cold = data.get("is_cold", False) is_cold = data.get("is_cold", False)
self.observation_info.update_cold_chat_status(is_cold, time.time()) self.observation_info.update_cold_chat_status(is_cold, time.time())
elif notification_type == "ACTIVE_CHAT": elif notification_type == "ACTIVE_CHAT":
# 处理活跃通知 # 处理活跃通知
is_active = data.get("is_active", False) is_active = data.get("is_active", False)
self.observation_info.is_cold = not is_active self.observation_info.is_cold = not is_active
elif notification_type == "BOT_SPEAKING": elif notification_type == "BOT_SPEAKING":
# 处理机器人说话通知 # 处理机器人说话通知
self.observation_info.is_typing = False self.observation_info.is_typing = False
self.observation_info.last_bot_speak_time = time.time() self.observation_info.last_bot_speak_time = time.time()
elif notification_type == "USER_SPEAKING": elif notification_type == "USER_SPEAKING":
# 处理用户说话通知 # 处理用户说话通知
self.observation_info.is_typing = False self.observation_info.is_typing = False
self.observation_info.last_user_speak_time = time.time() self.observation_info.last_user_speak_time = time.time()
elif notification_type == "MESSAGE_DELETED": elif notification_type == "MESSAGE_DELETED":
# 处理消息删除通知 # 处理消息删除通知
message_id = data.get("message_id") message_id = data.get("message_id")
self.observation_info.unprocessed_messages = [ self.observation_info.unprocessed_messages = [
msg for msg in self.observation_info.unprocessed_messages msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
if msg.get("message_id") != message_id
] ]
elif notification_type == "USER_JOINED": elif notification_type == "USER_JOINED":
# 处理用户加入通知 # 处理用户加入通知
user_id = data.get("user_id") user_id = data.get("user_id")
if user_id: if user_id:
self.observation_info.active_users.add(user_id) self.observation_info.active_users.add(user_id)
elif notification_type == "USER_LEFT": elif notification_type == "USER_LEFT":
# 处理用户离开通知 # 处理用户离开通知
user_id = data.get("user_id") user_id = data.get("user_id")
if user_id: if user_id:
self.observation_info.active_users.discard(user_id) self.observation_info.active_users.discard(user_id)
elif notification_type == "ERROR": elif notification_type == "ERROR":
# 处理错误通知 # 处理错误通知
error_msg = data.get("error", "") error_msg = data.get("error", "")
logger.error(f"收到错误通知: {error_msg}") logger.error(f"收到错误通知: {error_msg}")
@dataclass @dataclass
class ObservationInfo: class ObservationInfo:
"""决策信息类用于收集和管理来自chat_observer的通知信息""" """决策信息类用于收集和管理来自chat_observer的通知信息"""
#data_list # data_list
chat_history: List[str] = field(default_factory=list) chat_history: List[str] = field(default_factory=list)
unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list) unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list)
active_users: Set[str] = field(default_factory=set) active_users: Set[str] = field(default_factory=set)
#data # data
last_bot_speak_time: Optional[float] = None last_bot_speak_time: Optional[float] = None
last_user_speak_time: Optional[float] = None last_user_speak_time: Optional[float] = None
last_message_time: Optional[float] = None last_message_time: Optional[float] = None
@@ -101,78 +102,70 @@ class ObservationInfo:
bot_id: Optional[str] = None bot_id: Optional[str] = None
new_messages_count: int = 0 new_messages_count: int = 0
cold_chat_duration: float = 0.0 cold_chat_duration: float = 0.0
#state # state
is_typing: bool = False is_typing: bool = False
has_unread_messages: bool = False has_unread_messages: bool = False
is_cold_chat: bool = False is_cold_chat: bool = False
changed: bool = False changed: bool = False
# #spec # #spec
# meta_plan_trigger: bool = False # meta_plan_trigger: bool = False
def __post_init__(self): def __post_init__(self):
"""初始化后创建handler""" """初始化后创建handler"""
self.chat_observer = None self.chat_observer = None
self.handler = ObservationInfoHandler(self) self.handler = ObservationInfoHandler(self)
def bind_to_chat_observer(self, stream_id: str): def bind_to_chat_observer(self, stream_id: str):
"""绑定到指定的chat_observer """绑定到指定的chat_observer
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
""" """
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
self.chat_observer.notification_manager.register_handler( self.chat_observer.notification_manager.register_handler(
target="observation_info", target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
notification_type="NEW_MESSAGE",
handler=self.handler
) )
self.chat_observer.notification_manager.register_handler( self.chat_observer.notification_manager.register_handler(
target="observation_info", target="observation_info", notification_type="COLD_CHAT", handler=self.handler
notification_type="COLD_CHAT",
handler=self.handler
) )
def unbind_from_chat_observer(self): def unbind_from_chat_observer(self):
"""解除与chat_observer的绑定""" """解除与chat_observer的绑定"""
if self.chat_observer: if self.chat_observer:
self.chat_observer.notification_manager.unregister_handler( self.chat_observer.notification_manager.unregister_handler(
target="observation_info", target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
notification_type="NEW_MESSAGE",
handler=self.handler
) )
self.chat_observer.notification_manager.unregister_handler( self.chat_observer.notification_manager.unregister_handler(
target="observation_info", target="observation_info", notification_type="COLD_CHAT", handler=self.handler
notification_type="COLD_CHAT",
handler=self.handler
) )
self.chat_observer = None self.chat_observer = None
def update_from_message(self, message: Dict[str, Any]): def update_from_message(self, message: Dict[str, Any]):
"""从消息更新信息 """从消息更新信息
Args: Args:
message: 消息数据 message: 消息数据
""" """
logger.debug(f"更新信息from_message: {message}") logger.debug(f"更新信息from_message: {message}")
self.last_message_time = message["time"] self.last_message_time = message["time"]
self.last_message_content = message.get("processed_plain_text", "") self.last_message_content = message.get("processed_plain_text", "")
user_info = UserInfo.from_dict(message.get("user_info", {})) user_info = UserInfo.from_dict(message.get("user_info", {}))
self.last_message_sender = user_info.user_id self.last_message_sender = user_info.user_id
if user_info.user_id == self.bot_id: if user_info.user_id == self.bot_id:
self.last_bot_speak_time = message["time"] self.last_bot_speak_time = message["time"]
else: else:
self.last_user_speak_time = message["time"] self.last_user_speak_time = message["time"]
self.active_users.add(user_info.user_id) self.active_users.add(user_info.user_id)
self.new_messages_count += 1 self.new_messages_count += 1
self.unprocessed_messages.append(message) self.unprocessed_messages.append(message)
self.update_changed() self.update_changed()
def update_changed(self): def update_changed(self):
"""更新changed状态""" """更新changed状态"""
self.changed = True self.changed = True
@@ -180,7 +173,7 @@ class ObservationInfo:
def update_cold_chat_status(self, is_cold: bool, current_time: float): def update_cold_chat_status(self, is_cold: bool, current_time: float):
"""更新冷场状态 """更新冷场状态
Args: Args:
is_cold: 是否冷场 is_cold: 是否冷场
current_time: 当前时间 current_time: 当前时间
@@ -188,37 +181,37 @@ class ObservationInfo:
self.is_cold_chat = is_cold self.is_cold_chat = is_cold
if is_cold and self.last_message_time: if is_cold and self.last_message_time:
self.cold_chat_duration = current_time - self.last_message_time self.cold_chat_duration = current_time - self.last_message_time
def get_active_duration(self) -> float: def get_active_duration(self) -> float:
"""获取当前活跃时长 """获取当前活跃时长
Returns: Returns:
float: 最后一条消息到现在的时长(秒) float: 最后一条消息到现在的时长(秒)
""" """
if not self.last_message_time: if not self.last_message_time:
return 0.0 return 0.0
return time.time() - self.last_message_time return time.time() - self.last_message_time
def get_user_response_time(self) -> Optional[float]: def get_user_response_time(self) -> Optional[float]:
"""获取用户响应时间 """获取用户响应时间
Returns: Returns:
Optional[float]: 用户最后发言到现在的时长如果没有用户发言则返回None Optional[float]: 用户最后发言到现在的时长如果没有用户发言则返回None
""" """
if not self.last_user_speak_time: if not self.last_user_speak_time:
return None return None
return time.time() - self.last_user_speak_time return time.time() - self.last_user_speak_time
def get_bot_response_time(self) -> Optional[float]: def get_bot_response_time(self) -> Optional[float]:
"""获取机器人响应时间 """获取机器人响应时间
Returns: Returns:
Optional[float]: 机器人最后发言到现在的时长如果没有机器人发言则返回None Optional[float]: 机器人最后发言到现在的时长如果没有机器人发言则返回None
""" """
if not self.last_bot_speak_time: if not self.last_bot_speak_time:
return None return None
return time.time() - self.last_bot_speak_time return time.time() - self.last_bot_speak_time
def clear_unprocessed_messages(self): def clear_unprocessed_messages(self):
"""清空未处理消息列表""" """清空未处理消息列表"""
# 将未处理消息添加到历史记录中 # 将未处理消息添加到历史记录中
@@ -229,10 +222,10 @@ class ObservationInfo:
self.has_unread_messages = False self.has_unread_messages = False
self.unprocessed_messages.clear() self.unprocessed_messages.clear()
self.new_messages_count = 0 self.new_messages_count = 0
def add_unprocessed_message(self, message: Dict[str, Any]): def add_unprocessed_message(self, message: Dict[str, Any]):
"""添加未处理的消息 """添加未处理的消息
Args: Args:
message: 消息数据 message: 消息数据
""" """
@@ -241,6 +234,6 @@ class ObservationInfo:
if message_id and not any(m.get("message_id") == message_id for m in self.unprocessed_messages): if message_id and not any(m.get("message_id") == message_id for m in self.unprocessed_messages):
self.unprocessed_messages.append(message) self.unprocessed_messages.append(message)
self.new_messages_count += 1 self.new_messages_count += 1
# 同时更新其他消息相关信息 # 同时更新其他消息相关信息
self.update_from_message(message) self.update_from_message(message)

View File

@@ -49,43 +49,40 @@ class GoalAnalyzer:
Args: Args:
conversation_info: 对话信息 conversation_info: 对话信息
observation_info: 观察信息 observation_info: 观察信息
Returns: Returns:
Tuple[str, str, str]: (目标, 方法, 原因) Tuple[str, str, str]: (目标, 方法, 原因)
""" """
#构建对话目标 # 构建对话目标
goal_list = conversation_info.goal_list goal_list = conversation_info.goal_list
goal_text = "" goal_text = ""
for goal, reason in goal_list: for goal, reason in goal_list:
goal_text += f"目标:{goal};" goal_text += f"目标:{goal};"
goal_text += f"原因:{reason}\n" goal_text += f"原因:{reason}\n"
# 获取聊天历史记录 # 获取聊天历史记录
chat_history_list = observation_info.chat_history chat_history_list = observation_info.chat_history
chat_history_text = "" chat_history_text = ""
for msg in chat_history_list: for msg in chat_history_list:
chat_history_text += f"{msg}\n" chat_history_text += f"{msg}\n"
if observation_info.new_messages_count > 0: if observation_info.new_messages_count > 0:
new_messages_list = observation_info.unprocessed_messages new_messages_list = observation_info.unprocessed_messages
chat_history_text += f"{observation_info.new_messages_count}条新消息:\n" chat_history_text += f"{observation_info.new_messages_count}条新消息:\n"
for msg in new_messages_list: for msg in new_messages_list:
chat_history_text += f"{msg}\n" chat_history_text += f"{msg}\n"
observation_info.clear_unprocessed_messages() observation_info.clear_unprocessed_messages()
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本 # 构建action历史文本
action_history_list = conversation_info.done_action action_history_list = conversation_info.done_action
action_history_text = "你之前做的事情是:" action_history_text = "你之前做的事情是:"
for action in action_history_list: for action in action_history_list:
action_history_text += f"{action}\n" action_history_text += f"{action}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。 prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。
这些目标应该反映出对话的不同方面和意图。 这些目标应该反映出对话的不同方面和意图。
@@ -119,20 +116,15 @@ class GoalAnalyzer:
logger.debug(f"发送到LLM的提示词: {prompt}") logger.debug(f"发送到LLM的提示词: {prompt}")
content, _ = await self.llm.generate_response_async(prompt) content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}") logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容 # 使用简化函数提取JSON内容
success, result = get_items_from_json( success, result = get_items_from_json(
content, content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
"goal", "reasoning",
required_types={"goal": str, "reasoning": str}
) )
#TODO # TODO
conversation_info.goal_list.append(result) conversation_info.goal_list.append(result)
async def _update_goals(self, new_goal: str, method: str, reasoning: str): async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表 """更新目标列表
@@ -229,24 +221,26 @@ class GoalAnalyzer:
try: try:
content, _ = await self.llm.generate_response_async(prompt) content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}") logger.debug(f"LLM原始返回内容: {content}")
# 尝试解析JSON # 尝试解析JSON
success, result = get_items_from_json( success, result = get_items_from_json(
content, content,
"goal_achieved", "stop_conversation", "reason", "goal_achieved",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str} "stop_conversation",
"reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
) )
if not success: if not success:
logger.error("无法解析对话分析结果JSON") logger.error("无法解析对话分析结果JSON")
return False, False, "解析结果失败" return False, False, "解析结果失败"
goal_achieved = result["goal_achieved"] goal_achieved = result["goal_achieved"]
stop_conversation = result["stop_conversation"] stop_conversation = result["stop_conversation"]
reason = result["reason"] reason = result["reason"]
return goal_achieved, stop_conversation, reason return goal_achieved, stop_conversation, reason
except Exception as e: except Exception as e:
logger.error(f"分析对话状态时出错: {str(e)}") logger.error(f"分析对话状态时出错: {str(e)}")
return False, False, f"分析出错: {str(e)}" return False, False, f"分析出错: {str(e)}"
@@ -269,23 +263,22 @@ class Waiter:
# 使用当前时间作为等待开始时间 # 使用当前时间作为等待开始时间
wait_start_time = time.time() wait_start_time = time.time()
self.chat_observer.waiting_start_time = wait_start_time # 设置等待开始时间 self.chat_observer.waiting_start_time = wait_start_time # 设置等待开始时间
while True: while True:
# 检查是否有新消息 # 检查是否有新消息
if self.chat_observer.new_message_after(wait_start_time): if self.chat_observer.new_message_after(wait_start_time):
logger.info("等待结束,收到新消息") logger.info("等待结束,收到新消息")
return False return False
# 检查是否超时 # 检查是否超时
if time.time() - wait_start_time > 300: if time.time() - wait_start_time > 300:
logger.info("等待超过300秒结束对话") logger.info("等待超过300秒结束对话")
return True return True
await asyncio.sleep(1) await asyncio.sleep(1)
logger.info("等待中...") logger.info("等待中...")
class DirectMessageSender: class DirectMessageSender:
"""直接发送消息到平台的发送器""" """直接发送消息到平台的发送器"""

View File

@@ -5,33 +5,34 @@ import traceback
logger = get_module_logger("pfc_manager") logger = get_module_logger("pfc_manager")
class PFCManager: class PFCManager:
"""PFC对话管理器负责管理所有对话实例""" """PFC对话管理器负责管理所有对话实例"""
# 单例模式 # 单例模式
_instance = None _instance = None
# 会话实例管理 # 会话实例管理
_instances: Dict[str, Conversation] = {} _instances: Dict[str, Conversation] = {}
_initializing: Dict[str, bool] = {} _initializing: Dict[str, bool] = {}
@classmethod @classmethod
def get_instance(cls) -> 'PFCManager': def get_instance(cls) -> "PFCManager":
"""获取管理器单例 """获取管理器单例
Returns: Returns:
PFCManager: 管理器实例 PFCManager: 管理器实例
""" """
if cls._instance is None: if cls._instance is None:
cls._instance = PFCManager() cls._instance = PFCManager()
return cls._instance return cls._instance
async def get_or_create_conversation(self, stream_id: str) -> Optional[Conversation]: async def get_or_create_conversation(self, stream_id: str) -> Optional[Conversation]:
"""获取或创建对话实例 """获取或创建对话实例
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
Returns: Returns:
Optional[Conversation]: 对话实例创建失败则返回None Optional[Conversation]: 对话实例创建失败则返回None
""" """
@@ -39,11 +40,11 @@ class PFCManager:
if stream_id in self._initializing and self._initializing[stream_id]: if stream_id in self._initializing and self._initializing[stream_id]:
logger.debug(f"会话实例正在初始化中: {stream_id}") logger.debug(f"会话实例正在初始化中: {stream_id}")
return None return None
if stream_id in self._instances: if stream_id in self._instances:
logger.debug(f"使用现有会话实例: {stream_id}") logger.debug(f"使用现有会话实例: {stream_id}")
return self._instances[stream_id] return self._instances[stream_id]
try: try:
# 创建新实例 # 创建新实例
logger.info(f"创建新的对话实例: {stream_id}") logger.info(f"创建新的对话实例: {stream_id}")
@@ -51,47 +52,45 @@ class PFCManager:
# 创建实例 # 创建实例
conversation_instance = Conversation(stream_id) conversation_instance = Conversation(stream_id)
self._instances[stream_id] = conversation_instance self._instances[stream_id] = conversation_instance
# 启动实例初始化 # 启动实例初始化
await self._initialize_conversation(conversation_instance) await self._initialize_conversation(conversation_instance)
except Exception as e: except Exception as e:
logger.error(f"创建会话实例失败: {stream_id}, 错误: {e}") logger.error(f"创建会话实例失败: {stream_id}, 错误: {e}")
return None return None
return conversation_instance return conversation_instance
async def _initialize_conversation(self, conversation: Conversation): async def _initialize_conversation(self, conversation: Conversation):
"""初始化会话实例 """初始化会话实例
Args: Args:
conversation: 要初始化的会话实例 conversation: 要初始化的会话实例
""" """
stream_id = conversation.stream_id stream_id = conversation.stream_id
try: try:
logger.info(f"开始初始化会话实例: {stream_id}") logger.info(f"开始初始化会话实例: {stream_id}")
# 启动初始化流程 # 启动初始化流程
await conversation._initialize() await conversation._initialize()
# 标记初始化完成 # 标记初始化完成
self._initializing[stream_id] = False self._initializing[stream_id] = False
logger.info(f"会话实例 {stream_id} 初始化完成") logger.info(f"会话实例 {stream_id} 初始化完成")
except Exception as e: except Exception as e:
logger.error(f"管理器初始化会话实例失败: {stream_id}, 错误: {e}") logger.error(f"管理器初始化会话实例失败: {stream_id}, 错误: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
# 清理失败的初始化 # 清理失败的初始化
async def get_conversation(self, stream_id: str) -> Optional[Conversation]: async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
"""获取已存在的会话实例 """获取已存在的会话实例
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
Returns: Returns:
Optional[Conversation]: 会话实例不存在则返回None Optional[Conversation]: 会话实例不存在则返回None
""" """
return self._instances.get(stream_id) return self._instances.get(stream_id)

View File

@@ -4,6 +4,7 @@ from typing import Literal
class ConversationState(Enum): class ConversationState(Enum):
"""对话状态""" """对话状态"""
INIT = "初始化" INIT = "初始化"
RETHINKING = "重新思考" RETHINKING = "重新思考"
ANALYZING = "分析历史" ANALYZING = "分析历史"
@@ -18,4 +19,4 @@ class ConversationState(Enum):
JUDGING = "判断" JUDGING = "判断"
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"] ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]

View File

@@ -13,33 +13,26 @@ logger = get_module_logger("reply_generator")
class ReplyGenerator: class ReplyGenerator:
"""回复生成器""" """回复生成器"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation"
temperature=0.7,
max_tokens=300,
request_type="reply_generation"
) )
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
self.reply_checker = ReplyChecker(stream_id) self.reply_checker = ReplyChecker(stream_id)
async def generate( async def generate(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> str:
self,
observation_info: ObservationInfo,
conversation_info: ConversationInfo
) -> str:
"""生成回复 """生成回复
Args: Args:
goal: 对话目标 goal: 对话目标
chat_history: 聊天历史 chat_history: 聊天历史
knowledge_cache: 知识缓存 knowledge_cache: 知识缓存
previous_reply: 上一次生成的回复(如果有) previous_reply: 上一次生成的回复(如果有)
retry_count: 当前重试次数 retry_count: 当前重试次数
Returns: Returns:
str: 生成的回复 str: 生成的回复
""" """
@@ -51,22 +44,21 @@ class ReplyGenerator:
for goal, reason in goal_list: for goal, reason in goal_list:
goal_text += f"目标:{goal};" goal_text += f"目标:{goal};"
goal_text += f"原因:{reason}\n" goal_text += f"原因:{reason}\n"
# 获取聊天历史记录 # 获取聊天历史记录
chat_history_list = observation_info.chat_history chat_history_list = observation_info.chat_history
chat_history_text = "" chat_history_text = ""
for msg in chat_history_list: for msg in chat_history_list:
chat_history_text += f"{msg}\n" chat_history_text += f"{msg}\n"
# 整理知识缓存 # 整理知识缓存
knowledge_text = "" knowledge_text = ""
knowledge_list = conversation_info.knowledge_list knowledge_list = conversation_info.knowledge_list
for knowledge in knowledge_list: for knowledge in knowledge_list:
knowledge_text += f"知识:{knowledge}\n" knowledge_text += f"知识:{knowledge}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请根据以下信息生成回复 prompt = f"""{personality_text}。现在你在参与一场QQ聊天请根据以下信息生成回复
当前对话目标:{goal_text} 当前对话目标:{goal_text}
@@ -92,7 +84,7 @@ class ReplyGenerator:
logger.info(f"生成的回复: {content}") logger.info(f"生成的回复: {content}")
# is_new = self.chat_observer.check() # is_new = self.chat_observer.check()
# logger.debug(f"再看一眼聊天记录,{'有' if is_new else '没有'}新消息") # logger.debug(f"再看一眼聊天记录,{'有' if is_new else '没有'}新消息")
# 如果有新消息,重新生成回复 # 如果有新消息,重新生成回复
# if is_new: # if is_new:
# logger.info("检测到新消息,重新生成回复") # logger.info("检测到新消息,重新生成回复")
@@ -100,27 +92,22 @@ class ReplyGenerator:
# goal, chat_history, knowledge_cache, # goal, chat_history, knowledge_cache,
# None, retry_count # None, retry_count
# ) # )
return content return content
except Exception as e: except Exception as e:
logger.error(f"生成回复时出错: {e}") logger.error(f"生成回复时出错: {e}")
return "抱歉,我现在有点混乱,让我重新思考一下..." return "抱歉,我现在有点混乱,让我重新思考一下..."
async def check_reply( async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
"""检查回复是否合适 """检查回复是否合适
Args: Args:
reply: 生成的回复 reply: 生成的回复
goal: 对话目标 goal: 对话目标
retry_count: 当前重试次数 retry_count: 当前重试次数
Returns: Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划) Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
""" """
return await self.reply_checker.check(reply, goal, retry_count) return await self.reply_checker.check(reply, goal, retry_count)

View File

@@ -3,43 +3,44 @@ from .chat_observer import ChatObserver
logger = get_module_logger("waiter") logger = get_module_logger("waiter")
class Waiter: class Waiter:
"""等待器,用于等待对话流中的事件""" """等待器,用于等待对话流中的事件"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.stream_id = stream_id self.stream_id = stream_id
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
async def wait(self, timeout: float = 20.0) -> bool: async def wait(self, timeout: float = 20.0) -> bool:
"""等待用户回复或超时 """等待用户回复或超时
Args: Args:
timeout: 超时时间(秒) timeout: 超时时间(秒)
Returns: Returns:
bool: 如果因为超时返回则为True否则为False bool: 如果因为超时返回则为True否则为False
""" """
try: try:
message_before = self.chat_observer.get_last_message() message_before = self.chat_observer.get_last_message()
# 等待新消息 # 等待新消息
logger.debug(f"等待新消息,超时时间: {timeout}") logger.debug(f"等待新消息,超时时间: {timeout}")
is_timeout = await self.chat_observer.wait_for_update(timeout=timeout) is_timeout = await self.chat_observer.wait_for_update(timeout=timeout)
if is_timeout: if is_timeout:
logger.debug("等待超时,没有收到新消息") logger.debug("等待超时,没有收到新消息")
return True return True
# 检查是否是新消息 # 检查是否是新消息
message_after = self.chat_observer.get_last_message() message_after = self.chat_observer.get_last_message()
if message_before and message_after and message_before.get("message_id") == message_after.get("message_id"): if message_before and message_after and message_before.get("message_id") == message_after.get("message_id"):
# 如果消息ID相同说明没有新消息 # 如果消息ID相同说明没有新消息
logger.debug("没有收到新消息") logger.debug("没有收到新消息")
return True return True
logger.debug("收到新消息") logger.debug("收到新消息")
return False return False
except Exception as e: except Exception as e:
logger.error(f"等待时出错: {str(e)}") logger.error(f"等待时出错: {str(e)}")
return True return True

View File

@@ -30,7 +30,7 @@ class ChatBot:
self.think_flow_chat = ThinkFlowChat() self.think_flow_chat = ThinkFlowChat()
self.reasoning_chat = ReasoningChat() self.reasoning_chat = ReasoningChat()
self.only_process_chat = MessageProcessor() self.only_process_chat = MessageProcessor()
# 创建初始化PFC管理器的任务会在_ensure_started时执行 # 创建初始化PFC管理器的任务会在_ensure_started时执行
self.pfc_manager = PFCManager.get_instance() self.pfc_manager = PFCManager.get_instance()
@@ -38,7 +38,7 @@ class ChatBot:
"""确保所有任务已启动""" """确保所有任务已启动"""
if not self._started: if not self._started:
logger.info("确保ChatBot所有任务已启动") logger.info("确保ChatBot所有任务已启动")
self._started = True self._started = True
async def _create_PFC_chat(self, message: MessageRecv): async def _create_PFC_chat(self, message: MessageRecv):
@@ -46,7 +46,6 @@ class ChatBot:
chat_id = str(message.chat_stream.stream_id) chat_id = str(message.chat_stream.stream_id)
if global_config.enable_pfc_chatting: if global_config.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id) await self.pfc_manager.get_or_create_conversation(chat_id)
except Exception as e: except Exception as e:
@@ -80,7 +79,7 @@ class ChatBot:
try: try:
# 确保所有任务已启动 # 确保所有任务已启动
await self._ensure_started() await self._ensure_started()
message = MessageRecv(message_data) message = MessageRecv(message_data)
groupinfo = message.message_info.group_info groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info userinfo = message.message_info.user_info

View File

@@ -24,7 +24,7 @@ config_config = LogConfig(
# 配置主程序日志格式 # 配置主程序日志格式
logger = get_module_logger("config", config=config_config) logger = get_module_logger("config", config=config_config)
#考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
is_test = True is_test = True
mai_version_main = "0.6.2" mai_version_main = "0.6.2"
mai_version_fix = "snapshot-1" mai_version_fix = "snapshot-1"

View File

@@ -2,7 +2,7 @@
__version__ = "0.1.0" __version__ = "0.1.0"
from .api import BaseMessageAPI, global_api from .api import global_api
from .message_base import ( from .message_base import (
Seg, Seg,
GroupInfo, GroupInfo,
@@ -14,7 +14,6 @@ from .message_base import (
) )
__all__ = [ __all__ = [
"BaseMessageAPI",
"Seg", "Seg",
"global_api", "global_api",
"GroupInfo", "GroupInfo",

View File

@@ -1,7 +1,8 @@
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from typing import Dict, Any, Callable, List, Set from typing import Dict, Any, Callable, List, Set, Optional
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.plugins.message.message_base import MessageBase from src.plugins.message.message_base import MessageBase
from src.common.server import global_server
import aiohttp import aiohttp
import asyncio import asyncio
import uvicorn import uvicorn
@@ -49,13 +50,22 @@ class MessageServer(BaseMessageHandler):
_class_handlers: List[Callable] = [] # 类级别的消息处理器 _class_handlers: List[Callable] = [] # 类级别的消息处理器
def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False): def __init__(
self,
host: str = "0.0.0.0",
port: int = 18000,
enable_token=False,
app: Optional[FastAPI] = None,
path: str = "/ws",
):
super().__init__() super().__init__()
# 将类级别的处理器添加到实例处理器中 # 将类级别的处理器添加到实例处理器中
self.message_handlers.extend(self._class_handlers) self.message_handlers.extend(self._class_handlers)
self.app = FastAPI()
self.host = host self.host = host
self.port = port self.port = port
self.path = path
self.app = app or FastAPI()
self.own_app = app is None # 标记是否使用自己创建的app
self.active_websockets: Set[WebSocket] = set() self.active_websockets: Set[WebSocket] = set()
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射 self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
self.valid_tokens: Set[str] = set() self.valid_tokens: Set[str] = set()
@@ -63,28 +73,6 @@ class MessageServer(BaseMessageHandler):
self._setup_routes() self._setup_routes()
self._running = False self._running = False
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def _setup_routes(self): def _setup_routes(self):
@self.app.post("/api/message") @self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]): async def handle_message(message: Dict[str, Any]):
@@ -125,6 +113,90 @@ class MessageServer(BaseMessageHandler):
finally: finally:
self._remove_websocket(websocket, platform) self._remove_websocket(websocket, platform)
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def run_sync(self):
"""同步方式运行服务器"""
if not self.own_app:
raise RuntimeError("当使用外部FastAPI实例时请使用该实例的运行方法")
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
self._running = True
try:
if self.own_app:
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
await self.server.serve()
else:
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
while self._running:
await asyncio.sleep(1)
except KeyboardInterrupt:
await self.stop()
raise
except Exception as e:
await self.stop()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.stop()
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server") and self.own_app:
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def _remove_websocket(self, websocket: WebSocket, platform: str): def _remove_websocket(self, websocket: WebSocket, platform: str):
"""从所有集合中移除websocket""" """从所有集合中移除websocket"""
if websocket in self.active_websockets: if websocket in self.active_websockets:
@@ -161,54 +233,6 @@ class MessageServer(BaseMessageHandler):
async def send_message(self, message: MessageBase): async def send_message(self, message: MessageBase):
await self.broadcast_to_platform(message.message_info.platform, message.to_dict()) await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
try:
await self.server.serve()
except KeyboardInterrupt as e:
await self.stop()
raise KeyboardInterrupt from e
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点""" """发送消息到指定端点"""
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@@ -219,105 +243,4 @@ class MessageServer(BaseMessageHandler):
raise e raise e
class BaseMessageAPI: global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
self.app = FastAPI()
self.host = host
self.port = port
self.message_handlers: List[Callable] = []
self.cache = []
self._setup_routes()
self._running = False
def _setup_routes(self):
"""设置基础路由"""
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
try:
# 创建后台任务处理消息
asyncio.create_task(self._background_message_handler(message))
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def _background_message_handler(self, message: Dict[str, Any]):
"""后台处理单个消息"""
try:
await self.process_single_message(message)
except Exception as e:
logger.error(f"Background message processing failed: {str(e)}")
logger.error(traceback.format_exc())
def register_message_handler(self, handler: Callable):
"""注册消息处理函数"""
self.message_handlers.append(handler)
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
return await response.json()
except Exception:
# logger.error(f"发送消息失败: {str(e)}")
pass
async def process_single_message(self, message: Dict[str, Any]):
"""处理单条消息"""
tasks = []
for handler in self.message_handlers:
try:
tasks.append(handler(message))
except Exception as e:
logger.error(str(e))
logger.error(traceback.format_exc())
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
try:
await self.server.serve()
except KeyboardInterrupt as e:
await self.stop()
raise KeyboardInterrupt from e
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def start(self):
"""启动服务器的便捷方法"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.start_server())
except KeyboardInterrupt:
pass
finally:
loop.close()
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]))