Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
|
||||
def setup_crash_logger():
|
||||
"""设置崩溃日志记录器"""
|
||||
# 创建logs/crash目录(如果不存在)
|
||||
@@ -11,15 +12,12 @@ def setup_crash_logger():
|
||||
crash_log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建日志记录器
|
||||
crash_logger = logging.getLogger('crash_logger')
|
||||
crash_logger = logging.getLogger("crash_logger")
|
||||
crash_logger.setLevel(logging.ERROR)
|
||||
|
||||
# 设置日志格式
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s\n'
|
||||
'异常类型: %(exc_info)s\n'
|
||||
'详细信息:\n%(message)s\n'
|
||||
'-------------------\n'
|
||||
"%(asctime)s - %(name)s - %(levelname)s\n异常类型: %(exc_info)s\n详细信息:\n%(message)s\n-------------------\n"
|
||||
)
|
||||
|
||||
# 创建按大小轮转的文件处理器(最大10MB,保留5个备份)
|
||||
@@ -28,29 +26,28 @@ def setup_crash_logger():
|
||||
log_file,
|
||||
maxBytes=10 * 1024 * 1024, # 10MB
|
||||
backupCount=5,
|
||||
encoding='utf-8'
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
crash_logger.addHandler(file_handler)
|
||||
|
||||
return crash_logger
|
||||
|
||||
|
||||
def log_crash(exc_type, exc_value, exc_traceback):
|
||||
"""记录崩溃信息到日志文件"""
|
||||
if exc_type is None:
|
||||
return
|
||||
|
||||
# 获取崩溃日志记录器
|
||||
crash_logger = logging.getLogger('crash_logger')
|
||||
crash_logger = logging.getLogger("crash_logger")
|
||||
|
||||
# 获取完整的异常堆栈信息
|
||||
stack_trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||||
stack_trace = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||||
|
||||
# 记录崩溃信息
|
||||
crash_logger.error(
|
||||
stack_trace,
|
||||
exc_info=(exc_type, exc_value, exc_traceback)
|
||||
)
|
||||
crash_logger.error(stack_trace, exc_info=(exc_type, exc_value, exc_traceback))
|
||||
|
||||
|
||||
def install_crash_handler():
|
||||
"""安装全局异常处理器"""
|
||||
|
||||
73
src/common/server.py
Normal file
73
src/common/server.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from typing import Optional, Union
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
import os
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
|
||||
self.app = FastAPI(title=app_name)
|
||||
self._host: str = "127.0.0.1"
|
||||
self._port: int = 8080
|
||||
self._server: Optional[UvicornServer] = None
|
||||
self.set_address(host, port)
|
||||
|
||||
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||
"""注册路由
|
||||
|
||||
APIRouter 用于对相关的路由端点进行分组和模块化管理:
|
||||
1. 可以将相关的端点组织在一起,便于管理
|
||||
2. 支持添加统一的路由前缀
|
||||
3. 可以为一组路由添加共同的依赖项、标签等
|
||||
|
||||
示例:
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/users")
|
||||
def get_users():
|
||||
return {"users": [...]}
|
||||
|
||||
@router.post("/users")
|
||||
def create_user():
|
||||
return {"msg": "user created"}
|
||||
|
||||
# 注册路由,添加前缀 "/api/v1"
|
||||
server.register_router(router, prefix="/api/v1")
|
||||
"""
|
||||
self.app.include_router(router, prefix=prefix)
|
||||
|
||||
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
|
||||
"""设置服务器地址和端口"""
|
||||
if host:
|
||||
self._host = host
|
||||
if port:
|
||||
self._port = port
|
||||
|
||||
async def run(self):
|
||||
"""启动服务器"""
|
||||
config = Config(app=self.app, host=self._host, port=self._port)
|
||||
self._server = UvicornServer(config=config)
|
||||
try:
|
||||
await self._server.serve()
|
||||
except KeyboardInterrupt:
|
||||
await self.shutdown()
|
||||
raise
|
||||
except Exception as e:
|
||||
await self.shutdown()
|
||||
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||
finally:
|
||||
await self.shutdown()
|
||||
|
||||
async def shutdown(self):
|
||||
"""安全关闭服务器"""
|
||||
if self._server:
|
||||
self._server.should_exit = True
|
||||
await self._server.shutdown()
|
||||
self._server = None
|
||||
|
||||
def get_app(self) -> FastAPI:
|
||||
"""获取 FastAPI 实例"""
|
||||
return self.app
|
||||
|
||||
|
||||
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))
|
||||
@@ -16,7 +16,7 @@ from .plugins.chat.bot import chat_bot
|
||||
from .common.logger import get_module_logger
|
||||
from .plugins.remote import heartbeat_thread # noqa: F401
|
||||
from .individuality.individuality import Individuality
|
||||
|
||||
from .common.server import global_server
|
||||
|
||||
logger = get_module_logger("main")
|
||||
|
||||
@@ -33,6 +33,7 @@ class MainSystem:
|
||||
from .plugins.message import global_api
|
||||
|
||||
self.app = global_api
|
||||
self.server = global_server
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化系统组件"""
|
||||
@@ -126,6 +127,7 @@ class MainSystem:
|
||||
emoji_manager.start_periodic_check_register(),
|
||||
# emoji_manager.start_periodic_register(),
|
||||
self.app.run(),
|
||||
self.server.run(),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from .conversation_info import ConversationInfo
|
||||
|
||||
logger = get_module_logger("action_planner")
|
||||
|
||||
|
||||
class ActionPlannerInfo:
|
||||
def __init__(self):
|
||||
self.done_action = []
|
||||
@@ -23,20 +24,13 @@ class ActionPlanner:
|
||||
|
||||
def __init__(self, stream_id: str):
|
||||
self.llm = LLM_request(
|
||||
model=global_config.llm_normal,
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
request_type="action_planning"
|
||||
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning"
|
||||
)
|
||||
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||
|
||||
async def plan(
|
||||
self,
|
||||
observation_info: ObservationInfo,
|
||||
conversation_info: ConversationInfo
|
||||
) -> Tuple[str, str]:
|
||||
async def plan(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> Tuple[str, str]:
|
||||
"""规划下一步行动
|
||||
|
||||
Args:
|
||||
@@ -56,7 +50,6 @@ class ActionPlanner:
|
||||
goal = "目前没有明确对话目标"
|
||||
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
||||
|
||||
|
||||
# 获取聊天历史记录
|
||||
chat_history_list = observation_info.chat_history
|
||||
chat_history_text = ""
|
||||
@@ -72,7 +65,6 @@ class ActionPlanner:
|
||||
|
||||
observation_info.clear_unprocessed_messages()
|
||||
|
||||
|
||||
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
||||
|
||||
# 构建action历史文本
|
||||
@@ -81,8 +73,6 @@ class ActionPlanner:
|
||||
for action in action_history_list:
|
||||
action_history_text += f"{action}\n"
|
||||
|
||||
|
||||
|
||||
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下内容,根据信息决定下一步行动:
|
||||
|
||||
当前对话目标:{goal}
|
||||
@@ -114,9 +104,7 @@ rethink_goal: 重新思考对话目标,当发现对话目标不合适时选择
|
||||
|
||||
# 使用简化函数提取JSON内容
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
"action", "reason",
|
||||
default_values={"action": "direct_reply", "reason": "没有明确原因"}
|
||||
content, "action", "reason", default_values={"action": "direct_reply", "reason": "没有明确原因"}
|
||||
)
|
||||
|
||||
if not success:
|
||||
|
||||
@@ -17,7 +17,7 @@ class ChatObserver:
|
||||
_instances: Dict[str, "ChatObserver"] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
|
||||
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> "ChatObserver":
|
||||
"""获取或创建观察器实例
|
||||
|
||||
Args:
|
||||
@@ -84,10 +84,7 @@ class ChatObserver:
|
||||
"""
|
||||
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||
|
||||
new_message_exists = await self.message_storage.has_new_messages(
|
||||
self.stream_id,
|
||||
self.last_check_time
|
||||
)
|
||||
new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
|
||||
|
||||
if new_message_exists:
|
||||
logger.debug("发现新消息")
|
||||
@@ -114,11 +111,7 @@ class ChatObserver:
|
||||
self.last_user_speak_time = message["time"]
|
||||
|
||||
# 发送新消息通知
|
||||
notification = create_new_message_notification(
|
||||
sender="chat_observer",
|
||||
target="pfc",
|
||||
message=message
|
||||
)
|
||||
notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
|
||||
# 检查并更新冷场状态
|
||||
@@ -144,19 +137,12 @@ class ChatObserver:
|
||||
# 如果冷场状态发生变化,发送通知
|
||||
if is_cold != self.is_cold_chat_state:
|
||||
self.is_cold_chat_state = is_cold
|
||||
notification = create_cold_chat_notification(
|
||||
sender="chat_observer",
|
||||
target="pfc",
|
||||
is_cold=is_cold
|
||||
)
|
||||
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
|
||||
async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
|
||||
messages = await self.message_storage.get_messages_after(
|
||||
self.stream_id,
|
||||
self.last_message_read
|
||||
)
|
||||
messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
|
||||
for message in messages:
|
||||
await self._add_message_to_history(message)
|
||||
return messages, self.message_history
|
||||
@@ -224,10 +210,7 @@ class ChatObserver:
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 新消息列表
|
||||
"""
|
||||
new_messages = await self.message_storage.get_messages_after(
|
||||
self.stream_id,
|
||||
self.last_message_read
|
||||
)
|
||||
new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
@@ -243,17 +226,15 @@ class ChatObserver:
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 最多5条消息
|
||||
"""
|
||||
new_messages = await self.message_storage.get_messages_before(
|
||||
self.stream_id,
|
||||
time_point
|
||||
)
|
||||
new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
||||
return new_messages
|
||||
|
||||
'''主要观察循环'''
|
||||
"""主要观察循环"""
|
||||
|
||||
async def _update_loop(self):
|
||||
"""更新循环"""
|
||||
try:
|
||||
@@ -396,17 +377,16 @@ class ChatObserver:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
try:
|
||||
messages = await self.message_storage.get_messages_for_stream(
|
||||
self.stream_id,
|
||||
limit=50
|
||||
)
|
||||
messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50)
|
||||
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
# 检查是否有新消息
|
||||
has_new_messages = False
|
||||
if messages and (not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]):
|
||||
if messages and (
|
||||
not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]
|
||||
):
|
||||
has_new_messages = True
|
||||
|
||||
self.message_cache = messages
|
||||
|
||||
@@ -4,8 +4,10 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ChatState(Enum):
|
||||
"""聊天状态枚举"""
|
||||
|
||||
NORMAL = auto() # 正常状态
|
||||
NEW_MESSAGE = auto() # 有新消息
|
||||
COLD_CHAT = auto() # 冷场状态
|
||||
@@ -15,8 +17,10 @@ class ChatState(Enum):
|
||||
SILENT = auto() # 沉默状态
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
|
||||
class NotificationType(Enum):
|
||||
"""通知类型枚举"""
|
||||
|
||||
NEW_MESSAGE = auto() # 新消息通知
|
||||
COLD_CHAT = auto() # 冷场通知
|
||||
ACTIVE_CHAT = auto() # 活跃通知
|
||||
@@ -27,9 +31,11 @@ class NotificationType(Enum):
|
||||
USER_LEFT = auto() # 用户离开通知
|
||||
ERROR = auto() # 错误通知
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatStateInfo:
|
||||
"""聊天状态信息"""
|
||||
|
||||
state: ChatState
|
||||
last_message_time: Optional[float] = None
|
||||
last_message_content: Optional[str] = None
|
||||
@@ -38,9 +44,11 @@ class ChatStateInfo:
|
||||
cold_duration: float = 0.0 # 冷场持续时间(秒)
|
||||
active_duration: float = 0.0 # 活跃持续时间(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Notification:
|
||||
"""通知基类"""
|
||||
|
||||
type: NotificationType
|
||||
timestamp: float
|
||||
sender: str # 发送者标识
|
||||
@@ -49,15 +57,13 @@ class Notification:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"type": self.type.name,
|
||||
"timestamp": self.timestamp,
|
||||
"data": self.data
|
||||
}
|
||||
return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateNotification(Notification):
|
||||
"""持续状态通知"""
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
@@ -65,6 +71,7 @@ class StateNotification(Notification):
|
||||
base_dict["is_active"] = self.is_active
|
||||
return base_dict
|
||||
|
||||
|
||||
class NotificationHandler(ABC):
|
||||
"""通知处理器接口"""
|
||||
|
||||
@@ -73,6 +80,7 @@ class NotificationHandler(ABC):
|
||||
"""处理通知"""
|
||||
pass
|
||||
|
||||
|
||||
class NotificationManager:
|
||||
"""通知管理器"""
|
||||
|
||||
@@ -141,10 +149,9 @@ class NotificationManager:
|
||||
"""检查特定状态是否活跃"""
|
||||
return state_type in self._active_states
|
||||
|
||||
def get_notification_history(self,
|
||||
sender: Optional[str] = None,
|
||||
target: Optional[str] = None,
|
||||
limit: Optional[int] = None) -> List[Notification]:
|
||||
def get_notification_history(
|
||||
self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
|
||||
) -> List[Notification]:
|
||||
"""获取通知历史
|
||||
|
||||
Args:
|
||||
@@ -164,6 +171,7 @@ class NotificationManager:
|
||||
|
||||
return history
|
||||
|
||||
|
||||
# 一些常用的通知创建函数
|
||||
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
|
||||
"""创建新消息通知"""
|
||||
@@ -176,10 +184,11 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str,
|
||||
"message_id": message.get("message_id"),
|
||||
"content": message.get("content"),
|
||||
"sender": message.get("sender"),
|
||||
"time": message.get("time")
|
||||
}
|
||||
"time": message.get("time"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
|
||||
"""创建冷场状态通知"""
|
||||
return StateNotification(
|
||||
@@ -188,9 +197,10 @@ def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> St
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_cold": is_cold},
|
||||
is_active=is_cold
|
||||
is_active=is_cold,
|
||||
)
|
||||
|
||||
|
||||
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
|
||||
"""创建活跃状态通知"""
|
||||
return StateNotification(
|
||||
@@ -199,9 +209,10 @@ def create_active_chat_notification(sender: str, target: str, is_active: bool) -
|
||||
sender=sender,
|
||||
target=target,
|
||||
data={"is_active": is_active},
|
||||
is_active=is_active
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
|
||||
class ChatStateManager:
|
||||
"""聊天状态管理器"""
|
||||
|
||||
|
||||
@@ -54,7 +54,6 @@ class Conversation:
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
|
||||
try:
|
||||
# 决策所需要的信息,包括自身自信和观察信息两部分
|
||||
# 注册观察器和观测信息
|
||||
@@ -74,7 +73,6 @@ class Conversation:
|
||||
self.should_continue = True
|
||||
asyncio.create_task(self.start())
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""开始对话流程"""
|
||||
try:
|
||||
@@ -84,16 +82,12 @@ class Conversation:
|
||||
logger.error(f"启动对话系统失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _plan_and_action_loop(self):
|
||||
"""思考步,PFC核心循环模块"""
|
||||
# 获取最近的消息历史
|
||||
while self.should_continue:
|
||||
# 使用决策信息来辅助行动规划
|
||||
action, reason = await self.action_planner.plan(
|
||||
self.observation_info,
|
||||
self.conversation_info
|
||||
)
|
||||
action, reason = await self.action_planner.plan(self.observation_info, self.conversation_info)
|
||||
if self._check_new_messages_after_planning():
|
||||
continue
|
||||
|
||||
@@ -108,7 +102,6 @@ class Conversation:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
|
||||
"""将消息字典转换为Message对象"""
|
||||
try:
|
||||
@@ -122,31 +115,31 @@ class Conversation:
|
||||
time=msg_dict["time"],
|
||||
user_info=user_info,
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
detailed_plain_text=msg_dict.get("detailed_plain_text", "")
|
||||
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"转换消息时出错: {e}")
|
||||
raise
|
||||
|
||||
async def _handle_action(self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo):
|
||||
async def _handle_action(
|
||||
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
|
||||
):
|
||||
"""处理规划的行动"""
|
||||
logger.info(f"执行行动: {action}, 原因: {reason}")
|
||||
|
||||
# 记录action历史,先设置为stop,完成后再设置为done
|
||||
conversation_info.done_action.append({
|
||||
conversation_info.done_action.append(
|
||||
{
|
||||
"action": action,
|
||||
"reason": reason,
|
||||
"status": "start",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S")
|
||||
})
|
||||
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
}
|
||||
)
|
||||
|
||||
if action == "direct_reply":
|
||||
self.state = ConversationState.GENERATING
|
||||
self.generated_reply = await self.reply_generator.generate(
|
||||
observation_info,
|
||||
conversation_info
|
||||
)
|
||||
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
|
||||
|
||||
# # 检查回复是否合适
|
||||
# is_suitable, reason, need_replan = await self.reply_generator.check_reply(
|
||||
@@ -159,12 +152,14 @@ class Conversation:
|
||||
|
||||
await self._send_reply()
|
||||
|
||||
conversation_info.done_action.append({
|
||||
conversation_info.done_action.append(
|
||||
{
|
||||
"action": action,
|
||||
"reason": reason,
|
||||
"status": "done",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S")
|
||||
})
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
}
|
||||
)
|
||||
|
||||
elif action == "fetch_knowledge":
|
||||
self.state = ConversationState.FETCHING
|
||||
@@ -175,10 +170,7 @@ class Conversation:
|
||||
|
||||
if knowledge:
|
||||
if topic not in self.conversation_info.knowledge_list:
|
||||
self.conversation_info.knowledge_list.append({
|
||||
"topic": topic,
|
||||
"knowledge": knowledge
|
||||
})
|
||||
self.conversation_info.knowledge_list.append({"topic": topic, "knowledge": knowledge})
|
||||
else:
|
||||
self.conversation_info.knowledge_list[topic] += knowledge
|
||||
|
||||
@@ -186,7 +178,6 @@ class Conversation:
|
||||
self.state = ConversationState.RETHINKING
|
||||
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
|
||||
|
||||
|
||||
elif action == "listening":
|
||||
self.state = ConversationState.LISTENING
|
||||
logger.info("倾听对方发言...")
|
||||
@@ -210,9 +201,7 @@ class Conversation:
|
||||
|
||||
latest_message = self._convert_to_message(messages[0])
|
||||
await self.direct_sender.send_message(
|
||||
chat_stream=self.chat_stream,
|
||||
content="TODO:超时消息",
|
||||
reply_to_message=latest_message
|
||||
chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"发送超时消息失败: {str(e)}")
|
||||
@@ -231,9 +220,7 @@ class Conversation:
|
||||
latest_message = self._convert_to_message(messages[0])
|
||||
try:
|
||||
await self.direct_sender.send_message(
|
||||
chat_stream=self.chat_stream,
|
||||
content=self.generated_reply,
|
||||
reply_to_message=latest_message
|
||||
chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
|
||||
)
|
||||
self.chat_observer.trigger_update() # 触发立即更新
|
||||
if not await self.chat_observer.wait_for_update():
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
class ConversationInfo:
|
||||
def __init__(self):
|
||||
self.done_action = []
|
||||
|
||||
@@ -7,6 +7,7 @@ from src.plugins.chat.message import MessageSending
|
||||
|
||||
logger = get_module_logger("message_sender")
|
||||
|
||||
|
||||
class DirectMessageSender:
|
||||
"""直接消息发送器"""
|
||||
|
||||
@@ -33,10 +34,7 @@ class DirectMessageSender:
|
||||
# 检查是否需要引用回复
|
||||
if reply_to_message:
|
||||
reply_id = reply_to_message.message_id
|
||||
message_sending = MessageSending(
|
||||
segments=segments,
|
||||
reply_to_id=reply_id
|
||||
)
|
||||
message_sending = MessageSending(segments=segments, reply_to_id=reply_id)
|
||||
else:
|
||||
message_sending = MessageSending(segments=segments)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
from src.common.database import db
|
||||
|
||||
|
||||
class MessageStorage(ABC):
|
||||
"""消息存储接口"""
|
||||
|
||||
@@ -45,6 +46,7 @@ class MessageStorage(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MongoDBMessageStorage(MessageStorage):
|
||||
"""MongoDB消息存储实现"""
|
||||
|
||||
@@ -60,32 +62,23 @@ class MongoDBMessageStorage(MessageStorage):
|
||||
if last_message:
|
||||
query["time"] = {"$gt": last_message["time"]}
|
||||
|
||||
return list(
|
||||
self.db.messages.find(query).sort("time", 1)
|
||||
)
|
||||
return list(self.db.messages.find(query).sort("time", 1))
|
||||
|
||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
query = {
|
||||
"chat_id": chat_id,
|
||||
"time": {"$lt": time_point}
|
||||
}
|
||||
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
|
||||
|
||||
messages = list(
|
||||
self.db.messages.find(query).sort("time", -1).limit(limit)
|
||||
)
|
||||
messages = list(self.db.messages.find(query).sort("time", -1).limit(limit))
|
||||
|
||||
# 将消息按时间正序排列
|
||||
messages.reverse()
|
||||
return messages
|
||||
|
||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
query = {
|
||||
"chat_id": chat_id,
|
||||
"time": {"$gt": after_time}
|
||||
}
|
||||
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
|
||||
|
||||
return self.db.messages.find_one(query) is not None
|
||||
|
||||
|
||||
# # 创建一个内存消息存储实现,用于测试
|
||||
# class InMemoryMessageStorage(MessageStorage):
|
||||
# """内存消息存储实现,主要用于测试"""
|
||||
|
||||
@@ -7,10 +7,11 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_module_logger("notification_handler")
|
||||
|
||||
|
||||
class PFCNotificationHandler(NotificationHandler):
|
||||
"""PFC通知处理器"""
|
||||
|
||||
def __init__(self, conversation: 'Conversation'):
|
||||
def __init__(self, conversation: "Conversation"):
|
||||
"""初始化PFC通知处理器
|
||||
|
||||
Args:
|
||||
@@ -68,4 +69,3 @@ class PFCNotificationHandler(NotificationHandler):
|
||||
observation_info.conversation_cold_duration = cold_duration
|
||||
|
||||
logger.info(f"对话已冷: {cold_duration}秒")
|
||||
|
||||
@@ -10,10 +10,11 @@ from .chat_states import NotificationHandler
|
||||
|
||||
logger = get_module_logger("observation_info")
|
||||
|
||||
|
||||
class ObservationInfoHandler(NotificationHandler):
|
||||
"""ObservationInfo的通知处理器"""
|
||||
|
||||
def __init__(self, observation_info: 'ObservationInfo'):
|
||||
def __init__(self, observation_info: "ObservationInfo"):
|
||||
"""初始化处理器
|
||||
|
||||
Args:
|
||||
@@ -62,8 +63,7 @@ class ObservationInfoHandler(NotificationHandler):
|
||||
# 处理消息删除通知
|
||||
message_id = data.get("message_id")
|
||||
self.observation_info.unprocessed_messages = [
|
||||
msg for msg in self.observation_info.unprocessed_messages
|
||||
if msg.get("message_id") != message_id
|
||||
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
|
||||
]
|
||||
|
||||
elif notification_type == "USER_JOINED":
|
||||
@@ -83,6 +83,7 @@ class ObservationInfoHandler(NotificationHandler):
|
||||
error_msg = data.get("error", "")
|
||||
logger.error(f"收到错误通知: {error_msg}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationInfo:
|
||||
"""决策信息类,用于收集和管理来自chat_observer的通知信息"""
|
||||
@@ -124,28 +125,20 @@ class ObservationInfo:
|
||||
"""
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info",
|
||||
notification_type="NEW_MESSAGE",
|
||||
handler=self.handler
|
||||
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info",
|
||||
notification_type="COLD_CHAT",
|
||||
handler=self.handler
|
||||
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
|
||||
)
|
||||
|
||||
def unbind_from_chat_observer(self):
|
||||
"""解除与chat_observer的绑定"""
|
||||
if self.chat_observer:
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info",
|
||||
notification_type="NEW_MESSAGE",
|
||||
handler=self.handler
|
||||
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info",
|
||||
notification_type="COLD_CHAT",
|
||||
handler=self.handler
|
||||
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
|
||||
)
|
||||
self.chat_observer = None
|
||||
|
||||
|
||||
@@ -60,7 +60,6 @@ class GoalAnalyzer:
|
||||
goal_text += f"目标:{goal};"
|
||||
goal_text += f"原因:{reason}\n"
|
||||
|
||||
|
||||
# 获取聊天历史记录
|
||||
chat_history_list = observation_info.chat_history
|
||||
chat_history_text = ""
|
||||
@@ -76,7 +75,6 @@ class GoalAnalyzer:
|
||||
|
||||
observation_info.clear_unprocessed_messages()
|
||||
|
||||
|
||||
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
||||
|
||||
# 构建action历史文本
|
||||
@@ -85,7 +83,6 @@ class GoalAnalyzer:
|
||||
for action in action_history_list:
|
||||
action_history_text += f"{action}\n"
|
||||
|
||||
|
||||
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
|
||||
这些目标应该反映出对话的不同方面和意图。
|
||||
|
||||
@@ -122,17 +119,12 @@ class GoalAnalyzer:
|
||||
|
||||
# 使用简化函数提取JSON内容
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
"goal", "reasoning",
|
||||
required_types={"goal": str, "reasoning": str}
|
||||
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
|
||||
)
|
||||
# TODO
|
||||
|
||||
|
||||
conversation_info.goal_list.append(result)
|
||||
|
||||
|
||||
|
||||
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
||||
"""更新目标列表
|
||||
|
||||
@@ -233,8 +225,10 @@ class GoalAnalyzer:
|
||||
# 尝试解析JSON
|
||||
success, result = get_items_from_json(
|
||||
content,
|
||||
"goal_achieved", "stop_conversation", "reason",
|
||||
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}
|
||||
"goal_achieved",
|
||||
"stop_conversation",
|
||||
"reason",
|
||||
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
|
||||
)
|
||||
|
||||
if not success:
|
||||
@@ -285,7 +279,6 @@ class Waiter:
|
||||
logger.info("等待中...")
|
||||
|
||||
|
||||
|
||||
class DirectMessageSender:
|
||||
"""直接发送消息到平台的发送器"""
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import traceback
|
||||
|
||||
logger = get_module_logger("pfc_manager")
|
||||
|
||||
|
||||
class PFCManager:
|
||||
"""PFC对话管理器,负责管理所有对话实例"""
|
||||
|
||||
@@ -16,7 +17,7 @@ class PFCManager:
|
||||
_initializing: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'PFCManager':
|
||||
def get_instance(cls) -> "PFCManager":
|
||||
"""获取管理器单例
|
||||
|
||||
Returns:
|
||||
@@ -60,7 +61,6 @@ class PFCManager:
|
||||
|
||||
return conversation_instance
|
||||
|
||||
|
||||
async def _initialize_conversation(self, conversation: Conversation):
|
||||
"""初始化会话实例
|
||||
|
||||
@@ -84,7 +84,6 @@ class PFCManager:
|
||||
logger.error(traceback.format_exc())
|
||||
# 清理失败的初始化
|
||||
|
||||
|
||||
async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
|
||||
"""获取已存在的会话实例
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Literal
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""对话状态"""
|
||||
|
||||
INIT = "初始化"
|
||||
RETHINKING = "重新思考"
|
||||
ANALYZING = "分析历史"
|
||||
|
||||
@@ -16,21 +16,14 @@ class ReplyGenerator:
|
||||
|
||||
def __init__(self, stream_id: str):
|
||||
self.llm = LLM_request(
|
||||
model=global_config.llm_normal,
|
||||
temperature=0.7,
|
||||
max_tokens=300,
|
||||
request_type="reply_generation"
|
||||
model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation"
|
||||
)
|
||||
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||
self.reply_checker = ReplyChecker(stream_id)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
observation_info: ObservationInfo,
|
||||
conversation_info: ConversationInfo
|
||||
) -> str:
|
||||
async def generate(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> str:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
@@ -58,7 +51,6 @@ class ReplyGenerator:
|
||||
for msg in chat_history_list:
|
||||
chat_history_text += f"{msg}\n"
|
||||
|
||||
|
||||
# 整理知识缓存
|
||||
knowledge_text = ""
|
||||
knowledge_list = conversation_info.knowledge_list
|
||||
@@ -107,12 +99,7 @@ class ReplyGenerator:
|
||||
logger.error(f"生成回复时出错: {e}")
|
||||
return "抱歉,我现在有点混乱,让我重新思考一下..."
|
||||
|
||||
async def check_reply(
|
||||
self,
|
||||
reply: str,
|
||||
goal: str,
|
||||
retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
|
||||
"""检查回复是否合适
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,6 +3,7 @@ from .chat_observer import ChatObserver
|
||||
|
||||
logger = get_module_logger("waiter")
|
||||
|
||||
|
||||
class Waiter:
|
||||
"""等待器,用于等待对话流中的事件"""
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ class ChatBot:
|
||||
chat_id = str(message.chat_stream.stream_id)
|
||||
|
||||
if global_config.enable_pfc_chatting:
|
||||
|
||||
await self.pfc_manager.get_or_create_conversation(chat_id)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from .api import BaseMessageAPI, global_api
|
||||
from .api import global_api
|
||||
from .message_base import (
|
||||
Seg,
|
||||
GroupInfo,
|
||||
@@ -14,7 +14,6 @@ from .message_base import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseMessageAPI",
|
||||
"Seg",
|
||||
"global_api",
|
||||
"GroupInfo",
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from typing import Dict, Any, Callable, List, Set
|
||||
from typing import Dict, Any, Callable, List, Set, Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.message.message_base import MessageBase
|
||||
from src.common.server import global_server
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import uvicorn
|
||||
@@ -49,13 +50,22 @@ class MessageServer(BaseMessageHandler):
|
||||
|
||||
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False):
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 18000,
|
||||
enable_token=False,
|
||||
app: Optional[FastAPI] = None,
|
||||
path: str = "/ws",
|
||||
):
|
||||
super().__init__()
|
||||
# 将类级别的处理器添加到实例处理器中
|
||||
self.message_handlers.extend(self._class_handlers)
|
||||
self.app = FastAPI()
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.path = path
|
||||
self.app = app or FastAPI()
|
||||
self.own_app = app is None # 标记是否使用自己创建的app
|
||||
self.active_websockets: Set[WebSocket] = set()
|
||||
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
|
||||
self.valid_tokens: Set[str] = set()
|
||||
@@ -63,28 +73,6 @@ class MessageServer(BaseMessageHandler):
|
||||
self._setup_routes()
|
||||
self._running = False
|
||||
|
||||
@classmethod
|
||||
def register_class_handler(cls, handler: Callable):
|
||||
"""注册类级别的消息处理器"""
|
||||
if handler not in cls._class_handlers:
|
||||
cls._class_handlers.append(handler)
|
||||
|
||||
def register_message_handler(self, handler: Callable):
|
||||
"""注册实例级别的消息处理器"""
|
||||
if handler not in self.message_handlers:
|
||||
self.message_handlers.append(handler)
|
||||
|
||||
async def verify_token(self, token: str) -> bool:
|
||||
if not self.enable_token:
|
||||
return True
|
||||
return token in self.valid_tokens
|
||||
|
||||
def add_valid_token(self, token: str):
|
||||
self.valid_tokens.add(token)
|
||||
|
||||
def remove_valid_token(self, token: str):
|
||||
self.valid_tokens.discard(token)
|
||||
|
||||
def _setup_routes(self):
|
||||
@self.app.post("/api/message")
|
||||
async def handle_message(message: Dict[str, Any]):
|
||||
@@ -125,6 +113,90 @@ class MessageServer(BaseMessageHandler):
|
||||
finally:
|
||||
self._remove_websocket(websocket, platform)
|
||||
|
||||
@classmethod
|
||||
def register_class_handler(cls, handler: Callable):
|
||||
"""注册类级别的消息处理器"""
|
||||
if handler not in cls._class_handlers:
|
||||
cls._class_handlers.append(handler)
|
||||
|
||||
def register_message_handler(self, handler: Callable):
|
||||
"""注册实例级别的消息处理器"""
|
||||
if handler not in self.message_handlers:
|
||||
self.message_handlers.append(handler)
|
||||
|
||||
async def verify_token(self, token: str) -> bool:
|
||||
if not self.enable_token:
|
||||
return True
|
||||
return token in self.valid_tokens
|
||||
|
||||
def add_valid_token(self, token: str):
|
||||
self.valid_tokens.add(token)
|
||||
|
||||
def remove_valid_token(self, token: str):
|
||||
self.valid_tokens.discard(token)
|
||||
|
||||
def run_sync(self):
|
||||
"""同步方式运行服务器"""
|
||||
if not self.own_app:
|
||||
raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法")
|
||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||
|
||||
async def run(self):
|
||||
"""异步方式运行服务器"""
|
||||
self._running = True
|
||||
try:
|
||||
if self.own_app:
|
||||
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
|
||||
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
||||
self.server = uvicorn.Server(config)
|
||||
await self.server.serve()
|
||||
else:
|
||||
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
await self.stop()
|
||||
raise
|
||||
except Exception as e:
|
||||
await self.stop()
|
||||
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
async def start_server(self):
|
||||
"""启动服务器的异步方法"""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
await self.run()
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务器"""
|
||||
# 清理platform映射
|
||||
self.platform_websockets.clear()
|
||||
|
||||
# 取消所有后台任务
|
||||
for task in self.background_tasks:
|
||||
task.cancel()
|
||||
# 等待所有任务完成
|
||||
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
||||
self.background_tasks.clear()
|
||||
|
||||
# 关闭所有WebSocket连接
|
||||
for websocket in self.active_websockets:
|
||||
await websocket.close()
|
||||
self.active_websockets.clear()
|
||||
|
||||
if hasattr(self, "server") and self.own_app:
|
||||
self._running = False
|
||||
# 正确关闭 uvicorn 服务器
|
||||
self.server.should_exit = True
|
||||
await self.server.shutdown()
|
||||
# 等待服务器完全停止
|
||||
if hasattr(self.server, "started") and self.server.started:
|
||||
await self.server.main_loop()
|
||||
# 清理处理程序
|
||||
self.message_handlers.clear()
|
||||
|
||||
def _remove_websocket(self, websocket: WebSocket, platform: str):
|
||||
"""从所有集合中移除websocket"""
|
||||
if websocket in self.active_websockets:
|
||||
@@ -161,54 +233,6 @@ class MessageServer(BaseMessageHandler):
|
||||
async def send_message(self, message: MessageBase):
|
||||
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
|
||||
|
||||
def run_sync(self):
|
||||
"""同步方式运行服务器"""
|
||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||
|
||||
async def run(self):
|
||||
"""异步方式运行服务器"""
|
||||
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
||||
self.server = uvicorn.Server(config)
|
||||
try:
|
||||
await self.server.serve()
|
||||
except KeyboardInterrupt as e:
|
||||
await self.stop()
|
||||
raise KeyboardInterrupt from e
|
||||
|
||||
async def start_server(self):
|
||||
"""启动服务器的异步方法"""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
await self.run()
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务器"""
|
||||
# 清理platform映射
|
||||
self.platform_websockets.clear()
|
||||
|
||||
# 取消所有后台任务
|
||||
for task in self.background_tasks:
|
||||
task.cancel()
|
||||
# 等待所有任务完成
|
||||
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
||||
self.background_tasks.clear()
|
||||
|
||||
# 关闭所有WebSocket连接
|
||||
for websocket in self.active_websockets:
|
||||
await websocket.close()
|
||||
self.active_websockets.clear()
|
||||
|
||||
if hasattr(self, "server"):
|
||||
self._running = False
|
||||
# 正确关闭 uvicorn 服务器
|
||||
self.server.should_exit = True
|
||||
await self.server.shutdown()
|
||||
# 等待服务器完全停止
|
||||
if hasattr(self.server, "started") and self.server.started:
|
||||
await self.server.main_loop()
|
||||
# 清理处理程序
|
||||
self.message_handlers.clear()
|
||||
|
||||
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送消息到指定端点"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -219,105 +243,4 @@ class MessageServer(BaseMessageHandler):
|
||||
raise e
|
||||
|
||||
|
||||
class BaseMessageAPI:
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
|
||||
self.app = FastAPI()
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.message_handlers: List[Callable] = []
|
||||
self.cache = []
|
||||
self._setup_routes()
|
||||
self._running = False
|
||||
|
||||
def _setup_routes(self):
|
||||
"""设置基础路由"""
|
||||
|
||||
@self.app.post("/api/message")
|
||||
async def handle_message(message: Dict[str, Any]):
|
||||
try:
|
||||
# 创建后台任务处理消息
|
||||
asyncio.create_task(self._background_message_handler(message))
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
async def _background_message_handler(self, message: Dict[str, Any]):
|
||||
"""后台处理单个消息"""
|
||||
try:
|
||||
await self.process_single_message(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Background message processing failed: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def register_message_handler(self, handler: Callable):
|
||||
"""注册消息处理函数"""
|
||||
self.message_handlers.append(handler)
|
||||
|
||||
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送消息到指定端点"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
|
||||
return await response.json()
|
||||
except Exception:
|
||||
# logger.error(f"发送消息失败: {str(e)}")
|
||||
pass
|
||||
|
||||
async def process_single_message(self, message: Dict[str, Any]):
|
||||
"""处理单条消息"""
|
||||
tasks = []
|
||||
for handler in self.message_handlers:
|
||||
try:
|
||||
tasks.append(handler(message))
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
logger.error(traceback.format_exc())
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
def run_sync(self):
|
||||
"""同步方式运行服务器"""
|
||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||
|
||||
async def run(self):
|
||||
"""异步方式运行服务器"""
|
||||
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
||||
self.server = uvicorn.Server(config)
|
||||
try:
|
||||
await self.server.serve()
|
||||
except KeyboardInterrupt as e:
|
||||
await self.stop()
|
||||
raise KeyboardInterrupt from e
|
||||
|
||||
async def start_server(self):
|
||||
"""启动服务器的异步方法"""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
await self.run()
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务器"""
|
||||
if hasattr(self, "server"):
|
||||
self._running = False
|
||||
# 正确关闭 uvicorn 服务器
|
||||
self.server.should_exit = True
|
||||
await self.server.shutdown()
|
||||
# 等待服务器完全停止
|
||||
if hasattr(self.server, "started") and self.server.started:
|
||||
await self.server.main_loop()
|
||||
# 清理处理程序
|
||||
self.message_handlers.clear()
|
||||
|
||||
def start(self):
|
||||
"""启动服务器的便捷方法"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(self.start_server())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]))
|
||||
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
|
||||
|
||||
Reference in New Issue
Block a user