Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
|
|
||||||
def setup_crash_logger():
|
def setup_crash_logger():
|
||||||
"""设置崩溃日志记录器"""
|
"""设置崩溃日志记录器"""
|
||||||
# 创建logs/crash目录(如果不存在)
|
# 创建logs/crash目录(如果不存在)
|
||||||
@@ -11,15 +12,12 @@ def setup_crash_logger():
|
|||||||
crash_log_dir.mkdir(parents=True, exist_ok=True)
|
crash_log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# 创建日志记录器
|
# 创建日志记录器
|
||||||
crash_logger = logging.getLogger('crash_logger')
|
crash_logger = logging.getLogger("crash_logger")
|
||||||
crash_logger.setLevel(logging.ERROR)
|
crash_logger.setLevel(logging.ERROR)
|
||||||
|
|
||||||
# 设置日志格式
|
# 设置日志格式
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
'%(asctime)s - %(name)s - %(levelname)s\n'
|
"%(asctime)s - %(name)s - %(levelname)s\n异常类型: %(exc_info)s\n详细信息:\n%(message)s\n-------------------\n"
|
||||||
'异常类型: %(exc_info)s\n'
|
|
||||||
'详细信息:\n%(message)s\n'
|
|
||||||
'-------------------\n'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建按大小轮转的文件处理器(最大10MB,保留5个备份)
|
# 创建按大小轮转的文件处理器(最大10MB,保留5个备份)
|
||||||
@@ -28,29 +26,28 @@ def setup_crash_logger():
|
|||||||
log_file,
|
log_file,
|
||||||
maxBytes=10 * 1024 * 1024, # 10MB
|
maxBytes=10 * 1024 * 1024, # 10MB
|
||||||
backupCount=5,
|
backupCount=5,
|
||||||
encoding='utf-8'
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
crash_logger.addHandler(file_handler)
|
crash_logger.addHandler(file_handler)
|
||||||
|
|
||||||
return crash_logger
|
return crash_logger
|
||||||
|
|
||||||
|
|
||||||
def log_crash(exc_type, exc_value, exc_traceback):
|
def log_crash(exc_type, exc_value, exc_traceback):
|
||||||
"""记录崩溃信息到日志文件"""
|
"""记录崩溃信息到日志文件"""
|
||||||
if exc_type is None:
|
if exc_type is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 获取崩溃日志记录器
|
# 获取崩溃日志记录器
|
||||||
crash_logger = logging.getLogger('crash_logger')
|
crash_logger = logging.getLogger("crash_logger")
|
||||||
|
|
||||||
# 获取完整的异常堆栈信息
|
# 获取完整的异常堆栈信息
|
||||||
stack_trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
stack_trace = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||||||
|
|
||||||
# 记录崩溃信息
|
# 记录崩溃信息
|
||||||
crash_logger.error(
|
crash_logger.error(stack_trace, exc_info=(exc_type, exc_value, exc_traceback))
|
||||||
stack_trace,
|
|
||||||
exc_info=(exc_type, exc_value, exc_traceback)
|
|
||||||
)
|
|
||||||
|
|
||||||
def install_crash_handler():
|
def install_crash_handler():
|
||||||
"""安装全局异常处理器"""
|
"""安装全局异常处理器"""
|
||||||
|
|||||||
73
src/common/server.py
Normal file
73
src/common/server.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from fastapi import FastAPI, APIRouter
|
||||||
|
from typing import Optional, Union
|
||||||
|
from uvicorn import Config, Server as UvicornServer
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class Server:
|
||||||
|
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
|
||||||
|
self.app = FastAPI(title=app_name)
|
||||||
|
self._host: str = "127.0.0.1"
|
||||||
|
self._port: int = 8080
|
||||||
|
self._server: Optional[UvicornServer] = None
|
||||||
|
self.set_address(host, port)
|
||||||
|
|
||||||
|
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||||
|
"""注册路由
|
||||||
|
|
||||||
|
APIRouter 用于对相关的路由端点进行分组和模块化管理:
|
||||||
|
1. 可以将相关的端点组织在一起,便于管理
|
||||||
|
2. 支持添加统一的路由前缀
|
||||||
|
3. 可以为一组路由添加共同的依赖项、标签等
|
||||||
|
|
||||||
|
示例:
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
def get_users():
|
||||||
|
return {"users": [...]}
|
||||||
|
|
||||||
|
@router.post("/users")
|
||||||
|
def create_user():
|
||||||
|
return {"msg": "user created"}
|
||||||
|
|
||||||
|
# 注册路由,添加前缀 "/api/v1"
|
||||||
|
server.register_router(router, prefix="/api/v1")
|
||||||
|
"""
|
||||||
|
self.app.include_router(router, prefix=prefix)
|
||||||
|
|
||||||
|
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
|
||||||
|
"""设置服务器地址和端口"""
|
||||||
|
if host:
|
||||||
|
self._host = host
|
||||||
|
if port:
|
||||||
|
self._port = port
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""启动服务器"""
|
||||||
|
config = Config(app=self.app, host=self._host, port=self._port)
|
||||||
|
self._server = UvicornServer(config=config)
|
||||||
|
try:
|
||||||
|
await self._server.serve()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
await self.shutdown()
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await self.shutdown()
|
||||||
|
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||||
|
finally:
|
||||||
|
await self.shutdown()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""安全关闭服务器"""
|
||||||
|
if self._server:
|
||||||
|
self._server.should_exit = True
|
||||||
|
await self._server.shutdown()
|
||||||
|
self._server = None
|
||||||
|
|
||||||
|
def get_app(self) -> FastAPI:
|
||||||
|
"""获取 FastAPI 实例"""
|
||||||
|
return self.app
|
||||||
|
|
||||||
|
|
||||||
|
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))
|
||||||
@@ -16,7 +16,7 @@ from .plugins.chat.bot import chat_bot
|
|||||||
from .common.logger import get_module_logger
|
from .common.logger import get_module_logger
|
||||||
from .plugins.remote import heartbeat_thread # noqa: F401
|
from .plugins.remote import heartbeat_thread # noqa: F401
|
||||||
from .individuality.individuality import Individuality
|
from .individuality.individuality import Individuality
|
||||||
|
from .common.server import global_server
|
||||||
|
|
||||||
logger = get_module_logger("main")
|
logger = get_module_logger("main")
|
||||||
|
|
||||||
@@ -33,6 +33,7 @@ class MainSystem:
|
|||||||
from .plugins.message import global_api
|
from .plugins.message import global_api
|
||||||
|
|
||||||
self.app = global_api
|
self.app = global_api
|
||||||
|
self.server = global_server
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化系统组件"""
|
"""初始化系统组件"""
|
||||||
@@ -126,6 +127,7 @@ class MainSystem:
|
|||||||
emoji_manager.start_periodic_check_register(),
|
emoji_manager.start_periodic_check_register(),
|
||||||
# emoji_manager.start_periodic_register(),
|
# emoji_manager.start_periodic_register(),
|
||||||
self.app.run(),
|
self.app.run(),
|
||||||
|
self.server.run(),
|
||||||
]
|
]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from .conversation_info import ConversationInfo
|
|||||||
|
|
||||||
logger = get_module_logger("action_planner")
|
logger = get_module_logger("action_planner")
|
||||||
|
|
||||||
|
|
||||||
class ActionPlannerInfo:
|
class ActionPlannerInfo:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.done_action = []
|
self.done_action = []
|
||||||
@@ -23,20 +24,13 @@ class ActionPlanner:
|
|||||||
|
|
||||||
def __init__(self, stream_id: str):
|
def __init__(self, stream_id: str):
|
||||||
self.llm = LLM_request(
|
self.llm = LLM_request(
|
||||||
model=global_config.llm_normal,
|
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning"
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=1000,
|
|
||||||
request_type="action_planning"
|
|
||||||
)
|
)
|
||||||
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.BOT_NICKNAME
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||||
|
|
||||||
async def plan(
|
async def plan(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> Tuple[str, str]:
|
||||||
self,
|
|
||||||
observation_info: ObservationInfo,
|
|
||||||
conversation_info: ConversationInfo
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
"""规划下一步行动
|
"""规划下一步行动
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -56,7 +50,6 @@ class ActionPlanner:
|
|||||||
goal = "目前没有明确对话目标"
|
goal = "目前没有明确对话目标"
|
||||||
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
||||||
|
|
||||||
|
|
||||||
# 获取聊天历史记录
|
# 获取聊天历史记录
|
||||||
chat_history_list = observation_info.chat_history
|
chat_history_list = observation_info.chat_history
|
||||||
chat_history_text = ""
|
chat_history_text = ""
|
||||||
@@ -72,7 +65,6 @@ class ActionPlanner:
|
|||||||
|
|
||||||
observation_info.clear_unprocessed_messages()
|
observation_info.clear_unprocessed_messages()
|
||||||
|
|
||||||
|
|
||||||
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
||||||
|
|
||||||
# 构建action历史文本
|
# 构建action历史文本
|
||||||
@@ -81,8 +73,6 @@ class ActionPlanner:
|
|||||||
for action in action_history_list:
|
for action in action_history_list:
|
||||||
action_history_text += f"{action}\n"
|
action_history_text += f"{action}\n"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下内容,根据信息决定下一步行动:
|
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下内容,根据信息决定下一步行动:
|
||||||
|
|
||||||
当前对话目标:{goal}
|
当前对话目标:{goal}
|
||||||
@@ -114,9 +104,7 @@ rethink_goal: 重新思考对话目标,当发现对话目标不合适时选择
|
|||||||
|
|
||||||
# 使用简化函数提取JSON内容
|
# 使用简化函数提取JSON内容
|
||||||
success, result = get_items_from_json(
|
success, result = get_items_from_json(
|
||||||
content,
|
content, "action", "reason", default_values={"action": "direct_reply", "reason": "没有明确原因"}
|
||||||
"action", "reason",
|
|
||||||
default_values={"action": "direct_reply", "reason": "没有明确原因"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class ChatObserver:
|
|||||||
_instances: Dict[str, "ChatObserver"] = {}
|
_instances: Dict[str, "ChatObserver"] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
|
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> "ChatObserver":
|
||||||
"""获取或创建观察器实例
|
"""获取或创建观察器实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -84,10 +84,7 @@ class ChatObserver:
|
|||||||
"""
|
"""
|
||||||
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||||
|
|
||||||
new_message_exists = await self.message_storage.has_new_messages(
|
new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
|
||||||
self.stream_id,
|
|
||||||
self.last_check_time
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_message_exists:
|
if new_message_exists:
|
||||||
logger.debug("发现新消息")
|
logger.debug("发现新消息")
|
||||||
@@ -114,11 +111,7 @@ class ChatObserver:
|
|||||||
self.last_user_speak_time = message["time"]
|
self.last_user_speak_time = message["time"]
|
||||||
|
|
||||||
# 发送新消息通知
|
# 发送新消息通知
|
||||||
notification = create_new_message_notification(
|
notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message)
|
||||||
sender="chat_observer",
|
|
||||||
target="pfc",
|
|
||||||
message=message
|
|
||||||
)
|
|
||||||
await self.notification_manager.send_notification(notification)
|
await self.notification_manager.send_notification(notification)
|
||||||
|
|
||||||
# 检查并更新冷场状态
|
# 检查并更新冷场状态
|
||||||
@@ -144,19 +137,12 @@ class ChatObserver:
|
|||||||
# 如果冷场状态发生变化,发送通知
|
# 如果冷场状态发生变化,发送通知
|
||||||
if is_cold != self.is_cold_chat_state:
|
if is_cold != self.is_cold_chat_state:
|
||||||
self.is_cold_chat_state = is_cold
|
self.is_cold_chat_state = is_cold
|
||||||
notification = create_cold_chat_notification(
|
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
|
||||||
sender="chat_observer",
|
|
||||||
target="pfc",
|
|
||||||
is_cold=is_cold
|
|
||||||
)
|
|
||||||
await self.notification_manager.send_notification(notification)
|
await self.notification_manager.send_notification(notification)
|
||||||
|
|
||||||
async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
|
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
|
||||||
messages = await self.message_storage.get_messages_after(
|
messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
|
||||||
self.stream_id,
|
|
||||||
self.last_message_read
|
|
||||||
)
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
await self._add_message_to_history(message)
|
await self._add_message_to_history(message)
|
||||||
return messages, self.message_history
|
return messages, self.message_history
|
||||||
@@ -224,10 +210,7 @@ class ChatObserver:
|
|||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, Any]]: 新消息列表
|
List[Dict[str, Any]]: 新消息列表
|
||||||
"""
|
"""
|
||||||
new_messages = await self.message_storage.get_messages_after(
|
new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
|
||||||
self.stream_id,
|
|
||||||
self.last_message_read
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_messages:
|
if new_messages:
|
||||||
self.last_message_read = new_messages[-1]["message_id"]
|
self.last_message_read = new_messages[-1]["message_id"]
|
||||||
@@ -243,17 +226,15 @@ class ChatObserver:
|
|||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, Any]]: 最多5条消息
|
List[Dict[str, Any]]: 最多5条消息
|
||||||
"""
|
"""
|
||||||
new_messages = await self.message_storage.get_messages_before(
|
new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
|
||||||
self.stream_id,
|
|
||||||
time_point
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_messages:
|
if new_messages:
|
||||||
self.last_message_read = new_messages[-1]["message_id"]
|
self.last_message_read = new_messages[-1]["message_id"]
|
||||||
|
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
'''主要观察循环'''
|
"""主要观察循环"""
|
||||||
|
|
||||||
async def _update_loop(self):
|
async def _update_loop(self):
|
||||||
"""更新循环"""
|
"""更新循环"""
|
||||||
try:
|
try:
|
||||||
@@ -396,17 +377,16 @@ class ChatObserver:
|
|||||||
bool: 是否有新消息
|
bool: 是否有新消息
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
messages = await self.message_storage.get_messages_for_stream(
|
messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50)
|
||||||
self.stream_id,
|
|
||||||
limit=50
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查是否有新消息
|
# 检查是否有新消息
|
||||||
has_new_messages = False
|
has_new_messages = False
|
||||||
if messages and (not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]):
|
if messages and (
|
||||||
|
not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]
|
||||||
|
):
|
||||||
has_new_messages = True
|
has_new_messages = True
|
||||||
|
|
||||||
self.message_cache = messages
|
self.message_cache = messages
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from dataclasses import dataclass
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
class ChatState(Enum):
|
class ChatState(Enum):
|
||||||
"""聊天状态枚举"""
|
"""聊天状态枚举"""
|
||||||
|
|
||||||
NORMAL = auto() # 正常状态
|
NORMAL = auto() # 正常状态
|
||||||
NEW_MESSAGE = auto() # 有新消息
|
NEW_MESSAGE = auto() # 有新消息
|
||||||
COLD_CHAT = auto() # 冷场状态
|
COLD_CHAT = auto() # 冷场状态
|
||||||
@@ -15,8 +17,10 @@ class ChatState(Enum):
|
|||||||
SILENT = auto() # 沉默状态
|
SILENT = auto() # 沉默状态
|
||||||
ERROR = auto() # 错误状态
|
ERROR = auto() # 错误状态
|
||||||
|
|
||||||
|
|
||||||
class NotificationType(Enum):
|
class NotificationType(Enum):
|
||||||
"""通知类型枚举"""
|
"""通知类型枚举"""
|
||||||
|
|
||||||
NEW_MESSAGE = auto() # 新消息通知
|
NEW_MESSAGE = auto() # 新消息通知
|
||||||
COLD_CHAT = auto() # 冷场通知
|
COLD_CHAT = auto() # 冷场通知
|
||||||
ACTIVE_CHAT = auto() # 活跃通知
|
ACTIVE_CHAT = auto() # 活跃通知
|
||||||
@@ -27,9 +31,11 @@ class NotificationType(Enum):
|
|||||||
USER_LEFT = auto() # 用户离开通知
|
USER_LEFT = auto() # 用户离开通知
|
||||||
ERROR = auto() # 错误通知
|
ERROR = auto() # 错误通知
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatStateInfo:
|
class ChatStateInfo:
|
||||||
"""聊天状态信息"""
|
"""聊天状态信息"""
|
||||||
|
|
||||||
state: ChatState
|
state: ChatState
|
||||||
last_message_time: Optional[float] = None
|
last_message_time: Optional[float] = None
|
||||||
last_message_content: Optional[str] = None
|
last_message_content: Optional[str] = None
|
||||||
@@ -38,9 +44,11 @@ class ChatStateInfo:
|
|||||||
cold_duration: float = 0.0 # 冷场持续时间(秒)
|
cold_duration: float = 0.0 # 冷场持续时间(秒)
|
||||||
active_duration: float = 0.0 # 活跃持续时间(秒)
|
active_duration: float = 0.0 # 活跃持续时间(秒)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Notification:
|
class Notification:
|
||||||
"""通知基类"""
|
"""通知基类"""
|
||||||
|
|
||||||
type: NotificationType
|
type: NotificationType
|
||||||
timestamp: float
|
timestamp: float
|
||||||
sender: str # 发送者标识
|
sender: str # 发送者标识
|
||||||
@@ -49,15 +57,13 @@ class Notification:
|
|||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
return {
|
return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
|
||||||
"type": self.type.name,
|
|
||||||
"timestamp": self.timestamp,
|
|
||||||
"data": self.data
|
|
||||||
}
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StateNotification(Notification):
|
class StateNotification(Notification):
|
||||||
"""持续状态通知"""
|
"""持续状态通知"""
|
||||||
|
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
@@ -65,6 +71,7 @@ class StateNotification(Notification):
|
|||||||
base_dict["is_active"] = self.is_active
|
base_dict["is_active"] = self.is_active
|
||||||
return base_dict
|
return base_dict
|
||||||
|
|
||||||
|
|
||||||
class NotificationHandler(ABC):
|
class NotificationHandler(ABC):
|
||||||
"""通知处理器接口"""
|
"""通知处理器接口"""
|
||||||
|
|
||||||
@@ -73,6 +80,7 @@ class NotificationHandler(ABC):
|
|||||||
"""处理通知"""
|
"""处理通知"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NotificationManager:
|
class NotificationManager:
|
||||||
"""通知管理器"""
|
"""通知管理器"""
|
||||||
|
|
||||||
@@ -141,10 +149,9 @@ class NotificationManager:
|
|||||||
"""检查特定状态是否活跃"""
|
"""检查特定状态是否活跃"""
|
||||||
return state_type in self._active_states
|
return state_type in self._active_states
|
||||||
|
|
||||||
def get_notification_history(self,
|
def get_notification_history(
|
||||||
sender: Optional[str] = None,
|
self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
|
||||||
target: Optional[str] = None,
|
) -> List[Notification]:
|
||||||
limit: Optional[int] = None) -> List[Notification]:
|
|
||||||
"""获取通知历史
|
"""获取通知历史
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -164,6 +171,7 @@ class NotificationManager:
|
|||||||
|
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
# 一些常用的通知创建函数
|
# 一些常用的通知创建函数
|
||||||
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
|
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
|
||||||
"""创建新消息通知"""
|
"""创建新消息通知"""
|
||||||
@@ -176,10 +184,11 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str,
|
|||||||
"message_id": message.get("message_id"),
|
"message_id": message.get("message_id"),
|
||||||
"content": message.get("content"),
|
"content": message.get("content"),
|
||||||
"sender": message.get("sender"),
|
"sender": message.get("sender"),
|
||||||
"time": message.get("time")
|
"time": message.get("time"),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
|
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
|
||||||
"""创建冷场状态通知"""
|
"""创建冷场状态通知"""
|
||||||
return StateNotification(
|
return StateNotification(
|
||||||
@@ -188,9 +197,10 @@ def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> St
|
|||||||
sender=sender,
|
sender=sender,
|
||||||
target=target,
|
target=target,
|
||||||
data={"is_cold": is_cold},
|
data={"is_cold": is_cold},
|
||||||
is_active=is_cold
|
is_active=is_cold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
|
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
|
||||||
"""创建活跃状态通知"""
|
"""创建活跃状态通知"""
|
||||||
return StateNotification(
|
return StateNotification(
|
||||||
@@ -199,9 +209,10 @@ def create_active_chat_notification(sender: str, target: str, is_active: bool) -
|
|||||||
sender=sender,
|
sender=sender,
|
||||||
target=target,
|
target=target,
|
||||||
data={"is_active": is_active},
|
data={"is_active": is_active},
|
||||||
is_active=is_active
|
is_active=is_active,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatStateManager:
|
class ChatStateManager:
|
||||||
"""聊天状态管理器"""
|
"""聊天状态管理器"""
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,6 @@ class Conversation:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 决策所需要的信息,包括自身自信和观察信息两部分
|
# 决策所需要的信息,包括自身自信和观察信息两部分
|
||||||
# 注册观察器和观测信息
|
# 注册观察器和观测信息
|
||||||
@@ -74,7 +73,6 @@ class Conversation:
|
|||||||
self.should_continue = True
|
self.should_continue = True
|
||||||
asyncio.create_task(self.start())
|
asyncio.create_task(self.start())
|
||||||
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""开始对话流程"""
|
"""开始对话流程"""
|
||||||
try:
|
try:
|
||||||
@@ -84,16 +82,12 @@ class Conversation:
|
|||||||
logger.error(f"启动对话系统失败: {e}")
|
logger.error(f"启动对话系统失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _plan_and_action_loop(self):
|
async def _plan_and_action_loop(self):
|
||||||
"""思考步,PFC核心循环模块"""
|
"""思考步,PFC核心循环模块"""
|
||||||
# 获取最近的消息历史
|
# 获取最近的消息历史
|
||||||
while self.should_continue:
|
while self.should_continue:
|
||||||
# 使用决策信息来辅助行动规划
|
# 使用决策信息来辅助行动规划
|
||||||
action, reason = await self.action_planner.plan(
|
action, reason = await self.action_planner.plan(self.observation_info, self.conversation_info)
|
||||||
self.observation_info,
|
|
||||||
self.conversation_info
|
|
||||||
)
|
|
||||||
if self._check_new_messages_after_planning():
|
if self._check_new_messages_after_planning():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -108,7 +102,6 @@ class Conversation:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
|
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
|
||||||
"""将消息字典转换为Message对象"""
|
"""将消息字典转换为Message对象"""
|
||||||
try:
|
try:
|
||||||
@@ -122,31 +115,31 @@ class Conversation:
|
|||||||
time=msg_dict["time"],
|
time=msg_dict["time"],
|
||||||
user_info=user_info,
|
user_info=user_info,
|
||||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||||
detailed_plain_text=msg_dict.get("detailed_plain_text", "")
|
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"转换消息时出错: {e}")
|
logger.warning(f"转换消息时出错: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _handle_action(self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo):
|
async def _handle_action(
|
||||||
|
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
|
||||||
|
):
|
||||||
"""处理规划的行动"""
|
"""处理规划的行动"""
|
||||||
logger.info(f"执行行动: {action}, 原因: {reason}")
|
logger.info(f"执行行动: {action}, 原因: {reason}")
|
||||||
|
|
||||||
# 记录action历史,先设置为stop,完成后再设置为done
|
# 记录action历史,先设置为stop,完成后再设置为done
|
||||||
conversation_info.done_action.append({
|
conversation_info.done_action.append(
|
||||||
|
{
|
||||||
"action": action,
|
"action": action,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
"status": "start",
|
"status": "start",
|
||||||
"time": datetime.datetime.now().strftime("%H:%M:%S")
|
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if action == "direct_reply":
|
if action == "direct_reply":
|
||||||
self.state = ConversationState.GENERATING
|
self.state = ConversationState.GENERATING
|
||||||
self.generated_reply = await self.reply_generator.generate(
|
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
|
||||||
observation_info,
|
|
||||||
conversation_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# # 检查回复是否合适
|
# # 检查回复是否合适
|
||||||
# is_suitable, reason, need_replan = await self.reply_generator.check_reply(
|
# is_suitable, reason, need_replan = await self.reply_generator.check_reply(
|
||||||
@@ -159,12 +152,14 @@ class Conversation:
|
|||||||
|
|
||||||
await self._send_reply()
|
await self._send_reply()
|
||||||
|
|
||||||
conversation_info.done_action.append({
|
conversation_info.done_action.append(
|
||||||
|
{
|
||||||
"action": action,
|
"action": action,
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
"status": "done",
|
"status": "done",
|
||||||
"time": datetime.datetime.now().strftime("%H:%M:%S")
|
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
elif action == "fetch_knowledge":
|
elif action == "fetch_knowledge":
|
||||||
self.state = ConversationState.FETCHING
|
self.state = ConversationState.FETCHING
|
||||||
@@ -175,10 +170,7 @@ class Conversation:
|
|||||||
|
|
||||||
if knowledge:
|
if knowledge:
|
||||||
if topic not in self.conversation_info.knowledge_list:
|
if topic not in self.conversation_info.knowledge_list:
|
||||||
self.conversation_info.knowledge_list.append({
|
self.conversation_info.knowledge_list.append({"topic": topic, "knowledge": knowledge})
|
||||||
"topic": topic,
|
|
||||||
"knowledge": knowledge
|
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
self.conversation_info.knowledge_list[topic] += knowledge
|
self.conversation_info.knowledge_list[topic] += knowledge
|
||||||
|
|
||||||
@@ -186,7 +178,6 @@ class Conversation:
|
|||||||
self.state = ConversationState.RETHINKING
|
self.state = ConversationState.RETHINKING
|
||||||
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
|
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
|
||||||
|
|
||||||
|
|
||||||
elif action == "listening":
|
elif action == "listening":
|
||||||
self.state = ConversationState.LISTENING
|
self.state = ConversationState.LISTENING
|
||||||
logger.info("倾听对方发言...")
|
logger.info("倾听对方发言...")
|
||||||
@@ -210,9 +201,7 @@ class Conversation:
|
|||||||
|
|
||||||
latest_message = self._convert_to_message(messages[0])
|
latest_message = self._convert_to_message(messages[0])
|
||||||
await self.direct_sender.send_message(
|
await self.direct_sender.send_message(
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
|
||||||
content="TODO:超时消息",
|
|
||||||
reply_to_message=latest_message
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"发送超时消息失败: {str(e)}")
|
logger.error(f"发送超时消息失败: {str(e)}")
|
||||||
@@ -231,9 +220,7 @@ class Conversation:
|
|||||||
latest_message = self._convert_to_message(messages[0])
|
latest_message = self._convert_to_message(messages[0])
|
||||||
try:
|
try:
|
||||||
await self.direct_sender.send_message(
|
await self.direct_sender.send_message(
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
|
||||||
content=self.generated_reply,
|
|
||||||
reply_to_message=latest_message
|
|
||||||
)
|
)
|
||||||
self.chat_observer.trigger_update() # 触发立即更新
|
self.chat_observer.trigger_update() # 触发立即更新
|
||||||
if not await self.chat_observer.wait_for_update():
|
if not await self.chat_observer.wait_for_update():
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
|
|
||||||
class ConversationInfo:
|
class ConversationInfo:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.done_action = []
|
self.done_action = []
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from src.plugins.chat.message import MessageSending
|
|||||||
|
|
||||||
logger = get_module_logger("message_sender")
|
logger = get_module_logger("message_sender")
|
||||||
|
|
||||||
|
|
||||||
class DirectMessageSender:
|
class DirectMessageSender:
|
||||||
"""直接消息发送器"""
|
"""直接消息发送器"""
|
||||||
|
|
||||||
@@ -33,10 +34,7 @@ class DirectMessageSender:
|
|||||||
# 检查是否需要引用回复
|
# 检查是否需要引用回复
|
||||||
if reply_to_message:
|
if reply_to_message:
|
||||||
reply_id = reply_to_message.message_id
|
reply_id = reply_to_message.message_id
|
||||||
message_sending = MessageSending(
|
message_sending = MessageSending(segments=segments, reply_to_id=reply_id)
|
||||||
segments=segments,
|
|
||||||
reply_to_id=reply_id
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
message_sending = MessageSending(segments=segments)
|
message_sending = MessageSending(segments=segments)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from src.common.database import db
|
from src.common.database import db
|
||||||
|
|
||||||
|
|
||||||
class MessageStorage(ABC):
|
class MessageStorage(ABC):
|
||||||
"""消息存储接口"""
|
"""消息存储接口"""
|
||||||
|
|
||||||
@@ -45,6 +46,7 @@ class MessageStorage(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MongoDBMessageStorage(MessageStorage):
|
class MongoDBMessageStorage(MessageStorage):
|
||||||
"""MongoDB消息存储实现"""
|
"""MongoDB消息存储实现"""
|
||||||
|
|
||||||
@@ -60,32 +62,23 @@ class MongoDBMessageStorage(MessageStorage):
|
|||||||
if last_message:
|
if last_message:
|
||||||
query["time"] = {"$gt": last_message["time"]}
|
query["time"] = {"$gt": last_message["time"]}
|
||||||
|
|
||||||
return list(
|
return list(self.db.messages.find(query).sort("time", 1))
|
||||||
self.db.messages.find(query).sort("time", 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||||
query = {
|
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
|
||||||
"chat_id": chat_id,
|
|
||||||
"time": {"$lt": time_point}
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = list(
|
messages = list(self.db.messages.find(query).sort("time", -1).limit(limit))
|
||||||
self.db.messages.find(query).sort("time", -1).limit(limit)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 将消息按时间正序排列
|
# 将消息按时间正序排列
|
||||||
messages.reverse()
|
messages.reverse()
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||||
query = {
|
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
|
||||||
"chat_id": chat_id,
|
|
||||||
"time": {"$gt": after_time}
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.db.messages.find_one(query) is not None
|
return self.db.messages.find_one(query) is not None
|
||||||
|
|
||||||
|
|
||||||
# # 创建一个内存消息存储实现,用于测试
|
# # 创建一个内存消息存储实现,用于测试
|
||||||
# class InMemoryMessageStorage(MessageStorage):
|
# class InMemoryMessageStorage(MessageStorage):
|
||||||
# """内存消息存储实现,主要用于测试"""
|
# """内存消息存储实现,主要用于测试"""
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = get_module_logger("notification_handler")
|
logger = get_module_logger("notification_handler")
|
||||||
|
|
||||||
|
|
||||||
class PFCNotificationHandler(NotificationHandler):
|
class PFCNotificationHandler(NotificationHandler):
|
||||||
"""PFC通知处理器"""
|
"""PFC通知处理器"""
|
||||||
|
|
||||||
def __init__(self, conversation: 'Conversation'):
|
def __init__(self, conversation: "Conversation"):
|
||||||
"""初始化PFC通知处理器
|
"""初始化PFC通知处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -68,4 +69,3 @@ class PFCNotificationHandler(NotificationHandler):
|
|||||||
observation_info.conversation_cold_duration = cold_duration
|
observation_info.conversation_cold_duration = cold_duration
|
||||||
|
|
||||||
logger.info(f"对话已冷: {cold_duration}秒")
|
logger.info(f"对话已冷: {cold_duration}秒")
|
||||||
|
|
||||||
@@ -10,10 +10,11 @@ from .chat_states import NotificationHandler
|
|||||||
|
|
||||||
logger = get_module_logger("observation_info")
|
logger = get_module_logger("observation_info")
|
||||||
|
|
||||||
|
|
||||||
class ObservationInfoHandler(NotificationHandler):
|
class ObservationInfoHandler(NotificationHandler):
|
||||||
"""ObservationInfo的通知处理器"""
|
"""ObservationInfo的通知处理器"""
|
||||||
|
|
||||||
def __init__(self, observation_info: 'ObservationInfo'):
|
def __init__(self, observation_info: "ObservationInfo"):
|
||||||
"""初始化处理器
|
"""初始化处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -62,8 +63,7 @@ class ObservationInfoHandler(NotificationHandler):
|
|||||||
# 处理消息删除通知
|
# 处理消息删除通知
|
||||||
message_id = data.get("message_id")
|
message_id = data.get("message_id")
|
||||||
self.observation_info.unprocessed_messages = [
|
self.observation_info.unprocessed_messages = [
|
||||||
msg for msg in self.observation_info.unprocessed_messages
|
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
|
||||||
if msg.get("message_id") != message_id
|
|
||||||
]
|
]
|
||||||
|
|
||||||
elif notification_type == "USER_JOINED":
|
elif notification_type == "USER_JOINED":
|
||||||
@@ -83,6 +83,7 @@ class ObservationInfoHandler(NotificationHandler):
|
|||||||
error_msg = data.get("error", "")
|
error_msg = data.get("error", "")
|
||||||
logger.error(f"收到错误通知: {error_msg}")
|
logger.error(f"收到错误通知: {error_msg}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ObservationInfo:
|
class ObservationInfo:
|
||||||
"""决策信息类,用于收集和管理来自chat_observer的通知信息"""
|
"""决策信息类,用于收集和管理来自chat_observer的通知信息"""
|
||||||
@@ -124,28 +125,20 @@ class ObservationInfo:
|
|||||||
"""
|
"""
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||||
self.chat_observer.notification_manager.register_handler(
|
self.chat_observer.notification_manager.register_handler(
|
||||||
target="observation_info",
|
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
|
||||||
notification_type="NEW_MESSAGE",
|
|
||||||
handler=self.handler
|
|
||||||
)
|
)
|
||||||
self.chat_observer.notification_manager.register_handler(
|
self.chat_observer.notification_manager.register_handler(
|
||||||
target="observation_info",
|
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
|
||||||
notification_type="COLD_CHAT",
|
|
||||||
handler=self.handler
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def unbind_from_chat_observer(self):
|
def unbind_from_chat_observer(self):
|
||||||
"""解除与chat_observer的绑定"""
|
"""解除与chat_observer的绑定"""
|
||||||
if self.chat_observer:
|
if self.chat_observer:
|
||||||
self.chat_observer.notification_manager.unregister_handler(
|
self.chat_observer.notification_manager.unregister_handler(
|
||||||
target="observation_info",
|
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
|
||||||
notification_type="NEW_MESSAGE",
|
|
||||||
handler=self.handler
|
|
||||||
)
|
)
|
||||||
self.chat_observer.notification_manager.unregister_handler(
|
self.chat_observer.notification_manager.unregister_handler(
|
||||||
target="observation_info",
|
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
|
||||||
notification_type="COLD_CHAT",
|
|
||||||
handler=self.handler
|
|
||||||
)
|
)
|
||||||
self.chat_observer = None
|
self.chat_observer = None
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,6 @@ class GoalAnalyzer:
|
|||||||
goal_text += f"目标:{goal};"
|
goal_text += f"目标:{goal};"
|
||||||
goal_text += f"原因:{reason}\n"
|
goal_text += f"原因:{reason}\n"
|
||||||
|
|
||||||
|
|
||||||
# 获取聊天历史记录
|
# 获取聊天历史记录
|
||||||
chat_history_list = observation_info.chat_history
|
chat_history_list = observation_info.chat_history
|
||||||
chat_history_text = ""
|
chat_history_text = ""
|
||||||
@@ -76,7 +75,6 @@ class GoalAnalyzer:
|
|||||||
|
|
||||||
observation_info.clear_unprocessed_messages()
|
observation_info.clear_unprocessed_messages()
|
||||||
|
|
||||||
|
|
||||||
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
||||||
|
|
||||||
# 构建action历史文本
|
# 构建action历史文本
|
||||||
@@ -85,7 +83,6 @@ class GoalAnalyzer:
|
|||||||
for action in action_history_list:
|
for action in action_history_list:
|
||||||
action_history_text += f"{action}\n"
|
action_history_text += f"{action}\n"
|
||||||
|
|
||||||
|
|
||||||
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
|
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。
|
||||||
这些目标应该反映出对话的不同方面和意图。
|
这些目标应该反映出对话的不同方面和意图。
|
||||||
|
|
||||||
@@ -122,17 +119,12 @@ class GoalAnalyzer:
|
|||||||
|
|
||||||
# 使用简化函数提取JSON内容
|
# 使用简化函数提取JSON内容
|
||||||
success, result = get_items_from_json(
|
success, result = get_items_from_json(
|
||||||
content,
|
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
|
||||||
"goal", "reasoning",
|
|
||||||
required_types={"goal": str, "reasoning": str}
|
|
||||||
)
|
)
|
||||||
# TODO
|
# TODO
|
||||||
|
|
||||||
|
|
||||||
conversation_info.goal_list.append(result)
|
conversation_info.goal_list.append(result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
||||||
"""更新目标列表
|
"""更新目标列表
|
||||||
|
|
||||||
@@ -233,8 +225,10 @@ class GoalAnalyzer:
|
|||||||
# 尝试解析JSON
|
# 尝试解析JSON
|
||||||
success, result = get_items_from_json(
|
success, result = get_items_from_json(
|
||||||
content,
|
content,
|
||||||
"goal_achieved", "stop_conversation", "reason",
|
"goal_achieved",
|
||||||
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}
|
"stop_conversation",
|
||||||
|
"reason",
|
||||||
|
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
@@ -285,7 +279,6 @@ class Waiter:
|
|||||||
logger.info("等待中...")
|
logger.info("等待中...")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DirectMessageSender:
|
class DirectMessageSender:
|
||||||
"""直接发送消息到平台的发送器"""
|
"""直接发送消息到平台的发送器"""
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import traceback
|
|||||||
|
|
||||||
logger = get_module_logger("pfc_manager")
|
logger = get_module_logger("pfc_manager")
|
||||||
|
|
||||||
|
|
||||||
class PFCManager:
|
class PFCManager:
|
||||||
"""PFC对话管理器,负责管理所有对话实例"""
|
"""PFC对话管理器,负责管理所有对话实例"""
|
||||||
|
|
||||||
@@ -16,7 +17,7 @@ class PFCManager:
|
|||||||
_initializing: Dict[str, bool] = {}
|
_initializing: Dict[str, bool] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls) -> 'PFCManager':
|
def get_instance(cls) -> "PFCManager":
|
||||||
"""获取管理器单例
|
"""获取管理器单例
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -60,7 +61,6 @@ class PFCManager:
|
|||||||
|
|
||||||
return conversation_instance
|
return conversation_instance
|
||||||
|
|
||||||
|
|
||||||
async def _initialize_conversation(self, conversation: Conversation):
|
async def _initialize_conversation(self, conversation: Conversation):
|
||||||
"""初始化会话实例
|
"""初始化会话实例
|
||||||
|
|
||||||
@@ -84,7 +84,6 @@ class PFCManager:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
# 清理失败的初始化
|
# 清理失败的初始化
|
||||||
|
|
||||||
|
|
||||||
async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
|
async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
|
||||||
"""获取已存在的会话实例
|
"""获取已存在的会话实例
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Literal
|
|||||||
|
|
||||||
class ConversationState(Enum):
|
class ConversationState(Enum):
|
||||||
"""对话状态"""
|
"""对话状态"""
|
||||||
|
|
||||||
INIT = "初始化"
|
INIT = "初始化"
|
||||||
RETHINKING = "重新思考"
|
RETHINKING = "重新思考"
|
||||||
ANALYZING = "分析历史"
|
ANALYZING = "分析历史"
|
||||||
|
|||||||
@@ -16,21 +16,14 @@ class ReplyGenerator:
|
|||||||
|
|
||||||
def __init__(self, stream_id: str):
|
def __init__(self, stream_id: str):
|
||||||
self.llm = LLM_request(
|
self.llm = LLM_request(
|
||||||
model=global_config.llm_normal,
|
model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation"
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=300,
|
|
||||||
request_type="reply_generation"
|
|
||||||
)
|
)
|
||||||
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.BOT_NICKNAME
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||||
self.reply_checker = ReplyChecker(stream_id)
|
self.reply_checker = ReplyChecker(stream_id)
|
||||||
|
|
||||||
async def generate(
|
async def generate(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> str:
|
||||||
self,
|
|
||||||
observation_info: ObservationInfo,
|
|
||||||
conversation_info: ConversationInfo
|
|
||||||
) -> str:
|
|
||||||
"""生成回复
|
"""生成回复
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -58,7 +51,6 @@ class ReplyGenerator:
|
|||||||
for msg in chat_history_list:
|
for msg in chat_history_list:
|
||||||
chat_history_text += f"{msg}\n"
|
chat_history_text += f"{msg}\n"
|
||||||
|
|
||||||
|
|
||||||
# 整理知识缓存
|
# 整理知识缓存
|
||||||
knowledge_text = ""
|
knowledge_text = ""
|
||||||
knowledge_list = conversation_info.knowledge_list
|
knowledge_list = conversation_info.knowledge_list
|
||||||
@@ -107,12 +99,7 @@ class ReplyGenerator:
|
|||||||
logger.error(f"生成回复时出错: {e}")
|
logger.error(f"生成回复时出错: {e}")
|
||||||
return "抱歉,我现在有点混乱,让我重新思考一下..."
|
return "抱歉,我现在有点混乱,让我重新思考一下..."
|
||||||
|
|
||||||
async def check_reply(
|
async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
|
||||||
self,
|
|
||||||
reply: str,
|
|
||||||
goal: str,
|
|
||||||
retry_count: int = 0
|
|
||||||
) -> Tuple[bool, str, bool]:
|
|
||||||
"""检查回复是否合适
|
"""检查回复是否合适
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from .chat_observer import ChatObserver
|
|||||||
|
|
||||||
logger = get_module_logger("waiter")
|
logger = get_module_logger("waiter")
|
||||||
|
|
||||||
|
|
||||||
class Waiter:
|
class Waiter:
|
||||||
"""等待器,用于等待对话流中的事件"""
|
"""等待器,用于等待对话流中的事件"""
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ class ChatBot:
|
|||||||
chat_id = str(message.chat_stream.stream_id)
|
chat_id = str(message.chat_stream.stream_id)
|
||||||
|
|
||||||
if global_config.enable_pfc_chatting:
|
if global_config.enable_pfc_chatting:
|
||||||
|
|
||||||
await self.pfc_manager.get_or_create_conversation(chat_id)
|
await self.pfc_manager.get_or_create_conversation(chat_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
from .api import BaseMessageAPI, global_api
|
from .api import global_api
|
||||||
from .message_base import (
|
from .message_base import (
|
||||||
Seg,
|
Seg,
|
||||||
GroupInfo,
|
GroupInfo,
|
||||||
@@ -14,7 +14,6 @@ from .message_base import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseMessageAPI",
|
|
||||||
"Seg",
|
"Seg",
|
||||||
"global_api",
|
"global_api",
|
||||||
"GroupInfo",
|
"GroupInfo",
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||||
from typing import Dict, Any, Callable, List, Set
|
from typing import Dict, Any, Callable, List, Set, Optional
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.plugins.message.message_base import MessageBase
|
from src.plugins.message.message_base import MessageBase
|
||||||
|
from src.common.server import global_server
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@@ -49,13 +50,22 @@ class MessageServer(BaseMessageHandler):
|
|||||||
|
|
||||||
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
||||||
|
|
||||||
def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str = "0.0.0.0",
|
||||||
|
port: int = 18000,
|
||||||
|
enable_token=False,
|
||||||
|
app: Optional[FastAPI] = None,
|
||||||
|
path: str = "/ws",
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 将类级别的处理器添加到实例处理器中
|
# 将类级别的处理器添加到实例处理器中
|
||||||
self.message_handlers.extend(self._class_handlers)
|
self.message_handlers.extend(self._class_handlers)
|
||||||
self.app = FastAPI()
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
self.path = path
|
||||||
|
self.app = app or FastAPI()
|
||||||
|
self.own_app = app is None # 标记是否使用自己创建的app
|
||||||
self.active_websockets: Set[WebSocket] = set()
|
self.active_websockets: Set[WebSocket] = set()
|
||||||
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
|
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
|
||||||
self.valid_tokens: Set[str] = set()
|
self.valid_tokens: Set[str] = set()
|
||||||
@@ -63,28 +73,6 @@ class MessageServer(BaseMessageHandler):
|
|||||||
self._setup_routes()
|
self._setup_routes()
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_class_handler(cls, handler: Callable):
|
|
||||||
"""注册类级别的消息处理器"""
|
|
||||||
if handler not in cls._class_handlers:
|
|
||||||
cls._class_handlers.append(handler)
|
|
||||||
|
|
||||||
def register_message_handler(self, handler: Callable):
|
|
||||||
"""注册实例级别的消息处理器"""
|
|
||||||
if handler not in self.message_handlers:
|
|
||||||
self.message_handlers.append(handler)
|
|
||||||
|
|
||||||
async def verify_token(self, token: str) -> bool:
|
|
||||||
if not self.enable_token:
|
|
||||||
return True
|
|
||||||
return token in self.valid_tokens
|
|
||||||
|
|
||||||
def add_valid_token(self, token: str):
|
|
||||||
self.valid_tokens.add(token)
|
|
||||||
|
|
||||||
def remove_valid_token(self, token: str):
|
|
||||||
self.valid_tokens.discard(token)
|
|
||||||
|
|
||||||
def _setup_routes(self):
|
def _setup_routes(self):
|
||||||
@self.app.post("/api/message")
|
@self.app.post("/api/message")
|
||||||
async def handle_message(message: Dict[str, Any]):
|
async def handle_message(message: Dict[str, Any]):
|
||||||
@@ -125,6 +113,90 @@ class MessageServer(BaseMessageHandler):
|
|||||||
finally:
|
finally:
|
||||||
self._remove_websocket(websocket, platform)
|
self._remove_websocket(websocket, platform)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_class_handler(cls, handler: Callable):
|
||||||
|
"""注册类级别的消息处理器"""
|
||||||
|
if handler not in cls._class_handlers:
|
||||||
|
cls._class_handlers.append(handler)
|
||||||
|
|
||||||
|
def register_message_handler(self, handler: Callable):
|
||||||
|
"""注册实例级别的消息处理器"""
|
||||||
|
if handler not in self.message_handlers:
|
||||||
|
self.message_handlers.append(handler)
|
||||||
|
|
||||||
|
async def verify_token(self, token: str) -> bool:
|
||||||
|
if not self.enable_token:
|
||||||
|
return True
|
||||||
|
return token in self.valid_tokens
|
||||||
|
|
||||||
|
def add_valid_token(self, token: str):
|
||||||
|
self.valid_tokens.add(token)
|
||||||
|
|
||||||
|
def remove_valid_token(self, token: str):
|
||||||
|
self.valid_tokens.discard(token)
|
||||||
|
|
||||||
|
def run_sync(self):
|
||||||
|
"""同步方式运行服务器"""
|
||||||
|
if not self.own_app:
|
||||||
|
raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法")
|
||||||
|
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""异步方式运行服务器"""
|
||||||
|
self._running = True
|
||||||
|
try:
|
||||||
|
if self.own_app:
|
||||||
|
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
|
||||||
|
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
||||||
|
self.server = uvicorn.Server(config)
|
||||||
|
await self.server.serve()
|
||||||
|
else:
|
||||||
|
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
await self.stop()
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await self.stop()
|
||||||
|
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||||
|
finally:
|
||||||
|
await self.stop()
|
||||||
|
|
||||||
|
async def start_server(self):
|
||||||
|
"""启动服务器的异步方法"""
|
||||||
|
if not self._running:
|
||||||
|
self._running = True
|
||||||
|
await self.run()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止服务器"""
|
||||||
|
# 清理platform映射
|
||||||
|
self.platform_websockets.clear()
|
||||||
|
|
||||||
|
# 取消所有后台任务
|
||||||
|
for task in self.background_tasks:
|
||||||
|
task.cancel()
|
||||||
|
# 等待所有任务完成
|
||||||
|
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
||||||
|
self.background_tasks.clear()
|
||||||
|
|
||||||
|
# 关闭所有WebSocket连接
|
||||||
|
for websocket in self.active_websockets:
|
||||||
|
await websocket.close()
|
||||||
|
self.active_websockets.clear()
|
||||||
|
|
||||||
|
if hasattr(self, "server") and self.own_app:
|
||||||
|
self._running = False
|
||||||
|
# 正确关闭 uvicorn 服务器
|
||||||
|
self.server.should_exit = True
|
||||||
|
await self.server.shutdown()
|
||||||
|
# 等待服务器完全停止
|
||||||
|
if hasattr(self.server, "started") and self.server.started:
|
||||||
|
await self.server.main_loop()
|
||||||
|
# 清理处理程序
|
||||||
|
self.message_handlers.clear()
|
||||||
|
|
||||||
def _remove_websocket(self, websocket: WebSocket, platform: str):
|
def _remove_websocket(self, websocket: WebSocket, platform: str):
|
||||||
"""从所有集合中移除websocket"""
|
"""从所有集合中移除websocket"""
|
||||||
if websocket in self.active_websockets:
|
if websocket in self.active_websockets:
|
||||||
@@ -161,54 +233,6 @@ class MessageServer(BaseMessageHandler):
|
|||||||
async def send_message(self, message: MessageBase):
|
async def send_message(self, message: MessageBase):
|
||||||
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
|
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
|
||||||
|
|
||||||
def run_sync(self):
|
|
||||||
"""同步方式运行服务器"""
|
|
||||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""异步方式运行服务器"""
|
|
||||||
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
|
||||||
self.server = uvicorn.Server(config)
|
|
||||||
try:
|
|
||||||
await self.server.serve()
|
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
await self.stop()
|
|
||||||
raise KeyboardInterrupt from e
|
|
||||||
|
|
||||||
async def start_server(self):
|
|
||||||
"""启动服务器的异步方法"""
|
|
||||||
if not self._running:
|
|
||||||
self._running = True
|
|
||||||
await self.run()
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止服务器"""
|
|
||||||
# 清理platform映射
|
|
||||||
self.platform_websockets.clear()
|
|
||||||
|
|
||||||
# 取消所有后台任务
|
|
||||||
for task in self.background_tasks:
|
|
||||||
task.cancel()
|
|
||||||
# 等待所有任务完成
|
|
||||||
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
|
||||||
self.background_tasks.clear()
|
|
||||||
|
|
||||||
# 关闭所有WebSocket连接
|
|
||||||
for websocket in self.active_websockets:
|
|
||||||
await websocket.close()
|
|
||||||
self.active_websockets.clear()
|
|
||||||
|
|
||||||
if hasattr(self, "server"):
|
|
||||||
self._running = False
|
|
||||||
# 正确关闭 uvicorn 服务器
|
|
||||||
self.server.should_exit = True
|
|
||||||
await self.server.shutdown()
|
|
||||||
# 等待服务器完全停止
|
|
||||||
if hasattr(self.server, "started") and self.server.started:
|
|
||||||
await self.server.main_loop()
|
|
||||||
# 清理处理程序
|
|
||||||
self.message_handlers.clear()
|
|
||||||
|
|
||||||
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""发送消息到指定端点"""
|
"""发送消息到指定端点"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
@@ -219,105 +243,4 @@ class MessageServer(BaseMessageHandler):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageAPI:
|
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
|
||||||
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
|
|
||||||
self.app = FastAPI()
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.message_handlers: List[Callable] = []
|
|
||||||
self.cache = []
|
|
||||||
self._setup_routes()
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
def _setup_routes(self):
|
|
||||||
"""设置基础路由"""
|
|
||||||
|
|
||||||
@self.app.post("/api/message")
|
|
||||||
async def handle_message(message: Dict[str, Any]):
|
|
||||||
try:
|
|
||||||
# 创建后台任务处理消息
|
|
||||||
asyncio.create_task(self._background_message_handler(message))
|
|
||||||
return {"status": "success"}
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
||||||
|
|
||||||
async def _background_message_handler(self, message: Dict[str, Any]):
|
|
||||||
"""后台处理单个消息"""
|
|
||||||
try:
|
|
||||||
await self.process_single_message(message)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Background message processing failed: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
def register_message_handler(self, handler: Callable):
|
|
||||||
"""注册消息处理函数"""
|
|
||||||
self.message_handlers.append(handler)
|
|
||||||
|
|
||||||
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""发送消息到指定端点"""
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
try:
|
|
||||||
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
|
|
||||||
return await response.json()
|
|
||||||
except Exception:
|
|
||||||
# logger.error(f"发送消息失败: {str(e)}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def process_single_message(self, message: Dict[str, Any]):
|
|
||||||
"""处理单条消息"""
|
|
||||||
tasks = []
|
|
||||||
for handler in self.message_handlers:
|
|
||||||
try:
|
|
||||||
tasks.append(handler(message))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
if tasks:
|
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
def run_sync(self):
|
|
||||||
"""同步方式运行服务器"""
|
|
||||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""异步方式运行服务器"""
|
|
||||||
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
|
||||||
self.server = uvicorn.Server(config)
|
|
||||||
try:
|
|
||||||
await self.server.serve()
|
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
await self.stop()
|
|
||||||
raise KeyboardInterrupt from e
|
|
||||||
|
|
||||||
async def start_server(self):
|
|
||||||
"""启动服务器的异步方法"""
|
|
||||||
if not self._running:
|
|
||||||
self._running = True
|
|
||||||
await self.run()
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止服务器"""
|
|
||||||
if hasattr(self, "server"):
|
|
||||||
self._running = False
|
|
||||||
# 正确关闭 uvicorn 服务器
|
|
||||||
self.server.should_exit = True
|
|
||||||
await self.server.shutdown()
|
|
||||||
# 等待服务器完全停止
|
|
||||||
if hasattr(self.server, "started") and self.server.started:
|
|
||||||
await self.server.main_loop()
|
|
||||||
# 清理处理程序
|
|
||||||
self.message_handlers.clear()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""启动服务器的便捷方法"""
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(self.start_server())
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]))
|
|
||||||
|
|||||||
Reference in New Issue
Block a user