Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -321,9 +321,7 @@ CHAT_STYLE_CONFIG = {
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": (
|
||||
"<level>{time:MM-DD HH:mm}</level> | <green>见闻</green> | <green>{message}</green>"
|
||||
), # noqa: E501
|
||||
"console_format": ("<level>{time:MM-DD HH:mm}</level> | <green>见闻</green> | <green>{message}</green>"), # noqa: E501
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ class Observation:
|
||||
self.observe_type = observe_type
|
||||
self.observe_id = observe_id
|
||||
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
|
||||
|
||||
|
||||
async def observe(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class InterestChatting:
|
||||
self.max_reply_probability: float = max_probability
|
||||
self.current_reply_probability: float = 0.0
|
||||
self.is_above_threshold: bool = False
|
||||
|
||||
|
||||
# 任务相关属性初始化
|
||||
self.update_task: Optional[asyncio.Task] = None
|
||||
self._stop_event = asyncio.Event()
|
||||
|
||||
@@ -271,11 +271,10 @@ class SubHeartflowManager:
|
||||
log_prefix = "[兴趣评估]"
|
||||
current_state = current_mai_state.get_current_state()
|
||||
focused_limit = current_state.get_focused_chat_max_num()
|
||||
|
||||
|
||||
|
||||
if int(time.time()) % 20 == 0: # 每20秒输出一次
|
||||
logger.debug(f"{log_prefix} 当前状态 ({current_state.value}) 可以在{focused_limit}个群激情聊天")
|
||||
|
||||
|
||||
if focused_limit <= 0:
|
||||
# logger.debug(f"{log_prefix} 当前状态 ({current_state.value}) 不允许 FOCUSED 子心流")
|
||||
return
|
||||
@@ -288,22 +287,23 @@ class SubHeartflowManager:
|
||||
states_num = (
|
||||
self.count_subflows_by_state(ChatState.ABSENT),
|
||||
self.count_subflows_by_state(ChatState.CHAT),
|
||||
current_focused_count
|
||||
current_focused_count,
|
||||
)
|
||||
|
||||
for sub_hf in list(self.subheartflows.values()):
|
||||
flow_id = sub_hf.subheartflow_id
|
||||
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
|
||||
|
||||
|
||||
# 跳过非CHAT状态或已经是FOCUSED状态的子心流
|
||||
if sub_hf.chat_state.chat_status == ChatState.FOCUSED:
|
||||
continue
|
||||
|
||||
|
||||
from .mai_state_manager import enable_unlimited_hfc_chat
|
||||
|
||||
if not enable_unlimited_hfc_chat:
|
||||
if sub_hf.chat_state.chat_status != ChatState.CHAT:
|
||||
continue
|
||||
|
||||
|
||||
# 检查是否满足提升概率
|
||||
if random.random() >= sub_hf.interest_chatting.start_hfc_probability:
|
||||
continue
|
||||
@@ -324,12 +324,12 @@ class SubHeartflowManager:
|
||||
|
||||
# 执行状态提升
|
||||
await current_subflow.set_chat_state(ChatState.FOCUSED, states_num)
|
||||
|
||||
# 验证提升结果
|
||||
if (final_subflow := self.subheartflows.get(flow_id)) and \
|
||||
final_subflow.chat_state.chat_status == ChatState.FOCUSED:
|
||||
current_focused_count += 1
|
||||
|
||||
# 验证提升结果
|
||||
if (
|
||||
final_subflow := self.subheartflows.get(flow_id)
|
||||
) and final_subflow.chat_state.chat_status == ChatState.FOCUSED:
|
||||
current_focused_count += 1
|
||||
|
||||
async def randomly_deactivate_subflows(self, deactivation_probability: float = 0.1):
|
||||
"""以一定概率将 FOCUSED 或 CHAT 状态的子心流回退到 ABSENT 状态。"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.common.logger import get_module_logger
|
||||
from ..message.message_base import UserInfo
|
||||
from maim_message import UserInfo
|
||||
from ...config.config import global_config
|
||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||
from .message_storage import MongoDBMessageStorage
|
||||
|
||||
@@ -13,7 +13,7 @@ from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo
|
||||
from .reply_generator import ReplyGenerator
|
||||
from ..chat.chat_stream import ChatStream
|
||||
from ..message.message_base import UserInfo
|
||||
from maim_message import UserInfo
|
||||
from src.plugins.chat.chat_stream import chat_manager
|
||||
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
||||
from .waiter import Waiter
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from ..chat.chat_stream import ChatStream
|
||||
from ..chat.message import Message
|
||||
from ..message.message_base import Seg
|
||||
from maim_message import Seg
|
||||
from src.plugins.chat.message import MessageSending, MessageSet
|
||||
from src.plugins.chat.message_sender import message_manager
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Programmable Friendly Conversationalist
|
||||
# Prefrontal cortex
|
||||
from typing import List, Optional, Dict, Any, Set
|
||||
from ..message.message_base import UserInfo
|
||||
from maim_message import UserInfo
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
@@ -6,7 +6,7 @@ import datetime
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING
|
||||
from src.common.logger import get_module_logger
|
||||
from ..chat.chat_stream import ChatStream
|
||||
from ..message.message_base import UserInfo, Seg
|
||||
from maim_message import UserInfo, Seg
|
||||
from ..chat.message import Message
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
@@ -371,22 +371,11 @@ class DirectMessageSender:
|
||||
# 处理消息
|
||||
await message.process()
|
||||
|
||||
message_json = message.to_dict()
|
||||
_message_json = message.to_dict()
|
||||
|
||||
# 发送消息
|
||||
try:
|
||||
end_point = global_config.api_urls.get(message.message_info.platform, None)
|
||||
if end_point:
|
||||
# logger.info(f"发送消息到{end_point}")
|
||||
# logger.info(message_json)
|
||||
try:
|
||||
await global_api.send_message_REST(end_point, message_json)
|
||||
except Exception as e:
|
||||
logger.error(f"REST方式发送失败,出现错误: {str(e)}")
|
||||
logger.info("尝试使用ws发送")
|
||||
await self.send_via_ws(message)
|
||||
else:
|
||||
await self.send_via_ws(message)
|
||||
await self.send_via_ws(message)
|
||||
logger.success(f"PFC消息已发送: {content}")
|
||||
except Exception as e:
|
||||
logger.error(f"PFC消息发送失败: {str(e)}")
|
||||
|
||||
@@ -5,7 +5,7 @@ from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from ..message.message_base import UserInfo
|
||||
from maim_message import UserInfo
|
||||
|
||||
logger = get_module_logger("reply_checker")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Dict, Optional
|
||||
|
||||
|
||||
from ...common.database import db
|
||||
from ..message.message_base import GroupInfo, UserInfo
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from src.common.logger import get_module_logger, LogConfig, CHAT_STREAM_STYLE_CONFIG
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import urllib3
|
||||
from src.common.logger import get_module_logger
|
||||
from .chat_stream import ChatStream
|
||||
from .utils_image import image_manager
|
||||
from ..message.message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
|
||||
logger = get_module_logger("chat_message")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from src.common.logger import get_module_logger
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from .message import MessageRecv
|
||||
from ..message.message_base import BaseMessageInfo, GroupInfo, Seg
|
||||
from maim_message import BaseMessageInfo, GroupInfo, Seg
|
||||
import hashlib
|
||||
from typing import Dict
|
||||
from collections import OrderedDict
|
||||
|
||||
@@ -62,20 +62,10 @@ class MessageSender:
|
||||
# logger.trace(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志
|
||||
# --- 结束打字延迟 ---
|
||||
|
||||
message_json = message.to_dict()
|
||||
message_preview = truncate_message(message.processed_plain_text)
|
||||
|
||||
try:
|
||||
end_point = global_config.api_urls.get(message.message_info.platform, None)
|
||||
if end_point:
|
||||
try:
|
||||
await global_api.send_message_rest(end_point, message_json)
|
||||
except Exception as e:
|
||||
logger.error(f"REST发送失败: {str(e)}")
|
||||
logger.info(f"[{message.chat_stream.stream_id}] 尝试使用WS发送")
|
||||
await self.send_via_ws(message)
|
||||
else:
|
||||
await self.send_via_ws(message)
|
||||
await self.send_via_ws(message)
|
||||
logger.success(f"发送消息 '{message_preview}' 成功") # 调整日志格式
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}")
|
||||
|
||||
@@ -12,7 +12,7 @@ from ..models.utils_model import LLMRequest
|
||||
from ..utils.typo_generator import ChineseTypoGenerator
|
||||
from ...config.config import global_config
|
||||
from .message import MessageRecv, Message
|
||||
from ..message.message_base import UserInfo
|
||||
from maim_message import UserInfo
|
||||
from .chat_stream import ChatStream
|
||||
from ..moods.moods import MoodManager
|
||||
from ...common.database import db
|
||||
|
||||
@@ -29,10 +29,11 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
||||
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
还没经过测试,有些地方数据库和内存数据同步可能不完全
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class MaiEmoji:
|
||||
"""定义一个表情包"""
|
||||
@@ -258,7 +259,7 @@ class EmojiManager:
|
||||
if emoji.hash == hash:
|
||||
emoji.usage_count += 1
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
|
||||
@@ -316,7 +317,9 @@ class EmojiManager:
|
||||
|
||||
time_end = time.time()
|
||||
|
||||
logger.info(f"找到[{text_emotion}]表情包,用时:{time_end - time_start:.2f}秒: {selected_emoji.description} (相似度: {similarity:.4f})")
|
||||
logger.info(
|
||||
f"找到[{text_emotion}]表情包,用时:{time_end - time_start:.2f}秒: {selected_emoji.description} (相似度: {similarity:.4f})"
|
||||
)
|
||||
return selected_emoji.path, f"[ {selected_emoji.description} ]"
|
||||
|
||||
except Exception as e:
|
||||
@@ -784,16 +787,15 @@ class EmojiManager:
|
||||
logger.error(f"[错误] 注册表情包失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
|
||||
async def clear_temp_emoji(self):
|
||||
"""每天清理临时表情包
|
||||
清理/data/emoji和/data/image目录下的所有文件
|
||||
当目录中文件数超过50时,会全部删除
|
||||
"""
|
||||
|
||||
|
||||
logger.info("[清理] 开始清理临时表情包...")
|
||||
|
||||
|
||||
# 清理emoji目录
|
||||
emoji_dir = os.path.join(BASE_DIR, "emoji")
|
||||
if os.path.exists(emoji_dir):
|
||||
@@ -805,7 +807,7 @@ class EmojiManager:
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
logger.debug(f"[清理] 删除表情包文件: {filename}")
|
||||
|
||||
|
||||
# 清理image目录
|
||||
image_dir = os.path.join(BASE_DIR, "image")
|
||||
if os.path.exists(image_dir):
|
||||
@@ -817,10 +819,8 @@ class EmojiManager:
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
logger.debug(f"[清理] 删除图片文件: {filename}")
|
||||
|
||||
|
||||
logger.success("[清理] 临时文件清理完成")
|
||||
|
||||
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.common.logger import get_module_logger, LogConfig, PFC_STYLE_CONFIG #
|
||||
from src.plugins.models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.plugins.chat.utils_image import image_path_to_base64 # Local import needed after move
|
||||
from src.plugins.utils.timer_calculater import Timer # <--- Import Timer
|
||||
from src.plugins.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.plugins.heartFC_chat.heartFC_generator import HeartFCGenerator
|
||||
from src.do_tool.tool_use import ToolUser
|
||||
from ..chat.message_sender import message_manager # <-- Import the global manager
|
||||
@@ -40,31 +40,28 @@ logger = get_module_logger("HeartFCLoop", config=interest_log_config) # Logger
|
||||
|
||||
|
||||
# 默认动作定义
|
||||
DEFAULT_ACTIONS = {
|
||||
"no_reply": "不回复",
|
||||
"text_reply": "文本回复, 可选附带表情",
|
||||
"emoji_reply": "仅表情回复"
|
||||
}
|
||||
DEFAULT_ACTIONS = {"no_reply": "不回复", "text_reply": "文本回复, 可选附带表情", "emoji_reply": "仅表情回复"}
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""动作管理器:控制每次决策可以使用的动作"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
# 初始化为默认动作集
|
||||
self._available_actions: Dict[str, str] = DEFAULT_ACTIONS.copy()
|
||||
|
||||
|
||||
def get_available_actions(self) -> Dict[str, str]:
|
||||
"""获取当前可用的动作集"""
|
||||
return self._available_actions
|
||||
|
||||
|
||||
def add_action(self, action_name: str, description: str) -> bool:
|
||||
"""
|
||||
添加新的动作
|
||||
|
||||
|
||||
参数:
|
||||
action_name: 动作名称
|
||||
description: 动作描述
|
||||
|
||||
|
||||
返回:
|
||||
bool: 是否添加成功
|
||||
"""
|
||||
@@ -72,14 +69,14 @@ class ActionManager:
|
||||
return False
|
||||
self._available_actions[action_name] = description
|
||||
return True
|
||||
|
||||
|
||||
def remove_action(self, action_name: str) -> bool:
|
||||
"""
|
||||
移除指定动作
|
||||
|
||||
|
||||
参数:
|
||||
action_name: 动作名称
|
||||
|
||||
|
||||
返回:
|
||||
bool: 是否移除成功
|
||||
"""
|
||||
@@ -87,58 +84,67 @@ class ActionManager:
|
||||
return False
|
||||
del self._available_actions[action_name]
|
||||
return True
|
||||
|
||||
|
||||
def clear_actions(self):
|
||||
"""清空所有动作"""
|
||||
self._available_actions.clear()
|
||||
|
||||
|
||||
def reset_to_default(self):
|
||||
"""重置为默认动作集"""
|
||||
self._available_actions = DEFAULT_ACTIONS.copy()
|
||||
|
||||
|
||||
def get_planner_tool_definition(self) -> List[Dict[str, Any]]:
|
||||
"""获取当前动作集对应的规划器工具定义"""
|
||||
return [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "decide_reply_action",
|
||||
"description": "根据当前聊天内容和上下文,决定机器人是否应该回复以及如何回复。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": list(self._available_actions.keys()),
|
||||
"description": "决定采取的行动:" +
|
||||
", ".join([f"'{k}'({v})" for k, v in self._available_actions.items()]),
|
||||
},
|
||||
"reasoning": {"type": "string", "description": "做出此决定的简要理由。"},
|
||||
"emoji_query": {
|
||||
"type": "string",
|
||||
"description": "如果行动是'emoji_reply',指定表情的主题或概念。如果行动是'text_reply'且希望在文本后追加表情,也在此指定表情主题。",
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "decide_reply_action",
|
||||
"description": "根据当前聊天内容和上下文,决定机器人是否应该回复以及如何回复。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": list(self._available_actions.keys()),
|
||||
"description": "决定采取的行动:"
|
||||
+ ", ".join([f"'{k}'({v})" for k, v in self._available_actions.items()]),
|
||||
},
|
||||
"reasoning": {"type": "string", "description": "做出此决定的简要理由。"},
|
||||
"emoji_query": {
|
||||
"type": "string",
|
||||
"description": "如果行动是'emoji_reply',指定表情的主题或概念。如果行动是'text_reply'且希望在文本后追加表情,也在此指定表情主题。",
|
||||
},
|
||||
},
|
||||
"required": ["action", "reasoning"],
|
||||
},
|
||||
"required": ["action", "reasoning"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# 在文件开头添加自定义异常类
|
||||
class HeartFCError(Exception):
|
||||
"""麦麦聊天系统基础异常类"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PlannerError(HeartFCError):
|
||||
"""规划器异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ReplierError(HeartFCError):
|
||||
"""回复器异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SenderError(HeartFCError):
|
||||
"""发送器异常"""
|
||||
|
||||
pass
|
||||
|
||||
class HeartFChatting:
|
||||
@@ -160,7 +166,7 @@ class HeartFChatting:
|
||||
self.chat_stream: Optional[ChatStream] = None # 关联的聊天流
|
||||
self.sub_mind: SubMind = sub_mind # 关联的子思维
|
||||
self.observations: List[Observation] = observations # 关联的观察列表,用于监控聊天流状态
|
||||
|
||||
|
||||
# 日志前缀
|
||||
self.log_prefix: str = f"[{chat_manager.get_stream_name(chat_id) or chat_id}]"
|
||||
|
||||
@@ -206,7 +212,7 @@ class HeartFChatting:
|
||||
|
||||
# 更新日志前缀(以防流名称发生变化)
|
||||
self.log_prefix = f"[{chat_manager.get_stream_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"麦麦感觉到了,可以开始激情水群{self.log_prefix} ")
|
||||
return True
|
||||
@@ -265,22 +271,22 @@ class HeartFChatting:
|
||||
self._processing_lock.release()
|
||||
|
||||
async def _hfc_loop(self):
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||
try:
|
||||
while True: # 主循环
|
||||
# 创建新的循环信息
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle = CycleInfo(self._cycle_counter)
|
||||
|
||||
|
||||
# 初始化周期状态
|
||||
cycle_timers = {}
|
||||
loop_cycle_start_time = time.monotonic()
|
||||
|
||||
|
||||
# 执行规划和处理阶段
|
||||
async with self._get_cycle_context() as acquired_lock:
|
||||
if not acquired_lock:
|
||||
continue
|
||||
|
||||
|
||||
# 记录规划开始时间点
|
||||
planner_start_db_time = time.time()
|
||||
|
||||
@@ -295,22 +301,22 @@ class HeartFChatting:
|
||||
|
||||
# 防止循环过快消耗资源
|
||||
await self._handle_cycle_delay(action_taken, loop_cycle_start_time, self.log_prefix)
|
||||
|
||||
|
||||
# 等待直到所有消息都发送完成
|
||||
with Timer("发送消息", cycle_timers):
|
||||
while await self._should_skip_cycle(thinking_id):
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
|
||||
# 完成当前循环并保存历史
|
||||
self._current_cycle.complete_cycle()
|
||||
self._cycle_history.append(self._current_cycle)
|
||||
|
||||
|
||||
# 记录循环信息和计时器结果
|
||||
timer_strings = []
|
||||
for name, elapsed in cycle_timers.items():
|
||||
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒"
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 第 #{self._current_cycle.cycle_id}次思考完成,"
|
||||
f"耗时: {self._current_cycle.end_time - self._current_cycle.start_time:.2f}秒, "
|
||||
@@ -328,7 +334,7 @@ class HeartFChatting:
|
||||
async def _get_cycle_context(self):
|
||||
"""
|
||||
循环周期的上下文管理器
|
||||
|
||||
|
||||
用于确保资源的正确获取和释放:
|
||||
1. 获取处理锁
|
||||
2. 执行操作
|
||||
@@ -346,10 +352,10 @@ class HeartFChatting:
|
||||
async def _check_new_messages(self, start_time: float) -> bool:
|
||||
"""
|
||||
检查从指定时间点后是否有新消息
|
||||
|
||||
|
||||
参数:
|
||||
start_time: 开始检查的时间点
|
||||
|
||||
|
||||
返回:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
@@ -363,9 +369,7 @@ class HeartFChatting:
|
||||
logger.error(f"{self.log_prefix} 检查新消息时出错: {e}")
|
||||
return False
|
||||
|
||||
async def _think_plan_execute_loop(
|
||||
self, cycle_timers: dict, planner_start_db_time: float
|
||||
) -> tuple[bool, str]:
|
||||
async def _think_plan_execute_loop(self, cycle_timers: dict, planner_start_db_time: float) -> tuple[bool, str]:
|
||||
"""执行规划阶段"""
|
||||
try:
|
||||
# think:思考
|
||||
@@ -398,7 +402,7 @@ class HeartFChatting:
|
||||
reasoning = planner_result.get("reasoning", "未提供理由")
|
||||
# 更新循环信息
|
||||
self._current_cycle.set_action_info(action, reasoning, True)
|
||||
|
||||
|
||||
# 处理LLM错误
|
||||
if planner_result.get("llm_error"):
|
||||
logger.error(f"{self.log_prefix} LLM失败: {reasoning}")
|
||||
@@ -415,37 +419,32 @@ class HeartFChatting:
|
||||
return False, ""
|
||||
|
||||
async def _handle_action(
|
||||
self,
|
||||
action: str,
|
||||
reasoning: str,
|
||||
emoji_query: str,
|
||||
cycle_timers: dict,
|
||||
planner_start_db_time: float
|
||||
self, action: str, reasoning: str, emoji_query: str, cycle_timers: dict, planner_start_db_time: float
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
处理规划动作
|
||||
|
||||
|
||||
参数:
|
||||
action: 动作类型
|
||||
reasoning: 决策理由
|
||||
emoji_query: 表情查询
|
||||
cycle_timers: 计时器字典
|
||||
planner_start_db_time: 规划开始时间
|
||||
|
||||
|
||||
返回:
|
||||
tuple[bool, str]: (是否执行了动作, 思考消息ID)
|
||||
"""
|
||||
action_handlers = {
|
||||
"text_reply": self._handle_text_reply,
|
||||
"emoji_reply": self._handle_emoji_reply,
|
||||
"no_reply": self._handle_no_reply
|
||||
"no_reply": self._handle_no_reply,
|
||||
}
|
||||
|
||||
|
||||
handler = action_handlers.get(action)
|
||||
if not handler:
|
||||
logger.warning(f"{self.log_prefix} 未知动作: {action}, 原因: {reasoning}")
|
||||
return False, ""
|
||||
|
||||
|
||||
try:
|
||||
if action == "text_reply":
|
||||
return await handler(reasoning, emoji_query, cycle_timers)
|
||||
@@ -457,37 +456,35 @@ class HeartFChatting:
|
||||
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||
return False, ""
|
||||
|
||||
async def _handle_text_reply(
|
||||
self, reasoning: str, emoji_query: str, cycle_timers: dict
|
||||
) -> tuple[bool, str]:
|
||||
async def _handle_text_reply(self, reasoning: str, emoji_query: str, cycle_timers: dict) -> tuple[bool, str]:
|
||||
"""
|
||||
处理文本回复
|
||||
|
||||
|
||||
工作流程:
|
||||
1. 获取锚点消息
|
||||
2. 创建思考消息
|
||||
3. 生成回复
|
||||
4. 发送消息
|
||||
|
||||
|
||||
参数:
|
||||
reasoning: 回复原因
|
||||
emoji_query: 表情查询
|
||||
cycle_timers: 计时器字典
|
||||
|
||||
|
||||
返回:
|
||||
tuple[bool, str]: (是否回复成功, 思考消息ID)
|
||||
"""
|
||||
|
||||
|
||||
# 获取锚点消息
|
||||
anchor_message = await self._get_anchor_message()
|
||||
if not anchor_message:
|
||||
raise PlannerError("无法获取锚点消息")
|
||||
|
||||
|
||||
# 创建思考消息
|
||||
thinking_id = await self._create_thinking_message(anchor_message)
|
||||
if not thinking_id:
|
||||
raise PlannerError("无法创建思考消息")
|
||||
|
||||
|
||||
try:
|
||||
# 生成回复
|
||||
with Timer("生成回复", cycle_timers):
|
||||
@@ -496,10 +493,10 @@ class HeartFChatting:
|
||||
thinking_id=thinking_id,
|
||||
reason=reasoning,
|
||||
)
|
||||
|
||||
|
||||
if not reply:
|
||||
raise ReplierError("回复生成失败")
|
||||
|
||||
|
||||
# 发送消息
|
||||
|
||||
await self._sender(
|
||||
@@ -510,7 +507,7 @@ class HeartFChatting:
|
||||
)
|
||||
|
||||
return True, thinking_id
|
||||
|
||||
|
||||
except (ReplierError, SenderError) as e:
|
||||
logger.error(f"{self.log_prefix} 回复失败: {e}")
|
||||
return True, thinking_id # 仍然返回thinking_id以便跟踪
|
||||
@@ -518,72 +515,68 @@ class HeartFChatting:
|
||||
async def _handle_emoji_reply(self, reasoning: str, emoji_query: str) -> bool:
|
||||
"""
|
||||
处理表情回复
|
||||
|
||||
|
||||
工作流程:
|
||||
1. 获取锚点消息
|
||||
2. 发送表情
|
||||
|
||||
|
||||
参数:
|
||||
reasoning: 回复原因
|
||||
emoji_query: 表情查询
|
||||
|
||||
|
||||
返回:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
logger.info(f"{self.log_prefix} 决定回复表情({emoji_query}): {reasoning}")
|
||||
|
||||
|
||||
try:
|
||||
anchor = await self._get_anchor_message()
|
||||
if not anchor:
|
||||
raise PlannerError("无法获取锚点消息")
|
||||
|
||||
|
||||
await self._handle_emoji(anchor, [], emoji_query)
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 表情发送失败: {e}")
|
||||
return False
|
||||
|
||||
async def _handle_no_reply(
|
||||
self, reasoning: str, planner_start_db_time: float, cycle_timers: dict
|
||||
) -> bool:
|
||||
async def _handle_no_reply(self, reasoning: str, planner_start_db_time: float, cycle_timers: dict) -> bool:
|
||||
"""
|
||||
处理不回复的情况
|
||||
|
||||
|
||||
工作流程:
|
||||
1. 等待新消息
|
||||
2. 超时或收到新消息时返回
|
||||
|
||||
|
||||
参数:
|
||||
reasoning: 不回复的原因
|
||||
planner_start_db_time: 规划开始时间
|
||||
cycle_timers: 计时器字典
|
||||
|
||||
|
||||
返回:
|
||||
bool: 是否成功处理
|
||||
"""
|
||||
logger.info(f"{self.log_prefix} 决定不回复: {reasoning}")
|
||||
|
||||
|
||||
observation = self.observations[0] if self.observations else None
|
||||
|
||||
|
||||
try:
|
||||
with Timer("Wait New Msg", cycle_timers):
|
||||
return await self._wait_for_new_message(observation, planner_start_db_time, self.log_prefix)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 等待被中断")
|
||||
raise
|
||||
|
||||
async def _wait_for_new_message(
|
||||
self, observation, planner_start_db_time: float, log_prefix: str
|
||||
) -> bool:
|
||||
|
||||
async def _wait_for_new_message(self, observation, planner_start_db_time: float, log_prefix: str) -> bool:
|
||||
"""
|
||||
等待新消息
|
||||
|
||||
|
||||
参数:
|
||||
observation: 观察实例
|
||||
planner_start_db_time: 开始等待的时间
|
||||
log_prefix: 日志前缀
|
||||
|
||||
|
||||
返回:
|
||||
bool: 是否检测到新消息
|
||||
"""
|
||||
@@ -592,11 +585,11 @@ class HeartFChatting:
|
||||
if await observation.has_new_messages_since(planner_start_db_time):
|
||||
logger.info(f"{log_prefix} 检测到新消息")
|
||||
return True
|
||||
|
||||
|
||||
if time.monotonic() - wait_start_time > 60:
|
||||
logger.warning(f"{log_prefix} 等待超时(60秒)")
|
||||
return False
|
||||
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
async def _should_skip_cycle(self, thinking_id: str) -> bool:
|
||||
@@ -614,13 +607,11 @@ class HeartFChatting:
|
||||
if timer_strings:
|
||||
logger.debug(f"{log_prefix} 该次决策耗时: {'; '.join(timer_strings)}")
|
||||
|
||||
async def _handle_cycle_delay(
|
||||
self, action_taken_this_cycle: bool, cycle_start_time: float, log_prefix: str
|
||||
):
|
||||
async def _handle_cycle_delay(self, action_taken_this_cycle: bool, cycle_start_time: float, log_prefix: str):
|
||||
"""处理循环延迟"""
|
||||
cycle_duration = time.monotonic() - cycle_start_time
|
||||
# if cycle_duration > 0.1:
|
||||
# logger.debug(f"{log_prefix} HeartFChatting: 周期耗时 {cycle_duration:.2f}s.")
|
||||
# logger.debug(f"{log_prefix} HeartFChatting: 周期耗时 {cycle_duration:.2f}s.")
|
||||
|
||||
try:
|
||||
sleep_duration = 0.0
|
||||
@@ -639,7 +630,7 @@ class HeartFChatting:
|
||||
async def _get_submind_thinking(self, cycle_timers: dict) -> str:
|
||||
"""
|
||||
获取子思维的思考结果
|
||||
|
||||
|
||||
返回:
|
||||
str: 思考结果,如果思考失败则返回错误信息
|
||||
"""
|
||||
@@ -666,7 +657,7 @@ class HeartFChatting:
|
||||
async def _planner(self, current_mind: str, cycle_timers: dict, is_re_planned: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定是否和如何回复。
|
||||
|
||||
|
||||
参数:
|
||||
current_mind: 子思维的当前思考结果
|
||||
"""
|
||||
@@ -734,7 +725,9 @@ class HeartFChatting:
|
||||
action = arguments.get("action", "no_reply")
|
||||
# 验证动作是否在可用动作集中
|
||||
if action not in self.action_manager.get_available_actions():
|
||||
logger.warning(f"{self.log_prefix}[Planner] LLM返回了未授权的动作: {action},使用默认动作no_reply")
|
||||
logger.warning(
|
||||
f"{self.log_prefix}[Planner] LLM返回了未授权的动作: {action},使用默认动作no_reply"
|
||||
)
|
||||
action = "no_reply"
|
||||
reasoning = f"LLM返回了未授权的动作: {action}"
|
||||
else:
|
||||
@@ -742,7 +735,9 @@ class HeartFChatting:
|
||||
emoji_query = arguments.get("emoji_query", "")
|
||||
|
||||
# 记录决策结果
|
||||
logger.debug(f"{self.log_prefix}[要做什么]\nPrompt:\n{prompt}\n\n决策结果: {action}, 理由: {reasoning}, 表情查询: '{emoji_query}'")
|
||||
logger.debug(
|
||||
f"{self.log_prefix}[要做什么]\nPrompt:\n{prompt}\n\n决策结果: {action}, 理由: {reasoning}, 表情查询: '{emoji_query}'"
|
||||
)
|
||||
else:
|
||||
# 处理工具调用失败
|
||||
logger.warning(f"{self.log_prefix}[Planner] {error_msg}")
|
||||
@@ -926,9 +921,6 @@ class HeartFChatting:
|
||||
message=anchor_message, # Pass anchor_message positionally (matches 'message' parameter)
|
||||
thinking_id=thinking_id, # Pass thinking_id positionally
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
if not response_set:
|
||||
logger.warning(f"{self.log_prefix}[Replier-{thinking_id}] LLM生成了一个空回复集。")
|
||||
@@ -980,8 +972,7 @@ class HeartFChatting:
|
||||
# 记录锚点消息ID
|
||||
if self._current_cycle and anchor_message:
|
||||
self._current_cycle.set_response_info(
|
||||
response_text=response_set,
|
||||
anchor_message_id=anchor_message.message_info.message_id
|
||||
response_text=response_set, anchor_message_id=anchor_message.message_info.message_id
|
||||
)
|
||||
|
||||
chat = anchor_message.chat_stream
|
||||
@@ -1056,9 +1047,7 @@ class HeartFChatting:
|
||||
emoji_path, description = emoji_raw
|
||||
# 记录表情信息
|
||||
if self._current_cycle:
|
||||
self._current_cycle.set_response_info(
|
||||
emoji_info=f"表情: {description}, 路径: {emoji_path}"
|
||||
)
|
||||
self._current_cycle.set_response_info(emoji_info=f"表情: {description}, 路径: {emoji_path}")
|
||||
|
||||
emoji_cq = image_path_to_base64(emoji_path)
|
||||
thinking_time_point = round(time.time(), 2)
|
||||
@@ -1083,10 +1072,10 @@ class HeartFChatting:
|
||||
|
||||
def get_cycle_history(self, last_n: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""获取循环历史记录
|
||||
|
||||
|
||||
参数:
|
||||
last_n: 获取最近n个循环的信息,如果为None则获取所有历史记录
|
||||
|
||||
|
||||
返回:
|
||||
List[Dict[str, Any]]: 循环历史记录列表
|
||||
"""
|
||||
@@ -1100,4 +1089,3 @@ class HeartFChatting:
|
||||
if self._cycle_history:
|
||||
return self._cycle_history[-1].to_dict()
|
||||
return None
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from .heartflow_prompt_builder import prompt_builder
|
||||
from ..chat.utils import process_llm_response
|
||||
from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG
|
||||
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
||||
from ..utils.timer_calculater import Timer
|
||||
from ..utils.timer_calculator import Timer
|
||||
|
||||
from src.plugins.moods.moods import MoodManager
|
||||
|
||||
|
||||
@@ -5,12 +5,12 @@ from ...config.config import global_config
|
||||
from ..chat.message import MessageRecv
|
||||
from ..storage.storage import MessageStorage
|
||||
from ..chat.utils import is_mentioned_bot_in_message
|
||||
from ..message import Seg
|
||||
from maim_message import Seg
|
||||
from src.heart_flow.heartflow import heartflow
|
||||
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
||||
from ..chat.chat_stream import chat_manager
|
||||
from ..chat.message_buffer import message_buffer
|
||||
from ..utils.timer_calculater import Timer
|
||||
from ..utils.timer_calculator import Timer
|
||||
from src.plugins.person_info.relationship_manager import relationship_manager
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -24,14 +24,14 @@ logger = get_module_logger("heartflow_processor", config=processor_config)
|
||||
|
||||
class HeartFCProcessor:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def _handle_error(self, error: Exception, context: str, message: Optional[MessageRecv] = None) -> None:
|
||||
"""统一的错误处理函数
|
||||
|
||||
|
||||
Args:
|
||||
error: 捕获到的异常
|
||||
context: 错误发生的上下文描述
|
||||
@@ -39,12 +39,12 @@ class HeartFCProcessor:
|
||||
"""
|
||||
logger.error(f"{context}: {error}")
|
||||
logger.error(traceback.format_exc())
|
||||
if message and hasattr(message, 'raw_message'):
|
||||
if message and hasattr(message, "raw_message"):
|
||||
logger.error(f"相关消息原始内容: {message.raw_message}")
|
||||
|
||||
async def _process_relationship(self, message: MessageRecv) -> None:
|
||||
"""处理用户关系逻辑
|
||||
|
||||
|
||||
Args:
|
||||
message: 消息对象,包含用户信息
|
||||
"""
|
||||
@@ -54,24 +54,20 @@ class HeartFCProcessor:
|
||||
cardname = message.message_info.user_info.user_cardname or nickname
|
||||
|
||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
||||
|
||||
|
||||
if not is_known:
|
||||
logger.info(f"首次认识用户: {nickname}")
|
||||
await relationship_manager.first_knowing_some_one(
|
||||
platform, user_id, nickname, cardname, ""
|
||||
)
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
|
||||
elif not await relationship_manager.is_qved_name(platform, user_id):
|
||||
logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
|
||||
await relationship_manager.first_knowing_some_one(
|
||||
platform, user_id, nickname, cardname, ""
|
||||
)
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
|
||||
|
||||
async def _calculate_interest(self, message: MessageRecv) -> Tuple[float, bool]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
|
||||
Args:
|
||||
message: 待处理的消息对象
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[float, bool]: (兴趣度, 是否被提及)
|
||||
"""
|
||||
@@ -93,33 +89,35 @@ class HeartFCProcessor:
|
||||
|
||||
def _get_message_type(self, message: MessageRecv) -> str:
|
||||
"""获取消息类型
|
||||
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
|
||||
|
||||
Returns:
|
||||
str: 消息类型
|
||||
"""
|
||||
if message.message_segment.type != "seglist":
|
||||
return message.message_segment.type
|
||||
|
||||
if (isinstance(message.message_segment.data, list)
|
||||
|
||||
if (
|
||||
isinstance(message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in message.message_segment.data)
|
||||
and len(message.message_segment.data) == 1):
|
||||
and len(message.message_segment.data) == 1
|
||||
):
|
||||
return message.message_segment.data[0].type
|
||||
|
||||
|
||||
return "seglist"
|
||||
|
||||
async def process_message(self, message_data: str) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
|
||||
主要流程:
|
||||
1. 消息解析与初始化
|
||||
2. 消息缓冲处理
|
||||
3. 过滤检查
|
||||
4. 兴趣度计算
|
||||
5. 关系处理
|
||||
|
||||
|
||||
Args:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
@@ -133,20 +131,21 @@ class HeartFCProcessor:
|
||||
|
||||
# 2. 消息缓冲与流程序化
|
||||
await message_buffer.start_caching_messages(message)
|
||||
|
||||
|
||||
chat = await chat_manager.get_or_create_stream(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
)
|
||||
|
||||
|
||||
subheartflow = await heartflow.create_subheartflow(chat.stream_id)
|
||||
message.update_chat_stream(chat)
|
||||
await message.process()
|
||||
|
||||
|
||||
# 3. 过滤检查
|
||||
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or \
|
||||
self._check_ban_regex(message.raw_message, chat, userinfo):
|
||||
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
|
||||
message.raw_message, chat, userinfo
|
||||
):
|
||||
return
|
||||
|
||||
# 4. 缓冲检查
|
||||
@@ -156,7 +155,7 @@ class HeartFCProcessor:
|
||||
type_messages = {
|
||||
"text": f"触发缓冲,消息:{message.processed_plain_text}",
|
||||
"image": "触发缓冲,表情包/图片等待中",
|
||||
"seglist": "触发缓冲,消息列表等待中"
|
||||
"seglist": "触发缓冲,消息列表等待中",
|
||||
}
|
||||
logger.debug(type_messages.get(msg_type, "触发未知类型缓冲"))
|
||||
return
|
||||
@@ -189,12 +188,12 @@ class HeartFCProcessor:
|
||||
|
||||
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
|
||||
"""检查消息是否包含过滤词
|
||||
|
||||
|
||||
Args:
|
||||
text: 待检查的文本
|
||||
chat: 聊天对象
|
||||
userinfo: 用户信息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否包含过滤词
|
||||
"""
|
||||
@@ -208,12 +207,12 @@ class HeartFCProcessor:
|
||||
|
||||
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式
|
||||
|
||||
|
||||
Args:
|
||||
text: 待检查的文本
|
||||
chat: 聊天对象
|
||||
userinfo: 用户信息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配过滤正则
|
||||
"""
|
||||
|
||||
@@ -37,12 +37,15 @@ def init_prompt():
|
||||
{moderation_prompt}。注意:回复不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
"heart_flow_prompt",
|
||||
)
|
||||
|
||||
Prompt("""
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你有以下信息可供参考:
|
||||
{structured_info}
|
||||
以上的消息是你获取到的消息,或许可以帮助你更好地回复。
|
||||
""", "info_from_tools")
|
||||
""",
|
||||
"info_from_tools",
|
||||
)
|
||||
|
||||
# Planner提示词 - 优化版
|
||||
Prompt(
|
||||
@@ -187,11 +190,11 @@ class PromptBuilder:
|
||||
prompt_ger += "你喜欢用倒装句"
|
||||
if random.random() < 0.02:
|
||||
prompt_ger += "你喜欢用反问句"
|
||||
|
||||
|
||||
if structured_info:
|
||||
structured_info_prompt = await global_prompt_manager.format_prompt(
|
||||
"info_from_tools",
|
||||
structured_info = structured_info)
|
||||
"info_from_tools", structured_info=structured_info
|
||||
)
|
||||
else:
|
||||
structured_info_prompt = ""
|
||||
|
||||
|
||||
@@ -12,12 +12,12 @@ from ..chat.message import MessageSending, MessageRecv, MessageThinking, Message
|
||||
from ..chat.message_sender import message_manager
|
||||
from ..chat.utils_image import image_path_to_base64
|
||||
from ..willing.willing_manager import willing_manager
|
||||
from ..message import UserInfo, Seg
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
||||
from src.plugins.chat.chat_stream import ChatStream, chat_manager
|
||||
from src.plugins.person_info.relationship_manager import relationship_manager
|
||||
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
||||
from src.plugins.utils.timer_calculater import Timer
|
||||
from src.plugins.utils.timer_calculator import Timer
|
||||
|
||||
# 定义日志配置
|
||||
chat_config = LogConfig(
|
||||
|
||||
@@ -5,7 +5,7 @@ from ...config.config import global_config
|
||||
from ..chat.message import MessageThinking
|
||||
from .heartflow_prompt_builder import prompt_builder
|
||||
from ..chat.utils import process_llm_response
|
||||
from ..utils.timer_calculater import Timer
|
||||
from ..utils.timer_calculator import Timer
|
||||
from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG
|
||||
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
||||
|
||||
|
||||
@@ -3,23 +3,8 @@
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from .api import global_api
|
||||
from .message_base import (
|
||||
Seg,
|
||||
GroupInfo,
|
||||
UserInfo,
|
||||
FormatInfo,
|
||||
TemplateInfo,
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Seg",
|
||||
"global_api",
|
||||
"GroupInfo",
|
||||
"UserInfo",
|
||||
"FormatInfo",
|
||||
"TemplateInfo",
|
||||
"BaseMessageInfo",
|
||||
"MessageBase",
|
||||
]
|
||||
|
||||
@@ -1,250 +1,6 @@
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from typing import Dict, Any, Callable, List, Set, Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.message.message_base import MessageBase
|
||||
from src.common.server import global_server
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import uvicorn
|
||||
import os
|
||||
import traceback
|
||||
|
||||
logger = get_module_logger("api")
|
||||
|
||||
|
||||
class BaseMessageHandler:
|
||||
"""消息处理基类"""
|
||||
|
||||
def __init__(self):
|
||||
self.message_handlers: List[Callable] = []
|
||||
self.background_tasks = set()
|
||||
|
||||
def register_message_handler(self, handler: Callable):
|
||||
"""注册消息处理函数"""
|
||||
self.message_handlers.append(handler)
|
||||
|
||||
async def process_message(self, message: Dict[str, Any]):
|
||||
"""处理单条消息"""
|
||||
tasks = []
|
||||
for handler in self.message_handlers:
|
||||
try:
|
||||
tasks.append(handler(message))
|
||||
except Exception as e:
|
||||
logger.error(f"消息处理出错: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
# 不抛出异常,而是记录错误并继续处理其他消息
|
||||
continue
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def _handle_message(self, message: Dict[str, Any]):
|
||||
"""后台处理单个消息"""
|
||||
try:
|
||||
await self.process_message(message)
|
||||
except Exception as e:
|
||||
raise RuntimeError(str(e)) from e
|
||||
|
||||
|
||||
class MessageServer(BaseMessageHandler):
|
||||
"""WebSocket服务端"""
|
||||
|
||||
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 18000,
|
||||
enable_token=False,
|
||||
app: Optional[FastAPI] = None,
|
||||
path: str = "/ws",
|
||||
):
|
||||
super().__init__()
|
||||
# 将类级别的处理器添加到实例处理器中
|
||||
self.message_handlers.extend(self._class_handlers)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.path = path
|
||||
self.app = app or FastAPI()
|
||||
self.own_app = app is None # 标记是否使用自己创建的app
|
||||
self.active_websockets: Set[WebSocket] = set()
|
||||
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
|
||||
self.valid_tokens: Set[str] = set()
|
||||
self.enable_token = enable_token
|
||||
self._setup_routes()
|
||||
self._running = False
|
||||
|
||||
def _setup_routes(self):
|
||||
@self.app.post("/api/message")
|
||||
async def handle_message(message: Dict[str, Any]):
|
||||
try:
|
||||
# 创建后台任务处理消息
|
||||
asyncio.create_task(self._handle_message(message))
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
@self.app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
headers = dict(websocket.headers)
|
||||
token = headers.get("authorization")
|
||||
platform = headers.get("platform", "default") # 获取platform标识
|
||||
if self.enable_token:
|
||||
if not token or not await self.verify_token(token):
|
||||
await websocket.close(code=1008, reason="Invalid or missing token")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
self.active_websockets.add(websocket)
|
||||
|
||||
# 添加到platform映射
|
||||
if platform not in self.platform_websockets:
|
||||
self.platform_websockets[platform] = websocket
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_json()
|
||||
# print(f"Received message: {message}")
|
||||
asyncio.create_task(self._handle_message(message))
|
||||
except WebSocketDisconnect:
|
||||
self._remove_websocket(websocket, platform)
|
||||
except Exception as e:
|
||||
self._remove_websocket(websocket, platform)
|
||||
raise RuntimeError(str(e)) from e
|
||||
finally:
|
||||
self._remove_websocket(websocket, platform)
|
||||
|
||||
@classmethod
|
||||
def register_class_handler(cls, handler: Callable):
|
||||
"""注册类级别的消息处理器"""
|
||||
if handler not in cls._class_handlers:
|
||||
cls._class_handlers.append(handler)
|
||||
|
||||
def register_message_handler(self, handler: Callable):
|
||||
"""注册实例级别的消息处理器"""
|
||||
if handler not in self.message_handlers:
|
||||
self.message_handlers.append(handler)
|
||||
|
||||
async def verify_token(self, token: str) -> bool:
|
||||
if not self.enable_token:
|
||||
return True
|
||||
return token in self.valid_tokens
|
||||
|
||||
def add_valid_token(self, token: str):
|
||||
self.valid_tokens.add(token)
|
||||
|
||||
def remove_valid_token(self, token: str):
|
||||
self.valid_tokens.discard(token)
|
||||
|
||||
def run_sync(self):
|
||||
"""同步方式运行服务器"""
|
||||
if not self.own_app:
|
||||
raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法")
|
||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||
|
||||
async def run(self):
|
||||
"""异步方式运行服务器"""
|
||||
self._running = True
|
||||
try:
|
||||
if self.own_app:
|
||||
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
|
||||
# 禁用 uvicorn 默认日志和访问日志
|
||||
config = uvicorn.Config(
|
||||
self.app, host=self.host, port=self.port, loop="asyncio", log_config=None, access_log=False
|
||||
)
|
||||
self.server = uvicorn.Server(config)
|
||||
await self.server.serve()
|
||||
else:
|
||||
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
await self.stop()
|
||||
raise
|
||||
except Exception as e:
|
||||
await self.stop()
|
||||
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
async def start_server(self):
|
||||
"""启动服务器的异步方法"""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
await self.run()
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务器"""
|
||||
# 清理platform映射
|
||||
self.platform_websockets.clear()
|
||||
|
||||
# 取消所有后台任务
|
||||
for task in self.background_tasks:
|
||||
task.cancel()
|
||||
# 等待所有任务完成
|
||||
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
||||
self.background_tasks.clear()
|
||||
|
||||
# 关闭所有WebSocket连接
|
||||
for websocket in self.active_websockets:
|
||||
await websocket.close()
|
||||
self.active_websockets.clear()
|
||||
|
||||
if hasattr(self, "server") and self.own_app:
|
||||
self._running = False
|
||||
# 正确关闭 uvicorn 服务器
|
||||
self.server.should_exit = True
|
||||
await self.server.shutdown()
|
||||
# 等待服务器完全停止
|
||||
if hasattr(self.server, "started") and self.server.started:
|
||||
await self.server.main_loop()
|
||||
# 清理处理程序
|
||||
self.message_handlers.clear()
|
||||
|
||||
def _remove_websocket(self, websocket: WebSocket, platform: str):
|
||||
"""从所有集合中移除websocket"""
|
||||
if websocket in self.active_websockets:
|
||||
self.active_websockets.remove(websocket)
|
||||
if platform in self.platform_websockets:
|
||||
if self.platform_websockets[platform] == websocket:
|
||||
del self.platform_websockets[platform]
|
||||
|
||||
async def broadcast_message(self, message: Dict[str, Any]):
|
||||
disconnected = set()
|
||||
for websocket in self.active_websockets:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception:
|
||||
disconnected.add(websocket)
|
||||
for websocket in disconnected:
|
||||
self.active_websockets.remove(websocket)
|
||||
|
||||
async def broadcast_to_platform(self, platform: str, message: Dict[str, Any]):
|
||||
"""向指定平台的所有WebSocket客户端广播消息"""
|
||||
if platform not in self.platform_websockets:
|
||||
raise ValueError(f"平台:{platform} 未连接")
|
||||
|
||||
disconnected = set()
|
||||
try:
|
||||
await self.platform_websockets[platform].send_json(message)
|
||||
except Exception:
|
||||
disconnected.add(self.platform_websockets[platform])
|
||||
|
||||
# 清理断开的连接
|
||||
for websocket in disconnected:
|
||||
self._remove_websocket(websocket, platform)
|
||||
|
||||
async def send_message(self, message: MessageBase):
|
||||
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
|
||||
|
||||
@staticmethod
|
||||
async def send_message_rest(url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送消息到指定端点"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
raise e
|
||||
from maim_message import MessageServer
|
||||
|
||||
|
||||
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Optional, Union, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seg:
|
||||
"""消息片段类,用于表示消息的不同部分
|
||||
|
||||
Attributes:
|
||||
type: 片段类型,可以是 'text'、'image'、'seglist' 等
|
||||
data: 片段的具体内容
|
||||
- 对于 text 类型,data 是字符串
|
||||
- 对于 image 类型,data 是 base64 字符串
|
||||
- 对于 seglist 类型,data 是 Seg 列表
|
||||
"""
|
||||
|
||||
type: str
|
||||
data: Union[str, List["Seg"]]
|
||||
|
||||
# def __init__(self, type: str, data: Union[str, List['Seg']],):
|
||||
# """初始化实例,确保字典和属性同步"""
|
||||
# # 先初始化字典
|
||||
# self.type = type
|
||||
# self.data = data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "Seg":
|
||||
"""从字典创建Seg实例"""
|
||||
type = data.get("type")
|
||||
data = data.get("data")
|
||||
if type == "seglist":
|
||||
data = [Seg.from_dict(seg) for seg in data]
|
||||
return cls(type=type, data=data)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
result = {"type": self.type}
|
||||
if self.type == "seglist":
|
||||
result["data"] = [seg.to_dict() for seg in self.data]
|
||||
else:
|
||||
result["data"] = self.data
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupInfo:
|
||||
"""群组信息类"""
|
||||
|
||||
platform: Optional[str] = None
|
||||
group_id: Optional[int] = None
|
||||
group_name: Optional[str] = None # 群名称
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "GroupInfo":
|
||||
"""从字典创建GroupInfo实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
GroupInfo: 新的实例
|
||||
"""
|
||||
if data.get("group_id") is None:
|
||||
return None
|
||||
return cls(
|
||||
platform=data.get("platform"), group_id=data.get("group_id"), group_name=data.get("group_name", None)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInfo:
|
||||
"""用户信息类"""
|
||||
|
||||
platform: Optional[str] = None
|
||||
user_id: Optional[int] = None
|
||||
user_nickname: Optional[str] = None # 用户昵称
|
||||
user_cardname: Optional[str] = None # 用户群昵称
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "UserInfo":
|
||||
"""从字典创建UserInfo实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
UserInfo: 新的实例
|
||||
"""
|
||||
return cls(
|
||||
platform=data.get("platform"),
|
||||
user_id=data.get("user_id"),
|
||||
user_nickname=data.get("user_nickname", None),
|
||||
user_cardname=data.get("user_cardname", None),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormatInfo:
|
||||
"""格式信息类"""
|
||||
|
||||
"""
|
||||
目前maimcore可接受的格式为text,image,emoji
|
||||
可发送的格式为text,emoji,reply
|
||||
"""
|
||||
|
||||
content_format: Optional[str] = None
|
||||
accept_format: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "FormatInfo":
|
||||
"""从字典创建FormatInfo实例
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
Returns:
|
||||
FormatInfo: 新的实例
|
||||
"""
|
||||
return cls(
|
||||
content_format=data.get("content_format"),
|
||||
accept_format=data.get("accept_format"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateInfo:
|
||||
"""模板信息类"""
|
||||
|
||||
template_items: Optional[Dict] = None
|
||||
template_name: Optional[str] = None
|
||||
template_default: bool = True
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "TemplateInfo":
|
||||
"""从字典创建TemplateInfo实例
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
Returns:
|
||||
TemplateInfo: 新的实例
|
||||
"""
|
||||
return cls(
|
||||
template_items=data.get("template_items"),
|
||||
template_name=data.get("template_name"),
|
||||
template_default=data.get("template_default", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseMessageInfo:
|
||||
"""消息信息类"""
|
||||
|
||||
platform: Optional[str] = None
|
||||
message_id: Union[str, int, None] = None
|
||||
time: Optional[float] = None
|
||||
group_info: Optional[GroupInfo] = None
|
||||
user_info: Optional[UserInfo] = None
|
||||
format_info: Optional[FormatInfo] = None
|
||||
template_info: Optional[TemplateInfo] = None
|
||||
additional_config: Optional[dict] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
result = {}
|
||||
for field, value in asdict(self).items():
|
||||
if value is not None:
|
||||
if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)):
|
||||
result[field] = value.to_dict()
|
||||
else:
|
||||
result[field] = value
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "BaseMessageInfo":
|
||||
"""从字典创建BaseMessageInfo实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
BaseMessageInfo: 新的实例
|
||||
"""
|
||||
group_info = GroupInfo.from_dict(data.get("group_info", {}))
|
||||
user_info = UserInfo.from_dict(data.get("user_info", {}))
|
||||
format_info = FormatInfo.from_dict(data.get("format_info", {}))
|
||||
template_info = TemplateInfo.from_dict(data.get("template_info", {}))
|
||||
return cls(
|
||||
platform=data.get("platform"),
|
||||
message_id=data.get("message_id"),
|
||||
time=data.get("time"),
|
||||
additional_config=data.get("additional_config", None),
|
||||
group_info=group_info,
|
||||
user_info=user_info,
|
||||
format_info=format_info,
|
||||
template_info=template_info,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageBase:
|
||||
"""消息类"""
|
||||
|
||||
message_info: BaseMessageInfo
|
||||
message_segment: Seg
|
||||
raw_message: Optional[str] = None # 原始消息,包含未解析的cq码
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式
|
||||
|
||||
Returns:
|
||||
Dict: 包含所有非None字段的字典,其中:
|
||||
- message_info: 转换为字典格式
|
||||
- message_segment: 转换为字典格式
|
||||
- raw_message: 如果存在则包含
|
||||
"""
|
||||
result = {"message_info": self.message_info.to_dict(), "message_segment": self.message_segment.to_dict()}
|
||||
if self.raw_message is not None:
|
||||
result["raw_message"] = self.raw_message
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "MessageBase":
|
||||
"""从字典创建MessageBase实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
MessageBase: 新的实例
|
||||
"""
|
||||
message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
|
||||
message_segment = Seg.from_dict(data.get("message_segment", {}))
|
||||
raw_message = data.get("raw_message", None)
|
||||
return cls(message_info=message_info, message_segment=message_segment, raw_message=raw_message)
|
||||
Reference in New Issue
Block a user