This commit is contained in:
SengokuCola
2025-04-09 20:10:58 +08:00
22 changed files with 614 additions and 692 deletions

View File

@@ -4,6 +4,7 @@ import logging
from pathlib import Path from pathlib import Path
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
def setup_crash_logger(): def setup_crash_logger():
"""设置崩溃日志记录器""" """设置崩溃日志记录器"""
# 创建logs/crash目录如果不存在 # 创建logs/crash目录如果不存在
@@ -11,46 +12,42 @@ def setup_crash_logger():
crash_log_dir.mkdir(parents=True, exist_ok=True) crash_log_dir.mkdir(parents=True, exist_ok=True)
# 创建日志记录器 # 创建日志记录器
crash_logger = logging.getLogger('crash_logger') crash_logger = logging.getLogger("crash_logger")
crash_logger.setLevel(logging.ERROR) crash_logger.setLevel(logging.ERROR)
# 设置日志格式 # 设置日志格式
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s\n' "%(asctime)s - %(name)s - %(levelname)s\n异常类型: %(exc_info)s\n详细信息:\n%(message)s\n-------------------\n"
'异常类型: %(exc_info)s\n'
'详细信息:\n%(message)s\n'
'-------------------\n'
) )
# 创建按大小轮转的文件处理器最大10MB保留5个备份 # 创建按大小轮转的文件处理器最大10MB保留5个备份
log_file = crash_log_dir / "crash.log" log_file = crash_log_dir / "crash.log"
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(
log_file, log_file,
maxBytes=10*1024*1024, # 10MB maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5, backupCount=5,
encoding='utf-8' encoding="utf-8",
) )
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
crash_logger.addHandler(file_handler) crash_logger.addHandler(file_handler)
return crash_logger return crash_logger
def log_crash(exc_type, exc_value, exc_traceback): def log_crash(exc_type, exc_value, exc_traceback):
"""记录崩溃信息到日志文件""" """记录崩溃信息到日志文件"""
if exc_type is None: if exc_type is None:
return return
# 获取崩溃日志记录器 # 获取崩溃日志记录器
crash_logger = logging.getLogger('crash_logger') crash_logger = logging.getLogger("crash_logger")
# 获取完整的异常堆栈信息 # 获取完整的异常堆栈信息
stack_trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) stack_trace = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
# 记录崩溃信息 # 记录崩溃信息
crash_logger.error( crash_logger.error(stack_trace, exc_info=(exc_type, exc_value, exc_traceback))
stack_trace,
exc_info=(exc_type, exc_value, exc_traceback)
)
def install_crash_handler(): def install_crash_handler():
"""安装全局异常处理器""" """安装全局异常处理器"""

73
src/common/server.py Normal file
View File

@@ -0,0 +1,73 @@
from fastapi import FastAPI, APIRouter
from typing import Optional, Union
from uvicorn import Config, Server as UvicornServer
import os
class Server:
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
self.app = FastAPI(title=app_name)
self._host: str = "127.0.0.1"
self._port: int = 8080
self._server: Optional[UvicornServer] = None
self.set_address(host, port)
def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由
APIRouter 用于对相关的路由端点进行分组和模块化管理:
1. 可以将相关的端点组织在一起,便于管理
2. 支持添加统一的路由前缀
3. 可以为一组路由添加共同的依赖项、标签等
示例:
router = APIRouter()
@router.get("/users")
def get_users():
return {"users": [...]}
@router.post("/users")
def create_user():
return {"msg": "user created"}
# 注册路由,添加前缀 "/api/v1"
server.register_router(router, prefix="/api/v1")
"""
self.app.include_router(router, prefix=prefix)
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
"""设置服务器地址和端口"""
if host:
self._host = host
if port:
self._port = port
async def run(self):
"""启动服务器"""
config = Config(app=self.app, host=self._host, port=self._port)
self._server = UvicornServer(config=config)
try:
await self._server.serve()
except KeyboardInterrupt:
await self.shutdown()
raise
except Exception as e:
await self.shutdown()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.shutdown()
async def shutdown(self):
"""安全关闭服务器"""
if self._server:
self._server.should_exit = True
await self._server.shutdown()
self._server = None
def get_app(self) -> FastAPI:
"""获取 FastAPI 实例"""
return self.app
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))

View File

@@ -16,7 +16,7 @@ from .plugins.chat.bot import chat_bot
from .common.logger import get_module_logger from .common.logger import get_module_logger
from .plugins.remote import heartbeat_thread # noqa: F401 from .plugins.remote import heartbeat_thread # noqa: F401
from .individuality.individuality import Individuality from .individuality.individuality import Individuality
from .common.server import global_server
logger = get_module_logger("main") logger = get_module_logger("main")
@@ -33,6 +33,7 @@ class MainSystem:
from .plugins.message import global_api from .plugins.message import global_api
self.app = global_api self.app = global_api
self.server = global_server
async def initialize(self): async def initialize(self):
"""初始化系统组件""" """初始化系统组件"""
@@ -126,6 +127,7 @@ class MainSystem:
emoji_manager.start_periodic_check_register(), emoji_manager.start_periodic_check_register(),
# emoji_manager.start_periodic_register(), # emoji_manager.start_periodic_register(),
self.app.run(), self.app.run(),
self.server.run(),
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View File

@@ -10,6 +10,7 @@ from .conversation_info import ConversationInfo
logger = get_module_logger("action_planner") logger = get_module_logger("action_planner")
class ActionPlannerInfo: class ActionPlannerInfo:
def __init__(self): def __init__(self):
self.done_action = [] self.done_action = []
@@ -23,20 +24,13 @@ class ActionPlanner:
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning"
temperature=0.7,
max_tokens=1000,
request_type="action_planning"
) )
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
async def plan( async def plan(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> Tuple[str, str]:
self,
observation_info: ObservationInfo,
conversation_info: ConversationInfo
) -> Tuple[str, str]:
"""规划下一步行动 """规划下一步行动
Args: Args:
@@ -49,14 +43,13 @@ class ActionPlanner:
# 构建提示词 # 构建提示词
logger.debug(f"开始规划行动:当前目标: {conversation_info.goal_list}") logger.debug(f"开始规划行动:当前目标: {conversation_info.goal_list}")
#构建对话目标 # 构建对话目标
if conversation_info.goal_list: if conversation_info.goal_list:
goal, reasoning = conversation_info.goal_list[-1] goal, reasoning = conversation_info.goal_list[-1]
else: else:
goal = "目前没有明确对话目标" goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标" reasoning = "目前没有明确对话目标,最好思考一个对话目标"
# 获取聊天历史记录 # 获取聊天历史记录
chat_history_list = observation_info.chat_history chat_history_list = observation_info.chat_history
chat_history_text = "" chat_history_text = ""
@@ -72,7 +65,6 @@ class ActionPlanner:
observation_info.clear_unprocessed_messages() observation_info.clear_unprocessed_messages()
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本 # 构建action历史文本
@@ -81,8 +73,6 @@ class ActionPlanner:
for action in action_history_list: for action in action_history_list:
action_history_text += f"{action}\n" action_history_text += f"{action}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下内容根据信息决定下一步行动 prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下内容根据信息决定下一步行动
当前对话目标:{goal} 当前对话目标:{goal}
@@ -114,9 +104,7 @@ rethink_goal: 重新思考对话目标,当发现对话目标不合适时选择
# 使用简化函数提取JSON内容 # 使用简化函数提取JSON内容
success, result = get_items_from_json( success, result = get_items_from_json(
content, content, "action", "reason", default_values={"action": "direct_reply", "reason": "没有明确原因"}
"action", "reason",
default_values={"action": "direct_reply", "reason": "没有明确原因"}
) )
if not success: if not success:

View File

@@ -17,7 +17,7 @@ class ChatObserver:
_instances: Dict[str, "ChatObserver"] = {} _instances: Dict[str, "ChatObserver"] = {}
@classmethod @classmethod
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver': def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> "ChatObserver":
"""获取或创建观察器实例 """获取或创建观察器实例
Args: Args:
@@ -84,10 +84,7 @@ class ChatObserver:
""" """
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}") logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
new_message_exists = await self.message_storage.has_new_messages( new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
self.stream_id,
self.last_check_time
)
if new_message_exists: if new_message_exists:
logger.debug("发现新消息") logger.debug("发现新消息")
@@ -114,11 +111,7 @@ class ChatObserver:
self.last_user_speak_time = message["time"] self.last_user_speak_time = message["time"]
# 发送新消息通知 # 发送新消息通知
notification = create_new_message_notification( notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message)
sender="chat_observer",
target="pfc",
message=message
)
await self.notification_manager.send_notification(notification) await self.notification_manager.send_notification(notification)
# 检查并更新冷场状态 # 检查并更新冷场状态
@@ -144,19 +137,12 @@ class ChatObserver:
# 如果冷场状态发生变化,发送通知 # 如果冷场状态发生变化,发送通知
if is_cold != self.is_cold_chat_state: if is_cold != self.is_cold_chat_state:
self.is_cold_chat_state = is_cold self.is_cold_chat_state = is_cold
notification = create_cold_chat_notification( notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
sender="chat_observer",
target="pfc",
is_cold=is_cold
)
await self.notification_manager.send_notification(notification) await self.notification_manager.send_notification(notification)
async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象""" """获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
messages = await self.message_storage.get_messages_after( messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
self.stream_id,
self.last_message_read
)
for message in messages: for message in messages:
await self._add_message_to_history(message) await self._add_message_to_history(message)
return messages, self.message_history return messages, self.message_history
@@ -224,10 +210,7 @@ class ChatObserver:
Returns: Returns:
List[Dict[str, Any]]: 新消息列表 List[Dict[str, Any]]: 新消息列表
""" """
new_messages = await self.message_storage.get_messages_after( new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
self.stream_id,
self.last_message_read
)
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
@@ -243,17 +226,15 @@ class ChatObserver:
Returns: Returns:
List[Dict[str, Any]]: 最多5条消息 List[Dict[str, Any]]: 最多5条消息
""" """
new_messages = await self.message_storage.get_messages_before( new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
self.stream_id,
time_point
)
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
return new_messages return new_messages
'''主要观察循环''' """主要观察循环"""
async def _update_loop(self): async def _update_loop(self):
"""更新循环""" """更新循环"""
try: try:
@@ -396,17 +377,16 @@ class ChatObserver:
bool: 是否有新消息 bool: 是否有新消息
""" """
try: try:
messages = await self.message_storage.get_messages_for_stream( messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50)
self.stream_id,
limit=50
)
if not messages: if not messages:
return False return False
# 检查是否有新消息 # 检查是否有新消息
has_new_messages = False has_new_messages = False
if messages and (not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]): if messages and (
not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]
):
has_new_messages = True has_new_messages = True
self.message_cache = messages self.message_cache = messages

View File

@@ -4,8 +4,10 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class ChatState(Enum): class ChatState(Enum):
"""聊天状态枚举""" """聊天状态枚举"""
NORMAL = auto() # 正常状态 NORMAL = auto() # 正常状态
NEW_MESSAGE = auto() # 有新消息 NEW_MESSAGE = auto() # 有新消息
COLD_CHAT = auto() # 冷场状态 COLD_CHAT = auto() # 冷场状态
@@ -15,8 +17,10 @@ class ChatState(Enum):
SILENT = auto() # 沉默状态 SILENT = auto() # 沉默状态
ERROR = auto() # 错误状态 ERROR = auto() # 错误状态
class NotificationType(Enum): class NotificationType(Enum):
"""通知类型枚举""" """通知类型枚举"""
NEW_MESSAGE = auto() # 新消息通知 NEW_MESSAGE = auto() # 新消息通知
COLD_CHAT = auto() # 冷场通知 COLD_CHAT = auto() # 冷场通知
ACTIVE_CHAT = auto() # 活跃通知 ACTIVE_CHAT = auto() # 活跃通知
@@ -27,9 +31,11 @@ class NotificationType(Enum):
USER_LEFT = auto() # 用户离开通知 USER_LEFT = auto() # 用户离开通知
ERROR = auto() # 错误通知 ERROR = auto() # 错误通知
@dataclass @dataclass
class ChatStateInfo: class ChatStateInfo:
"""聊天状态信息""" """聊天状态信息"""
state: ChatState state: ChatState
last_message_time: Optional[float] = None last_message_time: Optional[float] = None
last_message_content: Optional[str] = None last_message_content: Optional[str] = None
@@ -38,9 +44,11 @@ class ChatStateInfo:
cold_duration: float = 0.0 # 冷场持续时间(秒) cold_duration: float = 0.0 # 冷场持续时间(秒)
active_duration: float = 0.0 # 活跃持续时间(秒) active_duration: float = 0.0 # 活跃持续时间(秒)
@dataclass @dataclass
class Notification: class Notification:
"""通知基类""" """通知基类"""
type: NotificationType type: NotificationType
timestamp: float timestamp: float
sender: str # 发送者标识 sender: str # 发送者标识
@@ -49,15 +57,13 @@ class Notification:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式""" """转换为字典格式"""
return { return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
"type": self.type.name,
"timestamp": self.timestamp,
"data": self.data
}
@dataclass @dataclass
class StateNotification(Notification): class StateNotification(Notification):
"""持续状态通知""" """持续状态通知"""
is_active: bool = True is_active: bool = True
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
@@ -65,6 +71,7 @@ class StateNotification(Notification):
base_dict["is_active"] = self.is_active base_dict["is_active"] = self.is_active
return base_dict return base_dict
class NotificationHandler(ABC): class NotificationHandler(ABC):
"""通知处理器接口""" """通知处理器接口"""
@@ -73,6 +80,7 @@ class NotificationHandler(ABC):
"""处理通知""" """处理通知"""
pass pass
class NotificationManager: class NotificationManager:
"""通知管理器""" """通知管理器"""
@@ -141,10 +149,9 @@ class NotificationManager:
"""检查特定状态是否活跃""" """检查特定状态是否活跃"""
return state_type in self._active_states return state_type in self._active_states
def get_notification_history(self, def get_notification_history(
sender: Optional[str] = None, self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
target: Optional[str] = None, ) -> List[Notification]:
limit: Optional[int] = None) -> List[Notification]:
"""获取通知历史 """获取通知历史
Args: Args:
@@ -164,6 +171,7 @@ class NotificationManager:
return history return history
# 一些常用的通知创建函数 # 一些常用的通知创建函数
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification: def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
"""创建新消息通知""" """创建新消息通知"""
@@ -176,10 +184,11 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str,
"message_id": message.get("message_id"), "message_id": message.get("message_id"),
"content": message.get("content"), "content": message.get("content"),
"sender": message.get("sender"), "sender": message.get("sender"),
"time": message.get("time") "time": message.get("time"),
} },
) )
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification: def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
"""创建冷场状态通知""" """创建冷场状态通知"""
return StateNotification( return StateNotification(
@@ -188,9 +197,10 @@ def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> St
sender=sender, sender=sender,
target=target, target=target,
data={"is_cold": is_cold}, data={"is_cold": is_cold},
is_active=is_cold is_active=is_cold,
) )
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification: def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
"""创建活跃状态通知""" """创建活跃状态通知"""
return StateNotification( return StateNotification(
@@ -199,9 +209,10 @@ def create_active_chat_notification(sender: str, target: str, is_active: bool) -
sender=sender, sender=sender,
target=target, target=target,
data={"is_active": is_active}, data={"is_active": is_active},
is_active=is_active is_active=is_active,
) )
class ChatStateManager: class ChatStateManager:
"""聊天状态管理器""" """聊天状态管理器"""

View File

@@ -54,16 +54,15 @@ class Conversation:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise raise
try: try:
#决策所需要的信息,包括自身自信和观察信息两部分 # 决策所需要的信息,包括自身自信和观察信息两部分
#注册观察器和观测信息 # 注册观察器和观测信息
self.chat_observer = ChatObserver.get_instance(self.stream_id) self.chat_observer = ChatObserver.get_instance(self.stream_id)
self.chat_observer.start() self.chat_observer.start()
self.observation_info = ObservationInfo() self.observation_info = ObservationInfo()
self.observation_info.bind_to_chat_observer(self.stream_id) self.observation_info.bind_to_chat_observer(self.stream_id)
#对话信息 # 对话信息
self.conversation_info = ConversationInfo() self.conversation_info = ConversationInfo()
except Exception as e: except Exception as e:
logger.error(f"初始化对话实例:注册信息组件失败: {e}") logger.error(f"初始化对话实例:注册信息组件失败: {e}")
@@ -74,7 +73,6 @@ class Conversation:
self.should_continue = True self.should_continue = True
asyncio.create_task(self.start()) asyncio.create_task(self.start())
async def start(self): async def start(self):
"""开始对话流程""" """开始对话流程"""
try: try:
@@ -84,16 +82,12 @@ class Conversation:
logger.error(f"启动对话系统失败: {e}") logger.error(f"启动对话系统失败: {e}")
raise raise
async def _plan_and_action_loop(self): async def _plan_and_action_loop(self):
"""思考步PFC核心循环模块""" """思考步PFC核心循环模块"""
# 获取最近的消息历史 # 获取最近的消息历史
while self.should_continue: while self.should_continue:
# 使用决策信息来辅助行动规划 # 使用决策信息来辅助行动规划
action, reason = await self.action_planner.plan( action, reason = await self.action_planner.plan(self.observation_info, self.conversation_info)
self.observation_info,
self.conversation_info
)
if self._check_new_messages_after_planning(): if self._check_new_messages_after_planning():
continue continue
@@ -108,7 +102,6 @@ class Conversation:
return True return True
return False return False
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message: def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象""" """将消息字典转换为Message对象"""
try: try:
@@ -122,31 +115,31 @@ class Conversation:
time=msg_dict["time"], time=msg_dict["time"],
user_info=user_info, user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""), processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", "") detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
) )
except Exception as e: except Exception as e:
logger.warning(f"转换消息时出错: {e}") logger.warning(f"转换消息时出错: {e}")
raise raise
async def _handle_action(self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo): async def _handle_action(
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
):
"""处理规划的行动""" """处理规划的行动"""
logger.info(f"执行行动: {action}, 原因: {reason}") logger.info(f"执行行动: {action}, 原因: {reason}")
# 记录action历史先设置为stop完成后再设置为done # 记录action历史先设置为stop完成后再设置为done
conversation_info.done_action.append({ conversation_info.done_action.append(
{
"action": action, "action": action,
"reason": reason, "reason": reason,
"status": "start", "status": "start",
"time": datetime.datetime.now().strftime("%H:%M:%S") "time": datetime.datetime.now().strftime("%H:%M:%S"),
}) }
)
if action == "direct_reply": if action == "direct_reply":
self.state = ConversationState.GENERATING self.state = ConversationState.GENERATING
self.generated_reply = await self.reply_generator.generate( self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
observation_info,
conversation_info
)
# # 检查回复是否合适 # # 检查回复是否合适
# is_suitable, reason, need_replan = await self.reply_generator.check_reply( # is_suitable, reason, need_replan = await self.reply_generator.check_reply(
@@ -159,12 +152,14 @@ class Conversation:
await self._send_reply() await self._send_reply()
conversation_info.done_action.append({ conversation_info.done_action.append(
{
"action": action, "action": action,
"reason": reason, "reason": reason,
"status": "done", "status": "done",
"time": datetime.datetime.now().strftime("%H:%M:%S") "time": datetime.datetime.now().strftime("%H:%M:%S"),
}) }
)
elif action == "fetch_knowledge": elif action == "fetch_knowledge":
self.state = ConversationState.FETCHING self.state = ConversationState.FETCHING
@@ -175,10 +170,7 @@ class Conversation:
if knowledge: if knowledge:
if topic not in self.conversation_info.knowledge_list: if topic not in self.conversation_info.knowledge_list:
self.conversation_info.knowledge_list.append({ self.conversation_info.knowledge_list.append({"topic": topic, "knowledge": knowledge})
"topic": topic,
"knowledge": knowledge
})
else: else:
self.conversation_info.knowledge_list[topic] += knowledge self.conversation_info.knowledge_list[topic] += knowledge
@@ -186,7 +178,6 @@ class Conversation:
self.state = ConversationState.RETHINKING self.state = ConversationState.RETHINKING
await self.goal_analyzer.analyze_goal(conversation_info, observation_info) await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
elif action == "listening": elif action == "listening":
self.state = ConversationState.LISTENING self.state = ConversationState.LISTENING
logger.info("倾听对方发言...") logger.info("倾听对方发言...")
@@ -210,9 +201,7 @@ class Conversation:
latest_message = self._convert_to_message(messages[0]) latest_message = self._convert_to_message(messages[0])
await self.direct_sender.send_message( await self.direct_sender.send_message(
chat_stream=self.chat_stream, chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
content="TODO:超时消息",
reply_to_message=latest_message
) )
except Exception as e: except Exception as e:
logger.error(f"发送超时消息失败: {str(e)}") logger.error(f"发送超时消息失败: {str(e)}")
@@ -231,9 +220,7 @@ class Conversation:
latest_message = self._convert_to_message(messages[0]) latest_message = self._convert_to_message(messages[0])
try: try:
await self.direct_sender.send_message( await self.direct_sender.send_message(
chat_stream=self.chat_stream, chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
content=self.generated_reply,
reply_to_message=latest_message
) )
self.chat_observer.trigger_update() # 触发立即更新 self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update(): if not await self.chat_observer.wait_for_update():

View File

@@ -1,5 +1,3 @@
class ConversationInfo: class ConversationInfo:
def __init__(self): def __init__(self):
self.done_action = [] self.done_action = []

View File

@@ -7,6 +7,7 @@ from src.plugins.chat.message import MessageSending
logger = get_module_logger("message_sender") logger = get_module_logger("message_sender")
class DirectMessageSender: class DirectMessageSender:
"""直接消息发送器""" """直接消息发送器"""
@@ -33,10 +34,7 @@ class DirectMessageSender:
# 检查是否需要引用回复 # 检查是否需要引用回复
if reply_to_message: if reply_to_message:
reply_id = reply_to_message.message_id reply_id = reply_to_message.message_id
message_sending = MessageSending( message_sending = MessageSending(segments=segments, reply_to_id=reply_id)
segments=segments,
reply_to_id=reply_id
)
else: else:
message_sending = MessageSending(segments=segments) message_sending = MessageSending(segments=segments)

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from src.common.database import db from src.common.database import db
class MessageStorage(ABC): class MessageStorage(ABC):
"""消息存储接口""" """消息存储接口"""
@@ -45,6 +46,7 @@ class MessageStorage(ABC):
""" """
pass pass
class MongoDBMessageStorage(MessageStorage): class MongoDBMessageStorage(MessageStorage):
"""MongoDB消息存储实现""" """MongoDB消息存储实现"""
@@ -60,32 +62,23 @@ class MongoDBMessageStorage(MessageStorage):
if last_message: if last_message:
query["time"] = {"$gt": last_message["time"]} query["time"] = {"$gt": last_message["time"]}
return list( return list(self.db.messages.find(query).sort("time", 1))
self.db.messages.find(query).sort("time", 1)
)
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = { query = {"chat_id": chat_id, "time": {"$lt": time_point}}
"chat_id": chat_id,
"time": {"$lt": time_point}
}
messages = list( messages = list(self.db.messages.find(query).sort("time", -1).limit(limit))
self.db.messages.find(query).sort("time", -1).limit(limit)
)
# 将消息按时间正序排列 # 将消息按时间正序排列
messages.reverse() messages.reverse()
return messages return messages
async def has_new_messages(self, chat_id: str, after_time: float) -> bool: async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = { query = {"chat_id": chat_id, "time": {"$gt": after_time}}
"chat_id": chat_id,
"time": {"$gt": after_time}
}
return self.db.messages.find_one(query) is not None return self.db.messages.find_one(query) is not None
# # 创建一个内存消息存储实现,用于测试 # # 创建一个内存消息存储实现,用于测试
# class InMemoryMessageStorage(MessageStorage): # class InMemoryMessageStorage(MessageStorage):
# """内存消息存储实现,主要用于测试""" # """内存消息存储实现,主要用于测试"""

View File

@@ -7,10 +7,11 @@ if TYPE_CHECKING:
logger = get_module_logger("notification_handler") logger = get_module_logger("notification_handler")
class PFCNotificationHandler(NotificationHandler): class PFCNotificationHandler(NotificationHandler):
"""PFC通知处理器""" """PFC通知处理器"""
def __init__(self, conversation: 'Conversation'): def __init__(self, conversation: "Conversation"):
"""初始化PFC通知处理器 """初始化PFC通知处理器
Args: Args:
@@ -68,4 +69,3 @@ class PFCNotificationHandler(NotificationHandler):
observation_info.conversation_cold_duration = cold_duration observation_info.conversation_cold_duration = cold_duration
logger.info(f"对话已冷: {cold_duration}") logger.info(f"对话已冷: {cold_duration}")

View File

@@ -1,5 +1,5 @@
#Programmable Friendly Conversationalist # Programmable Friendly Conversationalist
#Prefrontal cortex # Prefrontal cortex
from typing import List, Optional, Dict, Any, Set from typing import List, Optional, Dict, Any, Set
from ..message.message_base import UserInfo from ..message.message_base import UserInfo
import time import time
@@ -10,10 +10,11 @@ from .chat_states import NotificationHandler
logger = get_module_logger("observation_info") logger = get_module_logger("observation_info")
class ObservationInfoHandler(NotificationHandler): class ObservationInfoHandler(NotificationHandler):
"""ObservationInfo的通知处理器""" """ObservationInfo的通知处理器"""
def __init__(self, observation_info: 'ObservationInfo'): def __init__(self, observation_info: "ObservationInfo"):
"""初始化处理器 """初始化处理器
Args: Args:
@@ -62,8 +63,7 @@ class ObservationInfoHandler(NotificationHandler):
# 处理消息删除通知 # 处理消息删除通知
message_id = data.get("message_id") message_id = data.get("message_id")
self.observation_info.unprocessed_messages = [ self.observation_info.unprocessed_messages = [
msg for msg in self.observation_info.unprocessed_messages msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
if msg.get("message_id") != message_id
] ]
elif notification_type == "USER_JOINED": elif notification_type == "USER_JOINED":
@@ -83,16 +83,17 @@ class ObservationInfoHandler(NotificationHandler):
error_msg = data.get("error", "") error_msg = data.get("error", "")
logger.error(f"收到错误通知: {error_msg}") logger.error(f"收到错误通知: {error_msg}")
@dataclass @dataclass
class ObservationInfo: class ObservationInfo:
"""决策信息类用于收集和管理来自chat_observer的通知信息""" """决策信息类用于收集和管理来自chat_observer的通知信息"""
#data_list # data_list
chat_history: List[str] = field(default_factory=list) chat_history: List[str] = field(default_factory=list)
unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list) unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list)
active_users: Set[str] = field(default_factory=set) active_users: Set[str] = field(default_factory=set)
#data # data
last_bot_speak_time: Optional[float] = None last_bot_speak_time: Optional[float] = None
last_user_speak_time: Optional[float] = None last_user_speak_time: Optional[float] = None
last_message_time: Optional[float] = None last_message_time: Optional[float] = None
@@ -102,7 +103,7 @@ class ObservationInfo:
new_messages_count: int = 0 new_messages_count: int = 0
cold_chat_duration: float = 0.0 cold_chat_duration: float = 0.0
#state # state
is_typing: bool = False is_typing: bool = False
has_unread_messages: bool = False has_unread_messages: bool = False
is_cold_chat: bool = False is_cold_chat: bool = False
@@ -124,28 +125,20 @@ class ObservationInfo:
""" """
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
self.chat_observer.notification_manager.register_handler( self.chat_observer.notification_manager.register_handler(
target="observation_info", target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
notification_type="NEW_MESSAGE",
handler=self.handler
) )
self.chat_observer.notification_manager.register_handler( self.chat_observer.notification_manager.register_handler(
target="observation_info", target="observation_info", notification_type="COLD_CHAT", handler=self.handler
notification_type="COLD_CHAT",
handler=self.handler
) )
def unbind_from_chat_observer(self): def unbind_from_chat_observer(self):
"""解除与chat_observer的绑定""" """解除与chat_observer的绑定"""
if self.chat_observer: if self.chat_observer:
self.chat_observer.notification_manager.unregister_handler( self.chat_observer.notification_manager.unregister_handler(
target="observation_info", target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
notification_type="NEW_MESSAGE",
handler=self.handler
) )
self.chat_observer.notification_manager.unregister_handler( self.chat_observer.notification_manager.unregister_handler(
target="observation_info", target="observation_info", notification_type="COLD_CHAT", handler=self.handler
notification_type="COLD_CHAT",
handler=self.handler
) )
self.chat_observer = None self.chat_observer = None

View File

@@ -53,14 +53,13 @@ class GoalAnalyzer:
Returns: Returns:
Tuple[str, str, str]: (目标, 方法, 原因) Tuple[str, str, str]: (目标, 方法, 原因)
""" """
#构建对话目标 # 构建对话目标
goal_list = conversation_info.goal_list goal_list = conversation_info.goal_list
goal_text = "" goal_text = ""
for goal, reason in goal_list: for goal, reason in goal_list:
goal_text += f"目标:{goal};" goal_text += f"目标:{goal};"
goal_text += f"原因:{reason}\n" goal_text += f"原因:{reason}\n"
# 获取聊天历史记录 # 获取聊天历史记录
chat_history_list = observation_info.chat_history chat_history_list = observation_info.chat_history
chat_history_text = "" chat_history_text = ""
@@ -76,7 +75,6 @@ class GoalAnalyzer:
observation_info.clear_unprocessed_messages() observation_info.clear_unprocessed_messages()
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本 # 构建action历史文本
@@ -85,7 +83,6 @@ class GoalAnalyzer:
for action in action_history_list: for action in action_history_list:
action_history_text += f"{action}\n" action_history_text += f"{action}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。 prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。
这些目标应该反映出对话的不同方面和意图。 这些目标应该反映出对话的不同方面和意图。
@@ -122,17 +119,12 @@ class GoalAnalyzer:
# 使用简化函数提取JSON内容 # 使用简化函数提取JSON内容
success, result = get_items_from_json( success, result = get_items_from_json(
content, content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
"goal", "reasoning",
required_types={"goal": str, "reasoning": str}
) )
#TODO # TODO
conversation_info.goal_list.append(result) conversation_info.goal_list.append(result)
async def _update_goals(self, new_goal: str, method: str, reasoning: str): async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表 """更新目标列表
@@ -233,8 +225,10 @@ class GoalAnalyzer:
# 尝试解析JSON # 尝试解析JSON
success, result = get_items_from_json( success, result = get_items_from_json(
content, content,
"goal_achieved", "stop_conversation", "reason", "goal_achieved",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str} "stop_conversation",
"reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
) )
if not success: if not success:
@@ -285,7 +279,6 @@ class Waiter:
logger.info("等待中...") logger.info("等待中...")
class DirectMessageSender: class DirectMessageSender:
"""直接发送消息到平台的发送器""" """直接发送消息到平台的发送器"""

View File

@@ -5,6 +5,7 @@ import traceback
logger = get_module_logger("pfc_manager") logger = get_module_logger("pfc_manager")
class PFCManager: class PFCManager:
"""PFC对话管理器负责管理所有对话实例""" """PFC对话管理器负责管理所有对话实例"""
@@ -16,7 +17,7 @@ class PFCManager:
_initializing: Dict[str, bool] = {} _initializing: Dict[str, bool] = {}
@classmethod @classmethod
def get_instance(cls) -> 'PFCManager': def get_instance(cls) -> "PFCManager":
"""获取管理器单例 """获取管理器单例
Returns: Returns:
@@ -60,7 +61,6 @@ class PFCManager:
return conversation_instance return conversation_instance
async def _initialize_conversation(self, conversation: Conversation): async def _initialize_conversation(self, conversation: Conversation):
"""初始化会话实例 """初始化会话实例
@@ -84,7 +84,6 @@ class PFCManager:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
# 清理失败的初始化 # 清理失败的初始化
async def get_conversation(self, stream_id: str) -> Optional[Conversation]: async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
"""获取已存在的会话实例 """获取已存在的会话实例

View File

@@ -4,6 +4,7 @@ from typing import Literal
class ConversationState(Enum): class ConversationState(Enum):
"""对话状态""" """对话状态"""
INIT = "初始化" INIT = "初始化"
RETHINKING = "重新思考" RETHINKING = "重新思考"
ANALYZING = "分析历史" ANALYZING = "分析历史"

View File

@@ -16,21 +16,14 @@ class ReplyGenerator:
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation"
temperature=0.7,
max_tokens=300,
request_type="reply_generation"
) )
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
self.reply_checker = ReplyChecker(stream_id) self.reply_checker = ReplyChecker(stream_id)
async def generate( async def generate(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> str:
self,
observation_info: ObservationInfo,
conversation_info: ConversationInfo
) -> str:
"""生成回复 """生成回复
Args: Args:
@@ -58,7 +51,6 @@ class ReplyGenerator:
for msg in chat_history_list: for msg in chat_history_list:
chat_history_text += f"{msg}\n" chat_history_text += f"{msg}\n"
# 整理知识缓存 # 整理知识缓存
knowledge_text = "" knowledge_text = ""
knowledge_list = conversation_info.knowledge_list knowledge_list = conversation_info.knowledge_list
@@ -107,12 +99,7 @@ class ReplyGenerator:
logger.error(f"生成回复时出错: {e}") logger.error(f"生成回复时出错: {e}")
return "抱歉,我现在有点混乱,让我重新思考一下..." return "抱歉,我现在有点混乱,让我重新思考一下..."
async def check_reply( async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
"""检查回复是否合适 """检查回复是否合适
Args: Args:

View File

@@ -3,6 +3,7 @@ from .chat_observer import ChatObserver
logger = get_module_logger("waiter") logger = get_module_logger("waiter")
class Waiter: class Waiter:
"""等待器,用于等待对话流中的事件""" """等待器,用于等待对话流中的事件"""

View File

@@ -46,7 +46,6 @@ class ChatBot:
chat_id = str(message.chat_stream.stream_id) chat_id = str(message.chat_stream.stream_id)
if global_config.enable_pfc_chatting: if global_config.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id) await self.pfc_manager.get_or_create_conversation(chat_id)
except Exception as e: except Exception as e:

View File

@@ -24,7 +24,7 @@ config_config = LogConfig(
# 配置主程序日志格式 # 配置主程序日志格式
logger = get_module_logger("config", config=config_config) logger = get_module_logger("config", config=config_config)
#考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
is_test = True is_test = True
mai_version_main = "0.6.2" mai_version_main = "0.6.2"
mai_version_fix = "snapshot-1" mai_version_fix = "snapshot-1"

View File

@@ -2,7 +2,7 @@
__version__ = "0.1.0" __version__ = "0.1.0"
from .api import BaseMessageAPI, global_api from .api import global_api
from .message_base import ( from .message_base import (
Seg, Seg,
GroupInfo, GroupInfo,
@@ -14,7 +14,6 @@ from .message_base import (
) )
__all__ = [ __all__ = [
"BaseMessageAPI",
"Seg", "Seg",
"global_api", "global_api",
"GroupInfo", "GroupInfo",

View File

@@ -1,7 +1,8 @@
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from typing import Dict, Any, Callable, List, Set from typing import Dict, Any, Callable, List, Set, Optional
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.plugins.message.message_base import MessageBase from src.plugins.message.message_base import MessageBase
from src.common.server import global_server
import aiohttp import aiohttp
import asyncio import asyncio
import uvicorn import uvicorn
@@ -49,13 +50,22 @@ class MessageServer(BaseMessageHandler):
_class_handlers: List[Callable] = [] # 类级别的消息处理器 _class_handlers: List[Callable] = [] # 类级别的消息处理器
def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False): def __init__(
self,
host: str = "0.0.0.0",
port: int = 18000,
enable_token=False,
app: Optional[FastAPI] = None,
path: str = "/ws",
):
super().__init__() super().__init__()
# 将类级别的处理器添加到实例处理器中 # 将类级别的处理器添加到实例处理器中
self.message_handlers.extend(self._class_handlers) self.message_handlers.extend(self._class_handlers)
self.app = FastAPI()
self.host = host self.host = host
self.port = port self.port = port
self.path = path
self.app = app or FastAPI()
self.own_app = app is None # 标记是否使用自己创建的app
self.active_websockets: Set[WebSocket] = set() self.active_websockets: Set[WebSocket] = set()
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射 self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
self.valid_tokens: Set[str] = set() self.valid_tokens: Set[str] = set()
@@ -63,28 +73,6 @@ class MessageServer(BaseMessageHandler):
self._setup_routes() self._setup_routes()
self._running = False self._running = False
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def _setup_routes(self): def _setup_routes(self):
@self.app.post("/api/message") @self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]): async def handle_message(message: Dict[str, Any]):
@@ -125,6 +113,90 @@ class MessageServer(BaseMessageHandler):
finally: finally:
self._remove_websocket(websocket, platform) self._remove_websocket(websocket, platform)
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def run_sync(self):
"""同步方式运行服务器"""
if not self.own_app:
raise RuntimeError("当使用外部FastAPI实例时请使用该实例的运行方法")
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
self._running = True
try:
if self.own_app:
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
await self.server.serve()
else:
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
while self._running:
await asyncio.sleep(1)
except KeyboardInterrupt:
await self.stop()
raise
except Exception as e:
await self.stop()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.stop()
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server") and self.own_app:
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def _remove_websocket(self, websocket: WebSocket, platform: str): def _remove_websocket(self, websocket: WebSocket, platform: str):
"""从所有集合中移除websocket""" """从所有集合中移除websocket"""
if websocket in self.active_websockets: if websocket in self.active_websockets:
@@ -161,54 +233,6 @@ class MessageServer(BaseMessageHandler):
async def send_message(self, message: MessageBase): async def send_message(self, message: MessageBase):
await self.broadcast_to_platform(message.message_info.platform, message.to_dict()) await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
try:
await self.server.serve()
except KeyboardInterrupt as e:
await self.stop()
raise KeyboardInterrupt from e
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点""" """发送消息到指定端点"""
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@@ -219,105 +243,4 @@ class MessageServer(BaseMessageHandler):
raise e raise e
class BaseMessageAPI: global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
self.app = FastAPI()
self.host = host
self.port = port
self.message_handlers: List[Callable] = []
self.cache = []
self._setup_routes()
self._running = False
def _setup_routes(self):
"""设置基础路由"""
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
try:
# 创建后台任务处理消息
asyncio.create_task(self._background_message_handler(message))
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def _background_message_handler(self, message: Dict[str, Any]):
"""后台处理单个消息"""
try:
await self.process_single_message(message)
except Exception as e:
logger.error(f"Background message processing failed: {str(e)}")
logger.error(traceback.format_exc())
def register_message_handler(self, handler: Callable):
"""注册消息处理函数"""
self.message_handlers.append(handler)
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
return await response.json()
except Exception:
# logger.error(f"发送消息失败: {str(e)}")
pass
async def process_single_message(self, message: Dict[str, Any]):
"""处理单条消息"""
tasks = []
for handler in self.message_handlers:
try:
tasks.append(handler(message))
except Exception as e:
logger.error(str(e))
logger.error(traceback.format_exc())
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
try:
await self.server.serve()
except KeyboardInterrupt as e:
await self.stop()
raise KeyboardInterrupt from e
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def start(self):
"""启动服务器的便捷方法"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.start_server())
except KeyboardInterrupt:
pass
finally:
loop.close()
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]))