This commit is contained in:
SengokuCola
2025-04-09 20:10:58 +08:00
22 changed files with 614 additions and 692 deletions

2
bot.py
View File

@@ -196,7 +196,7 @@ def raw_main():
# 安装崩溃日志处理器
install_crash_handler()
check_eula()
print("检查EULA和隐私条款完成")
easter_egg()

View File

@@ -4,69 +4,66 @@ import logging
from pathlib import Path
from logging.handlers import RotatingFileHandler
def setup_crash_logger():
"""设置崩溃日志记录器"""
# 创建logs/crash目录如果不存在
crash_log_dir = Path("logs/crash")
crash_log_dir.mkdir(parents=True, exist_ok=True)
# 创建日志记录器
crash_logger = logging.getLogger('crash_logger')
crash_logger = logging.getLogger("crash_logger")
crash_logger.setLevel(logging.ERROR)
# 设置日志格式
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s\n'
'异常类型: %(exc_info)s\n'
'详细信息:\n%(message)s\n'
'-------------------\n'
"%(asctime)s - %(name)s - %(levelname)s\n异常类型: %(exc_info)s\n详细信息:\n%(message)s\n-------------------\n"
)
# 创建按大小轮转的文件处理器最大10MB保留5个备份
log_file = crash_log_dir / "crash.log"
file_handler = RotatingFileHandler(
log_file,
maxBytes=10*1024*1024, # 10MB
maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5,
encoding='utf-8'
encoding="utf-8",
)
file_handler.setFormatter(formatter)
crash_logger.addHandler(file_handler)
return crash_logger
def log_crash(exc_type, exc_value, exc_traceback):
"""记录崩溃信息到日志文件"""
if exc_type is None:
return
# 获取崩溃日志记录器
crash_logger = logging.getLogger('crash_logger')
crash_logger = logging.getLogger("crash_logger")
# 获取完整的异常堆栈信息
stack_trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback))
stack_trace = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
# 记录崩溃信息
crash_logger.error(
stack_trace,
exc_info=(exc_type, exc_value, exc_traceback)
)
crash_logger.error(stack_trace, exc_info=(exc_type, exc_value, exc_traceback))
def install_crash_handler():
"""安装全局异常处理器"""
# 设置崩溃日志记录器
setup_crash_logger()
# 保存原始的异常处理器
original_hook = sys.excepthook
def exception_handler(exc_type, exc_value, exc_traceback):
"""全局异常处理器"""
# 记录崩溃信息
log_crash(exc_type, exc_value, exc_traceback)
# 调用原始的异常处理器
original_hook(exc_type, exc_value, exc_traceback)
# 设置全局异常处理器
sys.excepthook = exception_handler
sys.excepthook = exception_handler

73
src/common/server.py Normal file
View File

@@ -0,0 +1,73 @@
from fastapi import FastAPI, APIRouter
from typing import Optional, Union
from uvicorn import Config, Server as UvicornServer
import os
class Server:
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
self.app = FastAPI(title=app_name)
self._host: str = "127.0.0.1"
self._port: int = 8080
self._server: Optional[UvicornServer] = None
self.set_address(host, port)
def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由
APIRouter 用于对相关的路由端点进行分组和模块化管理:
1. 可以将相关的端点组织在一起,便于管理
2. 支持添加统一的路由前缀
3. 可以为一组路由添加共同的依赖项、标签等
示例:
router = APIRouter()
@router.get("/users")
def get_users():
return {"users": [...]}
@router.post("/users")
def create_user():
return {"msg": "user created"}
# 注册路由,添加前缀 "/api/v1"
server.register_router(router, prefix="/api/v1")
"""
self.app.include_router(router, prefix=prefix)
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
"""设置服务器地址和端口"""
if host:
self._host = host
if port:
self._port = port
async def run(self):
"""启动服务器"""
config = Config(app=self.app, host=self._host, port=self._port)
self._server = UvicornServer(config=config)
try:
await self._server.serve()
except KeyboardInterrupt:
await self.shutdown()
raise
except Exception as e:
await self.shutdown()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.shutdown()
async def shutdown(self):
"""安全关闭服务器"""
if self._server:
self._server.should_exit = True
await self._server.shutdown()
self._server = None
def get_app(self) -> FastAPI:
"""获取 FastAPI 实例"""
return self.app
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))

View File

@@ -16,7 +16,7 @@ from .plugins.chat.bot import chat_bot
from .common.logger import get_module_logger
from .plugins.remote import heartbeat_thread # noqa: F401
from .individuality.individuality import Individuality
from .common.server import global_server
logger = get_module_logger("main")
@@ -33,6 +33,7 @@ class MainSystem:
from .plugins.message import global_api
self.app = global_api
self.server = global_server
async def initialize(self):
"""初始化系统组件"""
@@ -126,6 +127,7 @@ class MainSystem:
emoji_manager.start_periodic_check_register(),
# emoji_manager.start_periodic_register(),
self.app.run(),
self.server.run(),
]
await asyncio.gather(*tasks)

View File

@@ -10,6 +10,7 @@ from .conversation_info import ConversationInfo
logger = get_module_logger("action_planner")
class ActionPlannerInfo:
def __init__(self):
self.done_action = []
@@ -20,68 +21,57 @@ class ActionPlannerInfo:
class ActionPlanner:
"""行动规划器"""
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=1000,
request_type="action_planning"
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning"
)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
async def plan(
self,
observation_info: ObservationInfo,
conversation_info: ConversationInfo
) -> Tuple[str, str]:
async def plan(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> Tuple[str, str]:
"""规划下一步行动
Args:
observation_info: 决策信息
conversation_info: 对话信息
Returns:
Tuple[str, str]: (行动类型, 行动原因)
"""
# 构建提示词
logger.debug(f"开始规划行动:当前目标: {conversation_info.goal_list}")
#构建对话目标
# 构建对话目标
if conversation_info.goal_list:
goal, reasoning = conversation_info.goal_list[-1]
else:
goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
# 获取聊天历史记录
chat_history_list = observation_info.chat_history
chat_history_text = ""
for msg in chat_history_list:
chat_history_text += f"{msg}\n"
if observation_info.new_messages_count > 0:
new_messages_list = observation_info.unprocessed_messages
chat_history_text += f"{observation_info.new_messages_count}条新消息:\n"
for msg in new_messages_list:
chat_history_text += f"{msg}\n"
observation_info.clear_unprocessed_messages()
personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本
action_history_list = conversation_info.done_action
action_history_text = "你之前做的事情是:"
for action in action_history_list:
action_history_text += f"{action}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下内容根据信息决定下一步行动
@@ -111,29 +101,27 @@ rethink_goal: 重新思考对话目标,当发现对话目标不合适时选择
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"action", "reason",
default_values={"action": "direct_reply", "reason": "没有明确原因"}
content, "action", "reason", default_values={"action": "direct_reply", "reason": "没有明确原因"}
)
if not success:
return "direct_reply", "JSON解析失败选择直接回复"
action = result["action"]
reason = result["reason"]
# 验证action类型
if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal"]:
logger.warning(f"未知的行动类型: {action}默认使用listening")
action = "listening"
logger.info(f"规划的行动: {action}")
logger.info(f"行动原因: {reason}")
return action, reason
except Exception as e:
logger.error(f"规划行动时出错: {str(e)}")
return "direct_reply", "发生错误,选择直接回复"
return "direct_reply", "发生错误,选择直接回复"

View File

@@ -17,20 +17,20 @@ class ChatObserver:
_instances: Dict[str, "ChatObserver"] = {}
@classmethod
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> "ChatObserver":
"""获取或创建观察器实例
Args:
stream_id: 聊天流ID
message_storage: 消息存储实现如果为None则使用MongoDB实现
Returns:
ChatObserver: 观察器实例
"""
if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id, message_storage)
return cls._instances[stream_id]
def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None):
"""初始化观察器
@@ -43,15 +43,15 @@ class ChatObserver:
self.stream_id = stream_id
self.message_storage = message_storage or MongoDBMessageStorage()
self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[str] = None # 最后读取的消息ID
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[str] = None # 最后读取的消息ID
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
# 消息历史记录
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
self.last_message_id: Optional[str] = None # 最后一条消息的ID
@@ -62,20 +62,20 @@ class ChatObserver:
self._task: Optional[asyncio.Task] = None
self._update_event = asyncio.Event() # 触发更新的事件
self._update_complete = asyncio.Event() # 更新完成的事件
# 通知管理器
self.notification_manager = NotificationManager()
# 冷场检查配置
self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场
self.last_cold_chat_check: float = time.time()
self.is_cold_chat_state: bool = False
self.update_event = asyncio.Event()
self.update_interval = 5 # 更新间隔(秒)
self.message_cache = []
self.update_running = False
async def check(self) -> bool:
"""检查距离上一次观察之后是否有了新消息
@@ -83,21 +83,18 @@ class ChatObserver:
bool: 是否有新消息
"""
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
new_message_exists = await self.message_storage.has_new_messages(
self.stream_id,
self.last_check_time
)
new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
if new_message_exists:
logger.debug("发现新消息")
self.last_check_time = time.time()
return new_message_exists
async def _add_message_to_history(self, message: Dict[str, Any]):
"""添加消息到历史记录并发送通知
Args:
message: 消息数据
"""
@@ -112,76 +109,65 @@ class ChatObserver:
self.last_bot_speak_time = message["time"]
else:
self.last_user_speak_time = message["time"]
# 发送新消息通知
notification = create_new_message_notification(
sender="chat_observer",
target="pfc",
message=message
)
notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message)
await self.notification_manager.send_notification(notification)
# 检查并更新冷场状态
await self._check_cold_chat()
async def _check_cold_chat(self):
"""检查是否处于冷场状态并发送通知"""
current_time = time.time()
# 每10秒检查一次冷场状态
if current_time - self.last_cold_chat_check < 10:
return
self.last_cold_chat_check = current_time
# 判断是否冷场
is_cold = False
if self.last_message_time is None:
is_cold = True
else:
is_cold = (current_time - self.last_message_time) > self.cold_chat_threshold
# 如果冷场状态发生变化,发送通知
if is_cold != self.is_cold_chat_state:
self.is_cold_chat_state = is_cold
notification = create_cold_chat_notification(
sender="chat_observer",
target="pfc",
is_cold=is_cold
)
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
await self.notification_manager.send_notification(notification)
async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
messages = await self.message_storage.get_messages_after(
self.stream_id,
self.last_message_read
)
messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
for message in messages:
await self._add_message_to_history(message)
return messages, self.message_history
def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息
Args:
time_point: 时间戳
Returns:
bool: 是否有新消息
"""
if time_point is None:
logger.warning("time_point 为 None返回 False")
return False
if self.last_message_time is None:
logger.debug("没有最后消息时间,返回 False")
return False
has_new = self.last_message_time > time_point
logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}")
return has_new
def get_message_history(
self,
start_time: Optional[float] = None,
@@ -224,11 +210,8 @@ class ChatObserver:
Returns:
List[Dict[str, Any]]: 新消息列表
"""
new_messages = await self.message_storage.get_messages_after(
self.stream_id,
self.last_message_read
)
new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
@@ -243,17 +226,15 @@ class ChatObserver:
Returns:
List[Dict[str, Any]]: 最多5条消息
"""
new_messages = await self.message_storage.get_messages_before(
self.stream_id,
time_point
)
new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
return new_messages
'''主要观察循环'''
"""主要观察循环"""
async def _update_loop(self):
"""更新循环"""
try:
@@ -282,7 +263,7 @@ class ChatObserver:
# 处理新消息
for message in new_messages:
await self._add_message_to_history(message)
# 设置完成事件
self._update_complete.set()
@@ -379,7 +360,7 @@ class ChatObserver:
if not self.update_running:
self.update_running = True
asyncio.create_task(self._periodic_update())
async def _periodic_update(self):
"""定期更新消息历史"""
try:
@@ -388,53 +369,52 @@ class ChatObserver:
await asyncio.sleep(self.update_interval)
except Exception as e:
logger.error(f"定期更新消息历史时出错: {str(e)}")
async def _update_message_history(self) -> bool:
"""更新消息历史
Returns:
bool: 是否有新消息
"""
try:
messages = await self.message_storage.get_messages_for_stream(
self.stream_id,
limit=50
)
messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50)
if not messages:
return False
# 检查是否有新消息
has_new_messages = False
if messages and (not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]):
if messages and (
not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]
):
has_new_messages = True
self.message_cache = messages
if has_new_messages:
self.update_event.set()
self.update_event.clear()
return True
return False
except Exception as e:
logger.error(f"更新消息历史时出错: {str(e)}")
return False
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
"""获取缓存的消息历史
Args:
limit: 获取的最大消息数量默认50
Returns:
List[Dict[str, Any]]: 缓存的消息历史列表
"""
"""
return self.message_cache[:limit]
def get_last_message(self) -> Optional[Dict[str, Any]]:
"""获取最后一条消息
Returns:
Optional[Dict[str, Any]]: 最后一条消息如果没有则返回None
"""

View File

@@ -4,32 +4,38 @@ from dataclasses import dataclass
from datetime import datetime
from abc import ABC, abstractmethod
class ChatState(Enum):
"""聊天状态枚举"""
NORMAL = auto() # 正常状态
NEW_MESSAGE = auto() # 有新消息
COLD_CHAT = auto() # 冷场状态
ACTIVE_CHAT = auto() # 活跃状态
BOT_SPEAKING = auto() # 机器人正在说话
USER_SPEAKING = auto() # 用户正在说话
SILENT = auto() # 沉默状态
ERROR = auto() # 错误状态
NORMAL = auto() # 正常状态
NEW_MESSAGE = auto() # 有新消息
COLD_CHAT = auto() # 冷场状态
ACTIVE_CHAT = auto() # 活跃状态
BOT_SPEAKING = auto() # 机器人正在说话
USER_SPEAKING = auto() # 用户正在说话
SILENT = auto() # 沉默状态
ERROR = auto() # 错误状态
class NotificationType(Enum):
"""通知类型枚举"""
NEW_MESSAGE = auto() # 新消息通知
COLD_CHAT = auto() # 冷场通知
ACTIVE_CHAT = auto() # 活跃通知
BOT_SPEAKING = auto() # 机器人说话通知
USER_SPEAKING = auto() # 用户说话通知
MESSAGE_DELETED = auto() # 消息删除通知
USER_JOINED = auto() # 用户加入通知
USER_LEFT = auto() # 用户离开通知
ERROR = auto() # 错误通知
NEW_MESSAGE = auto() # 新消息通知
COLD_CHAT = auto() # 冷场通知
ACTIVE_CHAT = auto() # 活跃通知
BOT_SPEAKING = auto() # 机器人说话通知
USER_SPEAKING = auto() # 用户说话通知
MESSAGE_DELETED = auto() # 消息删除通知
USER_JOINED = auto() # 用户加入通知
USER_LEFT = auto() # 用户离开通知
ERROR = auto() # 错误通知
@dataclass
class ChatStateInfo:
"""聊天状态信息"""
state: ChatState
last_message_time: Optional[float] = None
last_message_content: Optional[str] = None
@@ -38,53 +44,55 @@ class ChatStateInfo:
cold_duration: float = 0.0 # 冷场持续时间(秒)
active_duration: float = 0.0 # 活跃持续时间(秒)
@dataclass
class Notification:
"""通知基类"""
type: NotificationType
timestamp: float
sender: str # 发送者标识
target: str # 接收者标识
sender: str # 发送者标识
target: str # 接收者标识
data: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"type": self.type.name,
"timestamp": self.timestamp,
"data": self.data
}
return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
@dataclass
class StateNotification(Notification):
"""持续状态通知"""
is_active: bool = True
def to_dict(self) -> Dict[str, Any]:
base_dict = super().to_dict()
base_dict["is_active"] = self.is_active
return base_dict
class NotificationHandler(ABC):
"""通知处理器接口"""
@abstractmethod
async def handle_notification(self, notification: Notification):
"""处理通知"""
pass
class NotificationManager:
"""通知管理器"""
def __init__(self):
# 按接收者和通知类型存储处理器
self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {}
self._active_states: Set[NotificationType] = set()
self._notification_history: List[Notification] = []
def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
"""注册通知处理器
Args:
target: 接收者标识(例如:"pfc"
notification_type: 要处理的通知类型
@@ -95,10 +103,10 @@ class NotificationManager:
if notification_type not in self._handlers[target]:
self._handlers[target][notification_type] = []
self._handlers[target][notification_type].append(handler)
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
"""注销通知处理器
Args:
target: 接收者标识
notification_type: 通知类型
@@ -114,56 +122,56 @@ class NotificationManager:
# 如果该目标没有任何处理器,删除该目标
if not self._handlers[target]:
del self._handlers[target]
async def send_notification(self, notification: Notification):
"""发送通知"""
self._notification_history.append(notification)
# 如果是状态通知,更新活跃状态
if isinstance(notification, StateNotification):
if notification.is_active:
self._active_states.add(notification.type)
else:
self._active_states.discard(notification.type)
# 调用目标接收者的处理器
target = notification.target
if target in self._handlers:
handlers = self._handlers[target].get(notification.type, [])
for handler in handlers:
await handler.handle_notification(notification)
def get_active_states(self) -> Set[NotificationType]:
"""获取当前活跃的状态"""
return self._active_states.copy()
def is_state_active(self, state_type: NotificationType) -> bool:
"""检查特定状态是否活跃"""
return state_type in self._active_states
def get_notification_history(self,
sender: Optional[str] = None,
target: Optional[str] = None,
limit: Optional[int] = None) -> List[Notification]:
def get_notification_history(
self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
) -> List[Notification]:
"""获取通知历史
Args:
sender: 过滤特定发送者的通知
target: 过滤特定接收者的通知
limit: 限制返回数量
"""
history = self._notification_history
if sender:
history = [n for n in history if n.sender == sender]
if target:
history = [n for n in history if n.target == target]
if limit is not None:
history = history[-limit:]
return history
# 一些常用的通知创建函数
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
"""创建新消息通知"""
@@ -176,10 +184,11 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str,
"message_id": message.get("message_id"),
"content": message.get("content"),
"sender": message.get("sender"),
"time": message.get("time")
}
"time": message.get("time"),
},
)
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
"""创建冷场状态通知"""
return StateNotification(
@@ -188,9 +197,10 @@ def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> St
sender=sender,
target=target,
data={"is_cold": is_cold},
is_active=is_cold
is_active=is_cold,
)
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
"""创建活跃状态通知"""
return StateNotification(
@@ -199,69 +209,70 @@ def create_active_chat_notification(sender: str, target: str, is_active: bool) -
sender=sender,
target=target,
data={"is_active": is_active},
is_active=is_active
is_active=is_active,
)
class ChatStateManager:
"""聊天状态管理器"""
def __init__(self):
self.current_state = ChatState.NORMAL
self.state_info = ChatStateInfo(state=ChatState.NORMAL)
self.state_history: list[ChatStateInfo] = []
def update_state(self, new_state: ChatState, **kwargs):
"""更新聊天状态
Args:
new_state: 新的状态
**kwargs: 其他状态信息
"""
self.current_state = new_state
self.state_info.state = new_state
# 更新其他状态信息
for key, value in kwargs.items():
if hasattr(self.state_info, key):
setattr(self.state_info, key, value)
# 记录状态历史
self.state_history.append(self.state_info)
def get_current_state_info(self) -> ChatStateInfo:
"""获取当前状态信息"""
return self.state_info
def get_state_history(self) -> list[ChatStateInfo]:
"""获取状态历史"""
return self.state_history
def is_cold_chat(self, threshold: float = 60.0) -> bool:
"""判断是否处于冷场状态
Args:
threshold: 冷场阈值(秒)
Returns:
bool: 是否冷场
"""
if not self.state_info.last_message_time:
return True
current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) > threshold
def is_active_chat(self, threshold: float = 5.0) -> bool:
"""判断是否处于活跃状态
Args:
threshold: 活跃阈值(秒)
Returns:
bool: 是否活跃
"""
if not self.state_info.last_message_time:
return False
current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) <= threshold
return (current_time - self.state_info.last_message_time) <= threshold

View File

@@ -20,23 +20,23 @@ logger = get_module_logger("pfc_conversation")
class Conversation:
"""对话类,负责管理单个对话的状态和行为"""
def __init__(self, stream_id: str):
"""初始化对话实例
Args:
stream_id: 聊天流ID
"""
self.stream_id = stream_id
self.state = ConversationState.INIT
self.should_continue = False
# 回复相关
self.generated_reply = ""
async def _initialize(self):
"""初始化实例,注册所有组件"""
try:
self.action_planner = ActionPlanner(self.stream_id)
self.goal_analyzer = GoalAnalyzer(self.stream_id)
@@ -44,37 +44,35 @@ class Conversation:
self.knowledge_fetcher = KnowledgeFetcher()
self.waiter = Waiter(self.stream_id)
self.direct_sender = DirectMessageSender()
# 获取聊天流信息
self.chat_stream = chat_manager.get_stream(self.stream_id)
self.stop_action_planner = False
except Exception as e:
logger.error(f"初始化对话实例:注册运行组件失败: {e}")
logger.error(traceback.format_exc())
raise
try:
#决策所需要的信息,包括自身自信和观察信息两部分
#注册观察器和观测信息
# 决策所需要的信息,包括自身自信和观察信息两部分
# 注册观察器和观测信息
self.chat_observer = ChatObserver.get_instance(self.stream_id)
self.chat_observer.start()
self.observation_info = ObservationInfo()
self.observation_info.bind_to_chat_observer(self.stream_id)
#对话信息
# 对话信息
self.conversation_info = ConversationInfo()
except Exception as e:
logger.error(f"初始化对话实例:注册信息组件失败: {e}")
logger.error(traceback.format_exc())
raise
# 组件准备完成,启动该论对话
self.should_continue = True
asyncio.create_task(self.start())
async def start(self):
"""开始对话流程"""
try:
@@ -83,17 +81,13 @@ class Conversation:
except Exception as e:
logger.error(f"启动对话系统失败: {e}")
raise
async def _plan_and_action_loop(self):
"""思考步PFC核心循环模块"""
# 获取最近的消息历史
while self.should_continue:
# 使用决策信息来辅助行动规划
action, reason = await self.action_planner.plan(
self.observation_info,
self.conversation_info
)
action, reason = await self.action_planner.plan(self.observation_info, self.conversation_info)
if self._check_new_messages_after_planning():
continue
@@ -107,93 +101,90 @@ class Conversation:
# 如果需要,可以在这里添加逻辑来根据新消息重新决定行动
return True
return False
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象"""
try:
chat_info = msg_dict.get("chat_info", {})
chat_stream = ChatStream.from_dict(chat_info)
user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
return Message(
message_id=msg_dict["message_id"],
chat_stream=chat_stream,
time=msg_dict["time"],
user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", "")
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
)
except Exception as e:
logger.warning(f"转换消息时出错: {e}")
raise
async def _handle_action(self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo):
async def _handle_action(
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
):
"""处理规划的行动"""
logger.info(f"执行行动: {action}, 原因: {reason}")
# 记录action历史先设置为stop完成后再设置为done
conversation_info.done_action.append({
"action": action,
"reason": reason,
"status": "start",
"time": datetime.datetime.now().strftime("%H:%M:%S")
})
conversation_info.done_action.append(
{
"action": action,
"reason": reason,
"status": "start",
"time": datetime.datetime.now().strftime("%H:%M:%S"),
}
)
if action == "direct_reply":
self.state = ConversationState.GENERATING
self.generated_reply = await self.reply_generator.generate(
observation_info,
conversation_info
)
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
# # 检查回复是否合适
# is_suitable, reason, need_replan = await self.reply_generator.check_reply(
# self.generated_reply,
# self.current_goal
# )
if self._check_new_messages_after_planning():
return None
await self._send_reply()
conversation_info.done_action.append({
"action": action,
"reason": reason,
"status": "done",
"time": datetime.datetime.now().strftime("%H:%M:%S")
})
conversation_info.done_action.append(
{
"action": action,
"reason": reason,
"status": "done",
"time": datetime.datetime.now().strftime("%H:%M:%S"),
}
)
elif action == "fetch_knowledge":
self.state = ConversationState.FETCHING
knowledge = "TODO:知识"
topic = "TODO:关键词"
logger.info(f"假装获取到知识{knowledge},关键词是: {topic}")
if knowledge:
if topic not in self.conversation_info.knowledge_list:
self.conversation_info.knowledge_list.append({
"topic": topic,
"knowledge": knowledge
})
self.conversation_info.knowledge_list.append({"topic": topic, "knowledge": knowledge})
else:
self.conversation_info.knowledge_list[topic] += knowledge
elif action == "rethink_goal":
self.state = ConversationState.RETHINKING
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
elif action == "listening":
self.state = ConversationState.LISTENING
logger.info("倾听对方发言...")
if await self.waiter.wait(): # 如果返回True表示超时
await self._send_timeout_message()
await self._stop_conversation()
else: # wait
self.state = ConversationState.WAITING
logger.info("等待更多信息...")
@@ -207,12 +198,10 @@ class Conversation:
messages = self.chat_observer.get_cached_messages(limit=1)
if not messages:
return
latest_message = self._convert_to_message(messages[0])
await self.direct_sender.send_message(
chat_stream=self.chat_stream,
content="TODO:超时消息",
reply_to_message=latest_message
chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
)
except Exception as e:
logger.error(f"发送超时消息失败: {str(e)}")
@@ -222,24 +211,22 @@ class Conversation:
if not self.generated_reply:
logger.warning("没有生成回复")
return
messages = self.chat_observer.get_cached_messages(limit=1)
if not messages:
logger.warning("没有最近的消息可以回复")
return
latest_message = self._convert_to_message(messages[0])
try:
await self.direct_sender.send_message(
chat_stream=self.chat_stream,
content=self.generated_reply,
reply_to_message=latest_message
chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
)
self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时")
self.state = ConversationState.ANALYZING
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
self.state = ConversationState.ANALYZING
self.state = ConversationState.ANALYZING

View File

@@ -1,8 +1,6 @@
class ConversationInfo:
def __init__(self):
self.done_action = []
self.goal_list = []
self.knowledge_list = []
self.memory_list = []
self.memory_list = []

View File

@@ -7,12 +7,13 @@ from src.plugins.chat.message import MessageSending
logger = get_module_logger("message_sender")
class DirectMessageSender:
"""直接消息发送器"""
def __init__(self):
pass
async def send_message(
self,
chat_stream: ChatStream,
@@ -20,7 +21,7 @@ class DirectMessageSender:
reply_to_message: Optional[Message] = None,
) -> None:
"""发送消息到聊天流
Args:
chat_stream: 聊天流
content: 消息内容
@@ -29,21 +30,18 @@ class DirectMessageSender:
try:
# 创建消息内容
segments = [Seg(type="text", data={"text": content})]
# 检查是否需要引用回复
if reply_to_message:
reply_id = reply_to_message.message_id
message_sending = MessageSending(
segments=segments,
reply_to_id=reply_id
)
message_sending = MessageSending(segments=segments, reply_to_id=reply_id)
else:
message_sending = MessageSending(segments=segments)
# 发送消息
await chat_stream.send_message(message_sending)
logger.info(f"消息已发送: {content}")
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
raise
raise

View File

@@ -2,133 +2,126 @@ from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from src.common.database import db
class MessageStorage(ABC):
"""消息存储接口"""
@abstractmethod
async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取指定消息ID之后的所有消息
Args:
chat_id: 聊天ID
message_id: 消息ID如果为None则获取所有消息
Returns:
List[Dict[str, Any]]: 消息列表
"""
pass
@abstractmethod
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
"""获取指定时间点之前的消息
Args:
chat_id: 聊天ID
time_point: 时间戳
limit: 最大消息数量
Returns:
List[Dict[str, Any]]: 消息列表
"""
pass
@abstractmethod
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
"""检查是否有新消息
Args:
chat_id: 聊天ID
after_time: 时间戳
Returns:
bool: 是否有新消息
"""
pass
class MongoDBMessageStorage(MessageStorage):
"""MongoDB消息存储实现"""
def __init__(self):
self.db = db
async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id}
if message_id:
# 获取ID大于message_id的消息
last_message = self.db.messages.find_one({"message_id": message_id})
if last_message:
query["time"] = {"$gt": last_message["time"]}
return list(
self.db.messages.find(query).sort("time", 1)
)
return list(self.db.messages.find(query).sort("time", 1))
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = {
"chat_id": chat_id,
"time": {"$lt": time_point}
}
messages = list(
self.db.messages.find(query).sort("time", -1).limit(limit)
)
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
messages = list(self.db.messages.find(query).sort("time", -1).limit(limit))
# 将消息按时间正序排列
messages.reverse()
return messages
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = {
"chat_id": chat_id,
"time": {"$gt": after_time}
}
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
return self.db.messages.find_one(query) is not None
# # 创建一个内存消息存储实现,用于测试
# class InMemoryMessageStorage(MessageStorage):
# """内存消息存储实现,主要用于测试"""
# def __init__(self):
# self.messages: Dict[str, List[Dict[str, Any]]] = {}
# async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
# if chat_id not in self.messages:
# return []
# messages = self.messages[chat_id]
# if not message_id:
# return messages
# # 找到message_id的索引
# try:
# index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id)
# return messages[index + 1:]
# except StopIteration:
# return []
# async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
# if chat_id not in self.messages:
# return []
# messages = [
# m for m in self.messages[chat_id]
# if m["time"] < time_point
# ]
# return messages[-limit:]
# async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
# if chat_id not in self.messages:
# return False
# return any(m["time"] > after_time for m in self.messages[chat_id])
# # 测试辅助方法
# def add_message(self, chat_id: str, message: Dict[str, Any]):
# """添加测试消息"""
# if chat_id not in self.messages:
# self.messages[chat_id] = []
# self.messages[chat_id].append(message)
# self.messages[chat_id].sort(key=lambda m: m["time"])
# self.messages[chat_id].sort(key=lambda m: m["time"])

View File

@@ -7,25 +7,26 @@ if TYPE_CHECKING:
logger = get_module_logger("notification_handler")
class PFCNotificationHandler(NotificationHandler):
"""PFC通知处理器"""
def __init__(self, conversation: 'Conversation'):
def __init__(self, conversation: "Conversation"):
"""初始化PFC通知处理器
Args:
conversation: 对话实例
"""
self.conversation = conversation
async def handle_notification(self, notification: Notification):
"""处理通知
Args:
notification: 通知对象
"""
logger.debug(f"收到通知: {notification.type.name}, 数据: {notification.data}")
# 根据通知类型执行不同的处理
if notification.type == NotificationType.NEW_MESSAGE:
# 新消息通知
@@ -38,34 +39,33 @@ class PFCNotificationHandler(NotificationHandler):
await self._handle_command(notification)
else:
logger.warning(f"未知的通知类型: {notification.type.name}")
async def _handle_new_message(self, notification: Notification):
"""处理新消息通知
Args:
notification: 通知对象
"""
# 更新决策信息
observation_info = self.conversation.observation_info
observation_info.last_message_time = notification.data.get("time", 0)
observation_info.add_unprocessed_message(notification.data)
# 手动触发观察器更新
self.conversation.chat_observer.trigger_update()
async def _handle_cold_chat(self, notification: Notification):
"""处理冷聊天通知
Args:
notification: 通知对象
"""
# 获取冷聊天信息
cold_duration = notification.data.get("duration", 0)
# 更新决策信息
observation_info = self.conversation.observation_info
observation_info.conversation_cold_duration = cold_duration
logger.info(f"对话已冷: {cold_duration}")

View File

@@ -1,5 +1,5 @@
#Programmable Friendly Conversationalist
#Prefrontal cortex
# Programmable Friendly Conversationalist
# Prefrontal cortex
from typing import List, Optional, Dict, Any, Set
from ..message.message_base import UserInfo
import time
@@ -10,26 +10,27 @@ from .chat_states import NotificationHandler
logger = get_module_logger("observation_info")
class ObservationInfoHandler(NotificationHandler):
"""ObservationInfo的通知处理器"""
def __init__(self, observation_info: 'ObservationInfo'):
def __init__(self, observation_info: "ObservationInfo"):
"""初始化处理器
Args:
observation_info: 要更新的ObservationInfo实例
"""
self.observation_info = observation_info
async def handle_notification(self, notification: Dict[str, Any]):
"""处理通知
Args:
notification: 通知数据
"""
notification_type = notification.get("type")
data = notification.get("data", {})
if notification_type == "NEW_MESSAGE":
# 处理新消息通知
logger.debug(f"收到新消息通知data: {data}")
@@ -37,62 +38,62 @@ class ObservationInfoHandler(NotificationHandler):
self.observation_info.update_from_message(message)
# self.observation_info.has_unread_messages = True
# self.observation_info.new_unread_message.append(message.get("processed_plain_text", ""))
elif notification_type == "COLD_CHAT":
# 处理冷场通知
is_cold = data.get("is_cold", False)
self.observation_info.update_cold_chat_status(is_cold, time.time())
elif notification_type == "ACTIVE_CHAT":
# 处理活跃通知
is_active = data.get("is_active", False)
self.observation_info.is_cold = not is_active
elif notification_type == "BOT_SPEAKING":
# 处理机器人说话通知
self.observation_info.is_typing = False
self.observation_info.last_bot_speak_time = time.time()
elif notification_type == "USER_SPEAKING":
# 处理用户说话通知
self.observation_info.is_typing = False
self.observation_info.last_user_speak_time = time.time()
elif notification_type == "MESSAGE_DELETED":
# 处理消息删除通知
message_id = data.get("message_id")
self.observation_info.unprocessed_messages = [
msg for msg in self.observation_info.unprocessed_messages
if msg.get("message_id") != message_id
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
]
elif notification_type == "USER_JOINED":
# 处理用户加入通知
user_id = data.get("user_id")
if user_id:
self.observation_info.active_users.add(user_id)
elif notification_type == "USER_LEFT":
# 处理用户离开通知
user_id = data.get("user_id")
if user_id:
self.observation_info.active_users.discard(user_id)
elif notification_type == "ERROR":
# 处理错误通知
error_msg = data.get("error", "")
logger.error(f"收到错误通知: {error_msg}")
@dataclass
class ObservationInfo:
"""决策信息类用于收集和管理来自chat_observer的通知信息"""
#data_list
# data_list
chat_history: List[str] = field(default_factory=list)
unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list)
active_users: Set[str] = field(default_factory=set)
#data
# data
last_bot_speak_time: Optional[float] = None
last_user_speak_time: Optional[float] = None
last_message_time: Optional[float] = None
@@ -101,78 +102,70 @@ class ObservationInfo:
bot_id: Optional[str] = None
new_messages_count: int = 0
cold_chat_duration: float = 0.0
#state
# state
is_typing: bool = False
has_unread_messages: bool = False
is_cold_chat: bool = False
changed: bool = False
# #spec
# meta_plan_trigger: bool = False
def __post_init__(self):
"""初始化后创建handler"""
self.chat_observer = None
self.handler = ObservationInfoHandler(self)
def bind_to_chat_observer(self, stream_id: str):
"""绑定到指定的chat_observer
Args:
stream_id: 聊天流ID
"""
self.chat_observer = ChatObserver.get_instance(stream_id)
self.chat_observer.notification_manager.register_handler(
target="observation_info",
notification_type="NEW_MESSAGE",
handler=self.handler
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
)
self.chat_observer.notification_manager.register_handler(
target="observation_info",
notification_type="COLD_CHAT",
handler=self.handler
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
)
def unbind_from_chat_observer(self):
"""解除与chat_observer的绑定"""
if self.chat_observer:
self.chat_observer.notification_manager.unregister_handler(
target="observation_info",
notification_type="NEW_MESSAGE",
handler=self.handler
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
)
self.chat_observer.notification_manager.unregister_handler(
target="observation_info",
notification_type="COLD_CHAT",
handler=self.handler
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
)
self.chat_observer = None
def update_from_message(self, message: Dict[str, Any]):
"""从消息更新信息
Args:
message: 消息数据
"""
logger.debug(f"更新信息from_message: {message}")
self.last_message_time = message["time"]
self.last_message_content = message.get("processed_plain_text", "")
user_info = UserInfo.from_dict(message.get("user_info", {}))
self.last_message_sender = user_info.user_id
if user_info.user_id == self.bot_id:
self.last_bot_speak_time = message["time"]
else:
self.last_user_speak_time = message["time"]
self.active_users.add(user_info.user_id)
self.new_messages_count += 1
self.unprocessed_messages.append(message)
self.update_changed()
def update_changed(self):
"""更新changed状态"""
self.changed = True
@@ -180,7 +173,7 @@ class ObservationInfo:
def update_cold_chat_status(self, is_cold: bool, current_time: float):
"""更新冷场状态
Args:
is_cold: 是否冷场
current_time: 当前时间
@@ -188,37 +181,37 @@ class ObservationInfo:
self.is_cold_chat = is_cold
if is_cold and self.last_message_time:
self.cold_chat_duration = current_time - self.last_message_time
def get_active_duration(self) -> float:
"""获取当前活跃时长
Returns:
float: 最后一条消息到现在的时长(秒)
"""
if not self.last_message_time:
return 0.0
return time.time() - self.last_message_time
def get_user_response_time(self) -> Optional[float]:
"""获取用户响应时间
Returns:
Optional[float]: 用户最后发言到现在的时长如果没有用户发言则返回None
"""
if not self.last_user_speak_time:
return None
return time.time() - self.last_user_speak_time
def get_bot_response_time(self) -> Optional[float]:
"""获取机器人响应时间
Returns:
Optional[float]: 机器人最后发言到现在的时长如果没有机器人发言则返回None
"""
if not self.last_bot_speak_time:
return None
return time.time() - self.last_bot_speak_time
def clear_unprocessed_messages(self):
"""清空未处理消息列表"""
# 将未处理消息添加到历史记录中
@@ -229,10 +222,10 @@ class ObservationInfo:
self.has_unread_messages = False
self.unprocessed_messages.clear()
self.new_messages_count = 0
def add_unprocessed_message(self, message: Dict[str, Any]):
"""添加未处理的消息
Args:
message: 消息数据
"""
@@ -241,6 +234,6 @@ class ObservationInfo:
if message_id and not any(m.get("message_id") == message_id for m in self.unprocessed_messages):
self.unprocessed_messages.append(message)
self.new_messages_count += 1
# 同时更新其他消息相关信息
self.update_from_message(message)
self.update_from_message(message)

View File

@@ -49,43 +49,40 @@ class GoalAnalyzer:
Args:
conversation_info: 对话信息
observation_info: 观察信息
Returns:
Tuple[str, str, str]: (目标, 方法, 原因)
"""
#构建对话目标
# 构建对话目标
goal_list = conversation_info.goal_list
goal_text = ""
for goal, reason in goal_list:
goal_text += f"目标:{goal};"
goal_text += f"原因:{reason}\n"
# 获取聊天历史记录
chat_history_list = observation_info.chat_history
chat_history_text = ""
for msg in chat_history_list:
chat_history_text += f"{msg}\n"
if observation_info.new_messages_count > 0:
new_messages_list = observation_info.unprocessed_messages
chat_history_text += f"{observation_info.new_messages_count}条新消息:\n"
for msg in new_messages_list:
chat_history_text += f"{msg}\n"
observation_info.clear_unprocessed_messages()
personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本
action_history_list = conversation_info.done_action
action_history_text = "你之前做的事情是:"
for action in action_history_list:
action_history_text += f"{action}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。
这些目标应该反映出对话的不同方面和意图。
@@ -119,20 +116,15 @@ class GoalAnalyzer:
logger.debug(f"发送到LLM的提示词: {prompt}")
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"goal", "reasoning",
required_types={"goal": str, "reasoning": str}
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
)
#TODO
# TODO
conversation_info.goal_list.append(result)
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表
@@ -229,24 +221,26 @@ class GoalAnalyzer:
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 尝试解析JSON
success, result = get_items_from_json(
content,
"goal_achieved", "stop_conversation", "reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}
"goal_achieved",
"stop_conversation",
"reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
)
if not success:
logger.error("无法解析对话分析结果JSON")
return False, False, "解析结果失败"
goal_achieved = result["goal_achieved"]
stop_conversation = result["stop_conversation"]
reason = result["reason"]
return goal_achieved, stop_conversation, reason
except Exception as e:
logger.error(f"分析对话状态时出错: {str(e)}")
return False, False, f"分析出错: {str(e)}"
@@ -269,23 +263,22 @@ class Waiter:
# 使用当前时间作为等待开始时间
wait_start_time = time.time()
self.chat_observer.waiting_start_time = wait_start_time # 设置等待开始时间
while True:
# 检查是否有新消息
if self.chat_observer.new_message_after(wait_start_time):
logger.info("等待结束,收到新消息")
return False
# 检查是否超时
if time.time() - wait_start_time > 300:
logger.info("等待超过300秒结束对话")
return True
await asyncio.sleep(1)
logger.info("等待中...")
class DirectMessageSender:
"""直接发送消息到平台的发送器"""

View File

@@ -5,33 +5,34 @@ import traceback
logger = get_module_logger("pfc_manager")
class PFCManager:
"""PFC对话管理器负责管理所有对话实例"""
# 单例模式
_instance = None
# 会话实例管理
_instances: Dict[str, Conversation] = {}
_initializing: Dict[str, bool] = {}
@classmethod
def get_instance(cls) -> 'PFCManager':
def get_instance(cls) -> "PFCManager":
"""获取管理器单例
Returns:
PFCManager: 管理器实例
"""
if cls._instance is None:
cls._instance = PFCManager()
return cls._instance
async def get_or_create_conversation(self, stream_id: str) -> Optional[Conversation]:
"""获取或创建对话实例
Args:
stream_id: 聊天流ID
Returns:
Optional[Conversation]: 对话实例创建失败则返回None
"""
@@ -39,11 +40,11 @@ class PFCManager:
if stream_id in self._initializing and self._initializing[stream_id]:
logger.debug(f"会话实例正在初始化中: {stream_id}")
return None
if stream_id in self._instances:
logger.debug(f"使用现有会话实例: {stream_id}")
return self._instances[stream_id]
try:
# 创建新实例
logger.info(f"创建新的对话实例: {stream_id}")
@@ -51,47 +52,45 @@ class PFCManager:
# 创建实例
conversation_instance = Conversation(stream_id)
self._instances[stream_id] = conversation_instance
# 启动实例初始化
await self._initialize_conversation(conversation_instance)
except Exception as e:
logger.error(f"创建会话实例失败: {stream_id}, 错误: {e}")
return None
return conversation_instance
async def _initialize_conversation(self, conversation: Conversation):
"""初始化会话实例
Args:
conversation: 要初始化的会话实例
"""
stream_id = conversation.stream_id
try:
logger.info(f"开始初始化会话实例: {stream_id}")
# 启动初始化流程
await conversation._initialize()
# 标记初始化完成
self._initializing[stream_id] = False
logger.info(f"会话实例 {stream_id} 初始化完成")
except Exception as e:
logger.error(f"管理器初始化会话实例失败: {stream_id}, 错误: {e}")
logger.error(traceback.format_exc())
# 清理失败的初始化
async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
"""获取已存在的会话实例
Args:
stream_id: 聊天流ID
Returns:
Optional[Conversation]: 会话实例不存在则返回None
"""
return self._instances.get(stream_id)
return self._instances.get(stream_id)

View File

@@ -4,6 +4,7 @@ from typing import Literal
class ConversationState(Enum):
"""对话状态"""
INIT = "初始化"
RETHINKING = "重新思考"
ANALYZING = "分析历史"
@@ -18,4 +19,4 @@ class ConversationState(Enum):
JUDGING = "判断"
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]

View File

@@ -13,33 +13,26 @@ logger = get_module_logger("reply_generator")
class ReplyGenerator:
"""回复生成器"""
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=300,
request_type="reply_generation"
model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation"
)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
self.reply_checker = ReplyChecker(stream_id)
async def generate(
self,
observation_info: ObservationInfo,
conversation_info: ConversationInfo
) -> str:
async def generate(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> str:
"""生成回复
Args:
goal: 对话目标
chat_history: 聊天历史
knowledge_cache: 知识缓存
previous_reply: 上一次生成的回复(如果有)
retry_count: 当前重试次数
Returns:
str: 生成的回复
"""
@@ -51,22 +44,21 @@ class ReplyGenerator:
for goal, reason in goal_list:
goal_text += f"目标:{goal};"
goal_text += f"原因:{reason}\n"
# 获取聊天历史记录
chat_history_list = observation_info.chat_history
chat_history_text = ""
for msg in chat_history_list:
chat_history_text += f"{msg}\n"
# 整理知识缓存
knowledge_text = ""
knowledge_list = conversation_info.knowledge_list
for knowledge in knowledge_list:
knowledge_text += f"知识:{knowledge}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请根据以下信息生成回复
当前对话目标:{goal_text}
@@ -92,7 +84,7 @@ class ReplyGenerator:
logger.info(f"生成的回复: {content}")
# is_new = self.chat_observer.check()
# logger.debug(f"再看一眼聊天记录,{'有' if is_new else '没有'}新消息")
# 如果有新消息,重新生成回复
# if is_new:
# logger.info("检测到新消息,重新生成回复")
@@ -100,27 +92,22 @@ class ReplyGenerator:
# goal, chat_history, knowledge_cache,
# None, retry_count
# )
return content
except Exception as e:
logger.error(f"生成回复时出错: {e}")
return "抱歉,我现在有点混乱,让我重新思考一下..."
async def check_reply(
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
"""检查回复是否合适
Args:
reply: 生成的回复
goal: 对话目标
retry_count: 当前重试次数
Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
"""
return await self.reply_checker.check(reply, goal, retry_count)
return await self.reply_checker.check(reply, goal, retry_count)

View File

@@ -3,43 +3,44 @@ from .chat_observer import ChatObserver
logger = get_module_logger("waiter")
class Waiter:
"""等待器,用于等待对话流中的事件"""
def __init__(self, stream_id: str):
self.stream_id = stream_id
self.chat_observer = ChatObserver.get_instance(stream_id)
async def wait(self, timeout: float = 20.0) -> bool:
"""等待用户回复或超时
Args:
timeout: 超时时间(秒)
Returns:
bool: 如果因为超时返回则为True否则为False
"""
try:
message_before = self.chat_observer.get_last_message()
# 等待新消息
logger.debug(f"等待新消息,超时时间: {timeout}")
is_timeout = await self.chat_observer.wait_for_update(timeout=timeout)
if is_timeout:
logger.debug("等待超时,没有收到新消息")
return True
# 检查是否是新消息
message_after = self.chat_observer.get_last_message()
if message_before and message_after and message_before.get("message_id") == message_after.get("message_id"):
# 如果消息ID相同说明没有新消息
logger.debug("没有收到新消息")
return True
logger.debug("收到新消息")
return False
except Exception as e:
logger.error(f"等待时出错: {str(e)}")
return True
return True

View File

@@ -30,7 +30,7 @@ class ChatBot:
self.think_flow_chat = ThinkFlowChat()
self.reasoning_chat = ReasoningChat()
self.only_process_chat = MessageProcessor()
# 创建初始化PFC管理器的任务会在_ensure_started时执行
self.pfc_manager = PFCManager.get_instance()
@@ -38,7 +38,7 @@ class ChatBot:
"""确保所有任务已启动"""
if not self._started:
logger.info("确保ChatBot所有任务已启动")
self._started = True
async def _create_PFC_chat(self, message: MessageRecv):
@@ -46,7 +46,6 @@ class ChatBot:
chat_id = str(message.chat_stream.stream_id)
if global_config.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id)
except Exception as e:
@@ -80,7 +79,7 @@ class ChatBot:
try:
# 确保所有任务已启动
await self._ensure_started()
message = MessageRecv(message_data)
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info

View File

@@ -24,7 +24,7 @@ config_config = LogConfig(
# 配置主程序日志格式
logger = get_module_logger("config", config=config_config)
#考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
is_test = True
mai_version_main = "0.6.2"
mai_version_fix = "snapshot-1"

View File

@@ -2,7 +2,7 @@
__version__ = "0.1.0"
from .api import BaseMessageAPI, global_api
from .api import global_api
from .message_base import (
Seg,
GroupInfo,
@@ -14,7 +14,6 @@ from .message_base import (
)
__all__ = [
"BaseMessageAPI",
"Seg",
"global_api",
"GroupInfo",

View File

@@ -1,7 +1,8 @@
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from typing import Dict, Any, Callable, List, Set
from typing import Dict, Any, Callable, List, Set, Optional
from src.common.logger import get_module_logger
from src.plugins.message.message_base import MessageBase
from src.common.server import global_server
import aiohttp
import asyncio
import uvicorn
@@ -49,13 +50,22 @@ class MessageServer(BaseMessageHandler):
_class_handlers: List[Callable] = [] # 类级别的消息处理器
def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False):
def __init__(
self,
host: str = "0.0.0.0",
port: int = 18000,
enable_token=False,
app: Optional[FastAPI] = None,
path: str = "/ws",
):
super().__init__()
# 将类级别的处理器添加到实例处理器中
self.message_handlers.extend(self._class_handlers)
self.app = FastAPI()
self.host = host
self.port = port
self.path = path
self.app = app or FastAPI()
self.own_app = app is None # 标记是否使用自己创建的app
self.active_websockets: Set[WebSocket] = set()
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
self.valid_tokens: Set[str] = set()
@@ -63,28 +73,6 @@ class MessageServer(BaseMessageHandler):
self._setup_routes()
self._running = False
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def _setup_routes(self):
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
@@ -125,6 +113,90 @@ class MessageServer(BaseMessageHandler):
finally:
self._remove_websocket(websocket, platform)
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def run_sync(self):
"""同步方式运行服务器"""
if not self.own_app:
raise RuntimeError("当使用外部FastAPI实例时请使用该实例的运行方法")
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
self._running = True
try:
if self.own_app:
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
await self.server.serve()
else:
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
while self._running:
await asyncio.sleep(1)
except KeyboardInterrupt:
await self.stop()
raise
except Exception as e:
await self.stop()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.stop()
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server") and self.own_app:
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def _remove_websocket(self, websocket: WebSocket, platform: str):
"""从所有集合中移除websocket"""
if websocket in self.active_websockets:
@@ -161,54 +233,6 @@ class MessageServer(BaseMessageHandler):
async def send_message(self, message: MessageBase):
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
try:
await self.server.serve()
except KeyboardInterrupt as e:
await self.stop()
raise KeyboardInterrupt from e
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
@@ -219,105 +243,4 @@ class MessageServer(BaseMessageHandler):
raise e
class BaseMessageAPI:
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
self.app = FastAPI()
self.host = host
self.port = port
self.message_handlers: List[Callable] = []
self.cache = []
self._setup_routes()
self._running = False
def _setup_routes(self):
"""设置基础路由"""
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
try:
# 创建后台任务处理消息
asyncio.create_task(self._background_message_handler(message))
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def _background_message_handler(self, message: Dict[str, Any]):
"""后台处理单个消息"""
try:
await self.process_single_message(message)
except Exception as e:
logger.error(f"Background message processing failed: {str(e)}")
logger.error(traceback.format_exc())
def register_message_handler(self, handler: Callable):
"""注册消息处理函数"""
self.message_handlers.append(handler)
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
return await response.json()
except Exception:
# logger.error(f"发送消息失败: {str(e)}")
pass
async def process_single_message(self, message: Dict[str, Any]):
"""处理单条消息"""
tasks = []
for handler in self.message_handlers:
try:
tasks.append(handler(message))
except Exception as e:
logger.error(str(e))
logger.error(traceback.format_exc())
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
try:
await self.server.serve()
except KeyboardInterrupt as e:
await self.stop()
raise KeyboardInterrupt from e
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def start(self):
"""启动服务器的便捷方法"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.start_server())
except KeyboardInterrupt:
pass
finally:
loop.close()
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]))
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())