This commit is contained in:
SengokuCola
2025-04-09 20:10:58 +08:00
22 changed files with 614 additions and 692 deletions

View File

@@ -4,6 +4,7 @@ import logging
from pathlib import Path
from logging.handlers import RotatingFileHandler
def setup_crash_logger():
"""设置崩溃日志记录器"""
# 创建logs/crash目录如果不存在
@@ -11,15 +12,12 @@ def setup_crash_logger():
crash_log_dir.mkdir(parents=True, exist_ok=True)
# 创建日志记录器
crash_logger = logging.getLogger('crash_logger')
crash_logger = logging.getLogger("crash_logger")
crash_logger.setLevel(logging.ERROR)
# 设置日志格式
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s\n'
'异常类型: %(exc_info)s\n'
'详细信息:\n%(message)s\n'
'-------------------\n'
"%(asctime)s - %(name)s - %(levelname)s\n异常类型: %(exc_info)s\n详细信息:\n%(message)s\n-------------------\n"
)
# 创建按大小轮转的文件处理器最大10MB保留5个备份
@@ -28,29 +26,28 @@ def setup_crash_logger():
log_file,
maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5,
encoding='utf-8'
encoding="utf-8",
)
file_handler.setFormatter(formatter)
crash_logger.addHandler(file_handler)
return crash_logger
def log_crash(exc_type, exc_value, exc_traceback):
"""记录崩溃信息到日志文件"""
if exc_type is None:
return
# 获取崩溃日志记录器
crash_logger = logging.getLogger('crash_logger')
crash_logger = logging.getLogger("crash_logger")
# 获取完整的异常堆栈信息
stack_trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback))
stack_trace = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
# 记录崩溃信息
crash_logger.error(
stack_trace,
exc_info=(exc_type, exc_value, exc_traceback)
)
crash_logger.error(stack_trace, exc_info=(exc_type, exc_value, exc_traceback))
def install_crash_handler():
"""安装全局异常处理器"""

73
src/common/server.py Normal file
View File

@@ -0,0 +1,73 @@
from fastapi import FastAPI, APIRouter
from typing import Optional, Union
from uvicorn import Config, Server as UvicornServer
import os
class Server:
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
self.app = FastAPI(title=app_name)
self._host: str = "127.0.0.1"
self._port: int = 8080
self._server: Optional[UvicornServer] = None
self.set_address(host, port)
def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由
APIRouter 用于对相关的路由端点进行分组和模块化管理:
1. 可以将相关的端点组织在一起,便于管理
2. 支持添加统一的路由前缀
3. 可以为一组路由添加共同的依赖项、标签等
示例:
router = APIRouter()
@router.get("/users")
def get_users():
return {"users": [...]}
@router.post("/users")
def create_user():
return {"msg": "user created"}
# 注册路由,添加前缀 "/api/v1"
server.register_router(router, prefix="/api/v1")
"""
self.app.include_router(router, prefix=prefix)
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
"""设置服务器地址和端口"""
if host:
self._host = host
if port:
self._port = port
async def run(self):
"""启动服务器"""
config = Config(app=self.app, host=self._host, port=self._port)
self._server = UvicornServer(config=config)
try:
await self._server.serve()
except KeyboardInterrupt:
await self.shutdown()
raise
except Exception as e:
await self.shutdown()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.shutdown()
async def shutdown(self):
"""安全关闭服务器"""
if self._server:
self._server.should_exit = True
await self._server.shutdown()
self._server = None
def get_app(self) -> FastAPI:
"""获取 FastAPI 实例"""
return self.app
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))

View File

@@ -16,7 +16,7 @@ from .plugins.chat.bot import chat_bot
from .common.logger import get_module_logger
from .plugins.remote import heartbeat_thread # noqa: F401
from .individuality.individuality import Individuality
from .common.server import global_server
logger = get_module_logger("main")
@@ -33,6 +33,7 @@ class MainSystem:
from .plugins.message import global_api
self.app = global_api
self.server = global_server
async def initialize(self):
"""初始化系统组件"""
@@ -126,6 +127,7 @@ class MainSystem:
emoji_manager.start_periodic_check_register(),
# emoji_manager.start_periodic_register(),
self.app.run(),
self.server.run(),
]
await asyncio.gather(*tasks)

View File

@@ -10,6 +10,7 @@ from .conversation_info import ConversationInfo
logger = get_module_logger("action_planner")
class ActionPlannerInfo:
def __init__(self):
self.done_action = []
@@ -23,20 +24,13 @@ class ActionPlanner:
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=1000,
request_type="action_planning"
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning"
)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
async def plan(
self,
observation_info: ObservationInfo,
conversation_info: ConversationInfo
) -> Tuple[str, str]:
async def plan(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> Tuple[str, str]:
"""规划下一步行动
Args:
@@ -56,7 +50,6 @@ class ActionPlanner:
goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
# 获取聊天历史记录
chat_history_list = observation_info.chat_history
chat_history_text = ""
@@ -72,7 +65,6 @@ class ActionPlanner:
observation_info.clear_unprocessed_messages()
personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本
@@ -81,8 +73,6 @@ class ActionPlanner:
for action in action_history_list:
action_history_text += f"{action}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下内容根据信息决定下一步行动
当前对话目标:{goal}
@@ -114,9 +104,7 @@ rethink_goal: 重新思考对话目标,当发现对话目标不合适时选择
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"action", "reason",
default_values={"action": "direct_reply", "reason": "没有明确原因"}
content, "action", "reason", default_values={"action": "direct_reply", "reason": "没有明确原因"}
)
if not success:

View File

@@ -17,7 +17,7 @@ class ChatObserver:
_instances: Dict[str, "ChatObserver"] = {}
@classmethod
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> "ChatObserver":
"""获取或创建观察器实例
Args:
@@ -84,10 +84,7 @@ class ChatObserver:
"""
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
new_message_exists = await self.message_storage.has_new_messages(
self.stream_id,
self.last_check_time
)
new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
if new_message_exists:
logger.debug("发现新消息")
@@ -114,11 +111,7 @@ class ChatObserver:
self.last_user_speak_time = message["time"]
# 发送新消息通知
notification = create_new_message_notification(
sender="chat_observer",
target="pfc",
message=message
)
notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message)
await self.notification_manager.send_notification(notification)
# 检查并更新冷场状态
@@ -144,19 +137,12 @@ class ChatObserver:
# 如果冷场状态发生变化,发送通知
if is_cold != self.is_cold_chat_state:
self.is_cold_chat_state = is_cold
notification = create_cold_chat_notification(
sender="chat_observer",
target="pfc",
is_cold=is_cold
)
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
await self.notification_manager.send_notification(notification)
async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
messages = await self.message_storage.get_messages_after(
self.stream_id,
self.last_message_read
)
messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
for message in messages:
await self._add_message_to_history(message)
return messages, self.message_history
@@ -224,10 +210,7 @@ class ChatObserver:
Returns:
List[Dict[str, Any]]: 新消息列表
"""
new_messages = await self.message_storage.get_messages_after(
self.stream_id,
self.last_message_read
)
new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
@@ -243,17 +226,15 @@ class ChatObserver:
Returns:
List[Dict[str, Any]]: 最多5条消息
"""
new_messages = await self.message_storage.get_messages_before(
self.stream_id,
time_point
)
new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
return new_messages
'''主要观察循环'''
"""主要观察循环"""
async def _update_loop(self):
"""更新循环"""
try:
@@ -396,17 +377,16 @@ class ChatObserver:
bool: 是否有新消息
"""
try:
messages = await self.message_storage.get_messages_for_stream(
self.stream_id,
limit=50
)
messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50)
if not messages:
return False
# 检查是否有新消息
has_new_messages = False
if messages and (not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]):
if messages and (
not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]
):
has_new_messages = True
self.message_cache = messages

View File

@@ -4,8 +4,10 @@ from dataclasses import dataclass
from datetime import datetime
from abc import ABC, abstractmethod
class ChatState(Enum):
"""聊天状态枚举"""
NORMAL = auto() # 正常状态
NEW_MESSAGE = auto() # 有新消息
COLD_CHAT = auto() # 冷场状态
@@ -15,8 +17,10 @@ class ChatState(Enum):
SILENT = auto() # 沉默状态
ERROR = auto() # 错误状态
class NotificationType(Enum):
"""通知类型枚举"""
NEW_MESSAGE = auto() # 新消息通知
COLD_CHAT = auto() # 冷场通知
ACTIVE_CHAT = auto() # 活跃通知
@@ -27,9 +31,11 @@ class NotificationType(Enum):
USER_LEFT = auto() # 用户离开通知
ERROR = auto() # 错误通知
@dataclass
class ChatStateInfo:
"""聊天状态信息"""
state: ChatState
last_message_time: Optional[float] = None
last_message_content: Optional[str] = None
@@ -38,9 +44,11 @@ class ChatStateInfo:
cold_duration: float = 0.0 # 冷场持续时间(秒)
active_duration: float = 0.0 # 活跃持续时间(秒)
@dataclass
class Notification:
"""通知基类"""
type: NotificationType
timestamp: float
sender: str # 发送者标识
@@ -49,15 +57,13 @@ class Notification:
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"type": self.type.name,
"timestamp": self.timestamp,
"data": self.data
}
return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
@dataclass
class StateNotification(Notification):
"""持续状态通知"""
is_active: bool = True
def to_dict(self) -> Dict[str, Any]:
@@ -65,6 +71,7 @@ class StateNotification(Notification):
base_dict["is_active"] = self.is_active
return base_dict
class NotificationHandler(ABC):
"""通知处理器接口"""
@@ -73,6 +80,7 @@ class NotificationHandler(ABC):
"""处理通知"""
pass
class NotificationManager:
"""通知管理器"""
@@ -141,10 +149,9 @@ class NotificationManager:
"""检查特定状态是否活跃"""
return state_type in self._active_states
def get_notification_history(self,
sender: Optional[str] = None,
target: Optional[str] = None,
limit: Optional[int] = None) -> List[Notification]:
def get_notification_history(
self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
) -> List[Notification]:
"""获取通知历史
Args:
@@ -164,6 +171,7 @@ class NotificationManager:
return history
# 一些常用的通知创建函数
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
"""创建新消息通知"""
@@ -176,10 +184,11 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str,
"message_id": message.get("message_id"),
"content": message.get("content"),
"sender": message.get("sender"),
"time": message.get("time")
}
"time": message.get("time"),
},
)
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
"""创建冷场状态通知"""
return StateNotification(
@@ -188,9 +197,10 @@ def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> St
sender=sender,
target=target,
data={"is_cold": is_cold},
is_active=is_cold
is_active=is_cold,
)
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
"""创建活跃状态通知"""
return StateNotification(
@@ -199,9 +209,10 @@ def create_active_chat_notification(sender: str, target: str, is_active: bool) -
sender=sender,
target=target,
data={"is_active": is_active},
is_active=is_active
is_active=is_active,
)
class ChatStateManager:
"""聊天状态管理器"""

View File

@@ -54,7 +54,6 @@ class Conversation:
logger.error(traceback.format_exc())
raise
try:
# 决策所需要的信息,包括自身自信和观察信息两部分
# 注册观察器和观测信息
@@ -74,7 +73,6 @@ class Conversation:
self.should_continue = True
asyncio.create_task(self.start())
async def start(self):
"""开始对话流程"""
try:
@@ -84,16 +82,12 @@ class Conversation:
logger.error(f"启动对话系统失败: {e}")
raise
async def _plan_and_action_loop(self):
"""思考步PFC核心循环模块"""
# 获取最近的消息历史
while self.should_continue:
# 使用决策信息来辅助行动规划
action, reason = await self.action_planner.plan(
self.observation_info,
self.conversation_info
)
action, reason = await self.action_planner.plan(self.observation_info, self.conversation_info)
if self._check_new_messages_after_planning():
continue
@@ -108,7 +102,6 @@ class Conversation:
return True
return False
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象"""
try:
@@ -122,31 +115,31 @@ class Conversation:
time=msg_dict["time"],
user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", "")
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
)
except Exception as e:
logger.warning(f"转换消息时出错: {e}")
raise
async def _handle_action(self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo):
async def _handle_action(
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
):
"""处理规划的行动"""
logger.info(f"执行行动: {action}, 原因: {reason}")
# 记录action历史先设置为stop完成后再设置为done
conversation_info.done_action.append({
conversation_info.done_action.append(
{
"action": action,
"reason": reason,
"status": "start",
"time": datetime.datetime.now().strftime("%H:%M:%S")
})
"time": datetime.datetime.now().strftime("%H:%M:%S"),
}
)
if action == "direct_reply":
self.state = ConversationState.GENERATING
self.generated_reply = await self.reply_generator.generate(
observation_info,
conversation_info
)
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
# # 检查回复是否合适
# is_suitable, reason, need_replan = await self.reply_generator.check_reply(
@@ -159,12 +152,14 @@ class Conversation:
await self._send_reply()
conversation_info.done_action.append({
conversation_info.done_action.append(
{
"action": action,
"reason": reason,
"status": "done",
"time": datetime.datetime.now().strftime("%H:%M:%S")
})
"time": datetime.datetime.now().strftime("%H:%M:%S"),
}
)
elif action == "fetch_knowledge":
self.state = ConversationState.FETCHING
@@ -175,10 +170,7 @@ class Conversation:
if knowledge:
if topic not in self.conversation_info.knowledge_list:
self.conversation_info.knowledge_list.append({
"topic": topic,
"knowledge": knowledge
})
self.conversation_info.knowledge_list.append({"topic": topic, "knowledge": knowledge})
else:
self.conversation_info.knowledge_list[topic] += knowledge
@@ -186,7 +178,6 @@ class Conversation:
self.state = ConversationState.RETHINKING
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
elif action == "listening":
self.state = ConversationState.LISTENING
logger.info("倾听对方发言...")
@@ -210,9 +201,7 @@ class Conversation:
latest_message = self._convert_to_message(messages[0])
await self.direct_sender.send_message(
chat_stream=self.chat_stream,
content="TODO:超时消息",
reply_to_message=latest_message
chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
)
except Exception as e:
logger.error(f"发送超时消息失败: {str(e)}")
@@ -231,9 +220,7 @@ class Conversation:
latest_message = self._convert_to_message(messages[0])
try:
await self.direct_sender.send_message(
chat_stream=self.chat_stream,
content=self.generated_reply,
reply_to_message=latest_message
chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
)
self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update():

View File

@@ -1,5 +1,3 @@
class ConversationInfo:
def __init__(self):
self.done_action = []

View File

@@ -7,6 +7,7 @@ from src.plugins.chat.message import MessageSending
logger = get_module_logger("message_sender")
class DirectMessageSender:
"""直接消息发送器"""
@@ -33,10 +34,7 @@ class DirectMessageSender:
# 检查是否需要引用回复
if reply_to_message:
reply_id = reply_to_message.message_id
message_sending = MessageSending(
segments=segments,
reply_to_id=reply_id
)
message_sending = MessageSending(segments=segments, reply_to_id=reply_id)
else:
message_sending = MessageSending(segments=segments)

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from src.common.database import db
class MessageStorage(ABC):
"""消息存储接口"""
@@ -45,6 +46,7 @@ class MessageStorage(ABC):
"""
pass
class MongoDBMessageStorage(MessageStorage):
"""MongoDB消息存储实现"""
@@ -60,32 +62,23 @@ class MongoDBMessageStorage(MessageStorage):
if last_message:
query["time"] = {"$gt": last_message["time"]}
return list(
self.db.messages.find(query).sort("time", 1)
)
return list(self.db.messages.find(query).sort("time", 1))
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = {
"chat_id": chat_id,
"time": {"$lt": time_point}
}
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
messages = list(
self.db.messages.find(query).sort("time", -1).limit(limit)
)
messages = list(self.db.messages.find(query).sort("time", -1).limit(limit))
# 将消息按时间正序排列
messages.reverse()
return messages
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = {
"chat_id": chat_id,
"time": {"$gt": after_time}
}
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
return self.db.messages.find_one(query) is not None
# # 创建一个内存消息存储实现,用于测试
# class InMemoryMessageStorage(MessageStorage):
# """内存消息存储实现,主要用于测试"""

View File

@@ -7,10 +7,11 @@ if TYPE_CHECKING:
logger = get_module_logger("notification_handler")
class PFCNotificationHandler(NotificationHandler):
"""PFC通知处理器"""
def __init__(self, conversation: 'Conversation'):
def __init__(self, conversation: "Conversation"):
"""初始化PFC通知处理器
Args:
@@ -68,4 +69,3 @@ class PFCNotificationHandler(NotificationHandler):
observation_info.conversation_cold_duration = cold_duration
logger.info(f"对话已冷: {cold_duration}")

View File

@@ -10,10 +10,11 @@ from .chat_states import NotificationHandler
logger = get_module_logger("observation_info")
class ObservationInfoHandler(NotificationHandler):
"""ObservationInfo的通知处理器"""
def __init__(self, observation_info: 'ObservationInfo'):
def __init__(self, observation_info: "ObservationInfo"):
"""初始化处理器
Args:
@@ -62,8 +63,7 @@ class ObservationInfoHandler(NotificationHandler):
# 处理消息删除通知
message_id = data.get("message_id")
self.observation_info.unprocessed_messages = [
msg for msg in self.observation_info.unprocessed_messages
if msg.get("message_id") != message_id
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
]
elif notification_type == "USER_JOINED":
@@ -83,6 +83,7 @@ class ObservationInfoHandler(NotificationHandler):
error_msg = data.get("error", "")
logger.error(f"收到错误通知: {error_msg}")
@dataclass
class ObservationInfo:
"""决策信息类用于收集和管理来自chat_observer的通知信息"""
@@ -124,28 +125,20 @@ class ObservationInfo:
"""
self.chat_observer = ChatObserver.get_instance(stream_id)
self.chat_observer.notification_manager.register_handler(
target="observation_info",
notification_type="NEW_MESSAGE",
handler=self.handler
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
)
self.chat_observer.notification_manager.register_handler(
target="observation_info",
notification_type="COLD_CHAT",
handler=self.handler
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
)
def unbind_from_chat_observer(self):
"""解除与chat_observer的绑定"""
if self.chat_observer:
self.chat_observer.notification_manager.unregister_handler(
target="observation_info",
notification_type="NEW_MESSAGE",
handler=self.handler
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
)
self.chat_observer.notification_manager.unregister_handler(
target="observation_info",
notification_type="COLD_CHAT",
handler=self.handler
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
)
self.chat_observer = None

View File

@@ -60,7 +60,6 @@ class GoalAnalyzer:
goal_text += f"目标:{goal};"
goal_text += f"原因:{reason}\n"
# 获取聊天历史记录
chat_history_list = observation_info.chat_history
chat_history_text = ""
@@ -76,7 +75,6 @@ class GoalAnalyzer:
observation_info.clear_unprocessed_messages()
personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本
@@ -85,7 +83,6 @@ class GoalAnalyzer:
for action in action_history_list:
action_history_text += f"{action}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。
这些目标应该反映出对话的不同方面和意图。
@@ -122,17 +119,12 @@ class GoalAnalyzer:
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"goal", "reasoning",
required_types={"goal": str, "reasoning": str}
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
)
# TODO
conversation_info.goal_list.append(result)
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表
@@ -233,8 +225,10 @@ class GoalAnalyzer:
# 尝试解析JSON
success, result = get_items_from_json(
content,
"goal_achieved", "stop_conversation", "reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}
"goal_achieved",
"stop_conversation",
"reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
)
if not success:
@@ -285,7 +279,6 @@ class Waiter:
logger.info("等待中...")
class DirectMessageSender:
"""直接发送消息到平台的发送器"""

View File

@@ -5,6 +5,7 @@ import traceback
logger = get_module_logger("pfc_manager")
class PFCManager:
"""PFC对话管理器负责管理所有对话实例"""
@@ -16,7 +17,7 @@ class PFCManager:
_initializing: Dict[str, bool] = {}
@classmethod
def get_instance(cls) -> 'PFCManager':
def get_instance(cls) -> "PFCManager":
"""获取管理器单例
Returns:
@@ -60,7 +61,6 @@ class PFCManager:
return conversation_instance
async def _initialize_conversation(self, conversation: Conversation):
"""初始化会话实例
@@ -84,7 +84,6 @@ class PFCManager:
logger.error(traceback.format_exc())
# 清理失败的初始化
async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
"""获取已存在的会话实例

View File

@@ -4,6 +4,7 @@ from typing import Literal
class ConversationState(Enum):
"""对话状态"""
INIT = "初始化"
RETHINKING = "重新思考"
ANALYZING = "分析历史"

View File

@@ -16,21 +16,14 @@ class ReplyGenerator:
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=300,
request_type="reply_generation"
model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation"
)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
self.reply_checker = ReplyChecker(stream_id)
async def generate(
self,
observation_info: ObservationInfo,
conversation_info: ConversationInfo
) -> str:
async def generate(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> str:
"""生成回复
Args:
@@ -58,7 +51,6 @@ class ReplyGenerator:
for msg in chat_history_list:
chat_history_text += f"{msg}\n"
# 整理知识缓存
knowledge_text = ""
knowledge_list = conversation_info.knowledge_list
@@ -107,12 +99,7 @@ class ReplyGenerator:
logger.error(f"生成回复时出错: {e}")
return "抱歉,我现在有点混乱,让我重新思考一下..."
async def check_reply(
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
"""检查回复是否合适
Args:

View File

@@ -3,6 +3,7 @@ from .chat_observer import ChatObserver
logger = get_module_logger("waiter")
class Waiter:
"""等待器,用于等待对话流中的事件"""

View File

@@ -46,7 +46,6 @@ class ChatBot:
chat_id = str(message.chat_stream.stream_id)
if global_config.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id)
except Exception as e:

View File

@@ -2,7 +2,7 @@
__version__ = "0.1.0"
from .api import BaseMessageAPI, global_api
from .api import global_api
from .message_base import (
Seg,
GroupInfo,
@@ -14,7 +14,6 @@ from .message_base import (
)
__all__ = [
"BaseMessageAPI",
"Seg",
"global_api",
"GroupInfo",

View File

@@ -1,7 +1,8 @@
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from typing import Dict, Any, Callable, List, Set
from typing import Dict, Any, Callable, List, Set, Optional
from src.common.logger import get_module_logger
from src.plugins.message.message_base import MessageBase
from src.common.server import global_server
import aiohttp
import asyncio
import uvicorn
@@ -49,13 +50,22 @@ class MessageServer(BaseMessageHandler):
_class_handlers: List[Callable] = [] # 类级别的消息处理器
def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False):
def __init__(
self,
host: str = "0.0.0.0",
port: int = 18000,
enable_token=False,
app: Optional[FastAPI] = None,
path: str = "/ws",
):
super().__init__()
# 将类级别的处理器添加到实例处理器中
self.message_handlers.extend(self._class_handlers)
self.app = FastAPI()
self.host = host
self.port = port
self.path = path
self.app = app or FastAPI()
self.own_app = app is None # 标记是否使用自己创建的app
self.active_websockets: Set[WebSocket] = set()
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
self.valid_tokens: Set[str] = set()
@@ -63,28 +73,6 @@ class MessageServer(BaseMessageHandler):
self._setup_routes()
self._running = False
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def _setup_routes(self):
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
@@ -125,6 +113,90 @@ class MessageServer(BaseMessageHandler):
finally:
self._remove_websocket(websocket, platform)
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def run_sync(self):
"""同步方式运行服务器"""
if not self.own_app:
raise RuntimeError("当使用外部FastAPI实例时请使用该实例的运行方法")
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
self._running = True
try:
if self.own_app:
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
await self.server.serve()
else:
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
while self._running:
await asyncio.sleep(1)
except KeyboardInterrupt:
await self.stop()
raise
except Exception as e:
await self.stop()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.stop()
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server") and self.own_app:
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def _remove_websocket(self, websocket: WebSocket, platform: str):
"""从所有集合中移除websocket"""
if websocket in self.active_websockets:
@@ -161,54 +233,6 @@ class MessageServer(BaseMessageHandler):
async def send_message(self, message: MessageBase):
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
try:
await self.server.serve()
except KeyboardInterrupt as e:
await self.stop()
raise KeyboardInterrupt from e
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
@@ -219,105 +243,4 @@ class MessageServer(BaseMessageHandler):
raise e
class BaseMessageAPI:
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
self.app = FastAPI()
self.host = host
self.port = port
self.message_handlers: List[Callable] = []
self.cache = []
self._setup_routes()
self._running = False
def _setup_routes(self):
"""设置基础路由"""
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
try:
# 创建后台任务处理消息
asyncio.create_task(self._background_message_handler(message))
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def _background_message_handler(self, message: Dict[str, Any]):
"""后台处理单个消息"""
try:
await self.process_single_message(message)
except Exception as e:
logger.error(f"Background message processing failed: {str(e)}")
logger.error(traceback.format_exc())
def register_message_handler(self, handler: Callable):
"""注册消息处理函数"""
self.message_handlers.append(handler)
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
return await response.json()
except Exception:
# logger.error(f"发送消息失败: {str(e)}")
pass
async def process_single_message(self, message: Dict[str, Any]):
"""处理单条消息"""
tasks = []
for handler in self.message_handlers:
try:
tasks.append(handler(message))
except Exception as e:
logger.error(str(e))
logger.error(traceback.format_exc())
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
try:
await self.server.serve()
except KeyboardInterrupt as e:
await self.stop()
raise KeyboardInterrupt from e
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def start(self):
"""启动服务器的便捷方法"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.start_server())
except KeyboardInterrupt:
pass
finally:
loop.close()
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]))
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())