Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -321,9 +321,7 @@ CHAT_STYLE_CONFIG = {
|
|||||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
|
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
|
||||||
},
|
},
|
||||||
"simple": {
|
"simple": {
|
||||||
"console_format": (
|
"console_format": ("<level>{time:MM-DD HH:mm}</level> | <green>见闻</green> | <green>{message}</green>"), # noqa: E501
|
||||||
"<level>{time:MM-DD HH:mm}</level> | <green>见闻</green> | <green>{message}</green>"
|
|
||||||
), # noqa: E501
|
|
||||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
|
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -272,7 +272,6 @@ class SubHeartflowManager:
|
|||||||
current_state = current_mai_state.get_current_state()
|
current_state = current_mai_state.get_current_state()
|
||||||
focused_limit = current_state.get_focused_chat_max_num()
|
focused_limit = current_state.get_focused_chat_max_num()
|
||||||
|
|
||||||
|
|
||||||
if int(time.time()) % 20 == 0: # 每20秒输出一次
|
if int(time.time()) % 20 == 0: # 每20秒输出一次
|
||||||
logger.debug(f"{log_prefix} 当前状态 ({current_state.value}) 可以在{focused_limit}个群激情聊天")
|
logger.debug(f"{log_prefix} 当前状态 ({current_state.value}) 可以在{focused_limit}个群激情聊天")
|
||||||
|
|
||||||
@@ -288,7 +287,7 @@ class SubHeartflowManager:
|
|||||||
states_num = (
|
states_num = (
|
||||||
self.count_subflows_by_state(ChatState.ABSENT),
|
self.count_subflows_by_state(ChatState.ABSENT),
|
||||||
self.count_subflows_by_state(ChatState.CHAT),
|
self.count_subflows_by_state(ChatState.CHAT),
|
||||||
current_focused_count
|
current_focused_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
for sub_hf in list(self.subheartflows.values()):
|
for sub_hf in list(self.subheartflows.values()):
|
||||||
@@ -300,6 +299,7 @@ class SubHeartflowManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
from .mai_state_manager import enable_unlimited_hfc_chat
|
from .mai_state_manager import enable_unlimited_hfc_chat
|
||||||
|
|
||||||
if not enable_unlimited_hfc_chat:
|
if not enable_unlimited_hfc_chat:
|
||||||
if sub_hf.chat_state.chat_status != ChatState.CHAT:
|
if sub_hf.chat_state.chat_status != ChatState.CHAT:
|
||||||
continue
|
continue
|
||||||
@@ -326,11 +326,11 @@ class SubHeartflowManager:
|
|||||||
await current_subflow.set_chat_state(ChatState.FOCUSED, states_num)
|
await current_subflow.set_chat_state(ChatState.FOCUSED, states_num)
|
||||||
|
|
||||||
# 验证提升结果
|
# 验证提升结果
|
||||||
if (final_subflow := self.subheartflows.get(flow_id)) and \
|
if (
|
||||||
final_subflow.chat_state.chat_status == ChatState.FOCUSED:
|
final_subflow := self.subheartflows.get(flow_id)
|
||||||
|
) and final_subflow.chat_state.chat_status == ChatState.FOCUSED:
|
||||||
current_focused_count += 1
|
current_focused_count += 1
|
||||||
|
|
||||||
|
|
||||||
async def randomly_deactivate_subflows(self, deactivation_probability: float = 0.1):
|
async def randomly_deactivate_subflows(self, deactivation_probability: float = 0.1):
|
||||||
"""以一定概率将 FOCUSED 或 CHAT 状态的子心流回退到 ABSENT 状态。"""
|
"""以一定概率将 FOCUSED 或 CHAT 状态的子心流回退到 ABSENT 状态。"""
|
||||||
log_prefix_manager = "[子心流管理器-随机停用]"
|
log_prefix_manager = "[子心流管理器-随机停用]"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import asyncio
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from ..message.message_base import UserInfo
|
from maim_message import UserInfo
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||||
from .message_storage import MongoDBMessageStorage
|
from .message_storage import MongoDBMessageStorage
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from .observation_info import ObservationInfo
|
|||||||
from .conversation_info import ConversationInfo
|
from .conversation_info import ConversationInfo
|
||||||
from .reply_generator import ReplyGenerator
|
from .reply_generator import ReplyGenerator
|
||||||
from ..chat.chat_stream import ChatStream
|
from ..chat.chat_stream import ChatStream
|
||||||
from ..message.message_base import UserInfo
|
from maim_message import UserInfo
|
||||||
from src.plugins.chat.chat_stream import chat_manager
|
from src.plugins.chat.chat_stream import chat_manager
|
||||||
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
||||||
from .waiter import Waiter
|
from .waiter import Waiter
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from ..chat.chat_stream import ChatStream
|
from ..chat.chat_stream import ChatStream
|
||||||
from ..chat.message import Message
|
from ..chat.message import Message
|
||||||
from ..message.message_base import Seg
|
from maim_message import Seg
|
||||||
from src.plugins.chat.message import MessageSending, MessageSet
|
from src.plugins.chat.message import MessageSending, MessageSet
|
||||||
from src.plugins.chat.message_sender import message_manager
|
from src.plugins.chat.message_sender import message_manager
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Programmable Friendly Conversationalist
|
# Programmable Friendly Conversationalist
|
||||||
# Prefrontal cortex
|
# Prefrontal cortex
|
||||||
from typing import List, Optional, Dict, Any, Set
|
from typing import List, Optional, Dict, Any, Set
|
||||||
from ..message.message_base import UserInfo
|
from maim_message import UserInfo
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import datetime
|
|||||||
from typing import List, Optional, Tuple, TYPE_CHECKING
|
from typing import List, Optional, Tuple, TYPE_CHECKING
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from ..chat.chat_stream import ChatStream
|
from ..chat.chat_stream import ChatStream
|
||||||
from ..message.message_base import UserInfo, Seg
|
from maim_message import UserInfo, Seg
|
||||||
from ..chat.message import Message
|
from ..chat.message import Message
|
||||||
from ..models.utils_model import LLMRequest
|
from ..models.utils_model import LLMRequest
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
@@ -371,21 +371,10 @@ class DirectMessageSender:
|
|||||||
# 处理消息
|
# 处理消息
|
||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
message_json = message.to_dict()
|
_message_json = message.to_dict()
|
||||||
|
|
||||||
# 发送消息
|
# 发送消息
|
||||||
try:
|
try:
|
||||||
end_point = global_config.api_urls.get(message.message_info.platform, None)
|
|
||||||
if end_point:
|
|
||||||
# logger.info(f"发送消息到{end_point}")
|
|
||||||
# logger.info(message_json)
|
|
||||||
try:
|
|
||||||
await global_api.send_message_REST(end_point, message_json)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"REST方式发送失败,出现错误: {str(e)}")
|
|
||||||
logger.info("尝试使用ws发送")
|
|
||||||
await self.send_via_ws(message)
|
|
||||||
else:
|
|
||||||
await self.send_via_ws(message)
|
await self.send_via_ws(message)
|
||||||
logger.success(f"PFC消息已发送: {content}")
|
logger.success(f"PFC消息已发送: {content}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from src.common.logger import get_module_logger
|
|||||||
from ..models.utils_model import LLMRequest
|
from ..models.utils_model import LLMRequest
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from .chat_observer import ChatObserver
|
from .chat_observer import ChatObserver
|
||||||
from ..message.message_base import UserInfo
|
from maim_message import UserInfo
|
||||||
|
|
||||||
logger = get_module_logger("reply_checker")
|
logger = get_module_logger("reply_checker")
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from ..message.message_base import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
|
|
||||||
from src.common.logger import get_module_logger, LogConfig, CHAT_STREAM_STYLE_CONFIG
|
from src.common.logger import get_module_logger, LogConfig, CHAT_STREAM_STYLE_CONFIG
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import urllib3
|
|||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from .utils_image import image_manager
|
from .utils_image import image_manager
|
||||||
from ..message.message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
|
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
|
|
||||||
logger = get_module_logger("chat_message")
|
logger = get_module_logger("chat_message")
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from src.common.logger import get_module_logger
|
|||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from .message import MessageRecv
|
from .message import MessageRecv
|
||||||
from ..message.message_base import BaseMessageInfo, GroupInfo, Seg
|
from maim_message import BaseMessageInfo, GroupInfo, Seg
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|||||||
@@ -62,19 +62,9 @@ class MessageSender:
|
|||||||
# logger.trace(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志
|
# logger.trace(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志
|
||||||
# --- 结束打字延迟 ---
|
# --- 结束打字延迟 ---
|
||||||
|
|
||||||
message_json = message.to_dict()
|
|
||||||
message_preview = truncate_message(message.processed_plain_text)
|
message_preview = truncate_message(message.processed_plain_text)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
end_point = global_config.api_urls.get(message.message_info.platform, None)
|
|
||||||
if end_point:
|
|
||||||
try:
|
|
||||||
await global_api.send_message_rest(end_point, message_json)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"REST发送失败: {str(e)}")
|
|
||||||
logger.info(f"[{message.chat_stream.stream_id}] 尝试使用WS发送")
|
|
||||||
await self.send_via_ws(message)
|
|
||||||
else:
|
|
||||||
await self.send_via_ws(message)
|
await self.send_via_ws(message)
|
||||||
logger.success(f"发送消息 '{message_preview}' 成功") # 调整日志格式
|
logger.success(f"发送消息 '{message_preview}' 成功") # 调整日志格式
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from ..models.utils_model import LLMRequest
|
|||||||
from ..utils.typo_generator import ChineseTypoGenerator
|
from ..utils.typo_generator import ChineseTypoGenerator
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from .message import MessageRecv, Message
|
from .message import MessageRecv, Message
|
||||||
from ..message.message_base import UserInfo
|
from maim_message import UserInfo
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from ..moods.moods import MoodManager
|
from ..moods.moods import MoodManager
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
|
|||||||
@@ -29,10 +29,11 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
|||||||
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
||||||
|
|
||||||
|
|
||||||
'''
|
"""
|
||||||
还没经过测试,有些地方数据库和内存数据同步可能不完全
|
还没经过测试,有些地方数据库和内存数据同步可能不完全
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class MaiEmoji:
|
class MaiEmoji:
|
||||||
"""定义一个表情包"""
|
"""定义一个表情包"""
|
||||||
@@ -316,7 +317,9 @@ class EmojiManager:
|
|||||||
|
|
||||||
time_end = time.time()
|
time_end = time.time()
|
||||||
|
|
||||||
logger.info(f"找到[{text_emotion}]表情包,用时:{time_end - time_start:.2f}秒: {selected_emoji.description} (相似度: {similarity:.4f})")
|
logger.info(
|
||||||
|
f"找到[{text_emotion}]表情包,用时:{time_end - time_start:.2f}秒: {selected_emoji.description} (相似度: {similarity:.4f})"
|
||||||
|
)
|
||||||
return selected_emoji.path, f"[ {selected_emoji.description} ]"
|
return selected_emoji.path, f"[ {selected_emoji.description} ]"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -785,7 +788,6 @@ class EmojiManager:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def clear_temp_emoji(self):
|
async def clear_temp_emoji(self):
|
||||||
"""每天清理临时表情包
|
"""每天清理临时表情包
|
||||||
清理/data/emoji和/data/image目录下的所有文件
|
清理/data/emoji和/data/image目录下的所有文件
|
||||||
@@ -821,7 +823,5 @@ class EmojiManager:
|
|||||||
logger.success("[清理] 临时文件清理完成")
|
logger.success("[清理] 临时文件清理完成")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 创建全局单例
|
# 创建全局单例
|
||||||
emoji_manager = EmojiManager()
|
emoji_manager = EmojiManager()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from src.common.logger import get_module_logger, LogConfig, PFC_STYLE_CONFIG #
|
|||||||
from src.plugins.models.utils_model import LLMRequest
|
from src.plugins.models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.plugins.chat.utils_image import image_path_to_base64 # Local import needed after move
|
from src.plugins.chat.utils_image import image_path_to_base64 # Local import needed after move
|
||||||
from src.plugins.utils.timer_calculater import Timer # <--- Import Timer
|
from src.plugins.utils.timer_calculator import Timer # <--- Import Timer
|
||||||
from src.plugins.heartFC_chat.heartFC_generator import HeartFCGenerator
|
from src.plugins.heartFC_chat.heartFC_generator import HeartFCGenerator
|
||||||
from src.do_tool.tool_use import ToolUser
|
from src.do_tool.tool_use import ToolUser
|
||||||
from ..chat.message_sender import message_manager # <-- Import the global manager
|
from ..chat.message_sender import message_manager # <-- Import the global manager
|
||||||
@@ -40,11 +40,8 @@ logger = get_module_logger("HeartFCLoop", config=interest_log_config) # Logger
|
|||||||
|
|
||||||
|
|
||||||
# 默认动作定义
|
# 默认动作定义
|
||||||
DEFAULT_ACTIONS = {
|
DEFAULT_ACTIONS = {"no_reply": "不回复", "text_reply": "文本回复, 可选附带表情", "emoji_reply": "仅表情回复"}
|
||||||
"no_reply": "不回复",
|
|
||||||
"text_reply": "文本回复, 可选附带表情",
|
|
||||||
"emoji_reply": "仅表情回复"
|
|
||||||
}
|
|
||||||
|
|
||||||
class ActionManager:
|
class ActionManager:
|
||||||
"""动作管理器:控制每次决策可以使用的动作"""
|
"""动作管理器:控制每次决策可以使用的动作"""
|
||||||
@@ -98,7 +95,8 @@ class ActionManager:
|
|||||||
|
|
||||||
def get_planner_tool_definition(self) -> List[Dict[str, Any]]:
|
def get_planner_tool_definition(self) -> List[Dict[str, Any]]:
|
||||||
"""获取当前动作集对应的规划器工具定义"""
|
"""获取当前动作集对应的规划器工具定义"""
|
||||||
return [{
|
return [
|
||||||
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "decide_reply_action",
|
"name": "decide_reply_action",
|
||||||
@@ -109,8 +107,8 @@ class ActionManager:
|
|||||||
"action": {
|
"action": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": list(self._available_actions.keys()),
|
"enum": list(self._available_actions.keys()),
|
||||||
"description": "决定采取的行动:" +
|
"description": "决定采取的行动:"
|
||||||
", ".join([f"'{k}'({v})" for k, v in self._available_actions.items()]),
|
+ ", ".join([f"'{k}'({v})" for k, v in self._available_actions.items()]),
|
||||||
},
|
},
|
||||||
"reasoning": {"type": "string", "description": "做出此决定的简要理由。"},
|
"reasoning": {"type": "string", "description": "做出此决定的简要理由。"},
|
||||||
"emoji_query": {
|
"emoji_query": {
|
||||||
@@ -121,24 +119,32 @@ class ActionManager:
|
|||||||
"required": ["action", "reasoning"],
|
"required": ["action", "reasoning"],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# 在文件开头添加自定义异常类
|
# 在文件开头添加自定义异常类
|
||||||
class HeartFCError(Exception):
|
class HeartFCError(Exception):
|
||||||
"""麦麦聊天系统基础异常类"""
|
"""麦麦聊天系统基础异常类"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PlannerError(HeartFCError):
|
class PlannerError(HeartFCError):
|
||||||
"""规划器异常"""
|
"""规划器异常"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ReplierError(HeartFCError):
|
class ReplierError(HeartFCError):
|
||||||
"""回复器异常"""
|
"""回复器异常"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SenderError(HeartFCError):
|
class SenderError(HeartFCError):
|
||||||
"""发送器异常"""
|
"""发送器异常"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class HeartFChatting:
|
class HeartFChatting:
|
||||||
@@ -363,9 +369,7 @@ class HeartFChatting:
|
|||||||
logger.error(f"{self.log_prefix} 检查新消息时出错: {e}")
|
logger.error(f"{self.log_prefix} 检查新消息时出错: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _think_plan_execute_loop(
|
async def _think_plan_execute_loop(self, cycle_timers: dict, planner_start_db_time: float) -> tuple[bool, str]:
|
||||||
self, cycle_timers: dict, planner_start_db_time: float
|
|
||||||
) -> tuple[bool, str]:
|
|
||||||
"""执行规划阶段"""
|
"""执行规划阶段"""
|
||||||
try:
|
try:
|
||||||
# think:思考
|
# think:思考
|
||||||
@@ -415,12 +419,7 @@ class HeartFChatting:
|
|||||||
return False, ""
|
return False, ""
|
||||||
|
|
||||||
async def _handle_action(
|
async def _handle_action(
|
||||||
self,
|
self, action: str, reasoning: str, emoji_query: str, cycle_timers: dict, planner_start_db_time: float
|
||||||
action: str,
|
|
||||||
reasoning: str,
|
|
||||||
emoji_query: str,
|
|
||||||
cycle_timers: dict,
|
|
||||||
planner_start_db_time: float
|
|
||||||
) -> tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
处理规划动作
|
处理规划动作
|
||||||
@@ -438,7 +437,7 @@ class HeartFChatting:
|
|||||||
action_handlers = {
|
action_handlers = {
|
||||||
"text_reply": self._handle_text_reply,
|
"text_reply": self._handle_text_reply,
|
||||||
"emoji_reply": self._handle_emoji_reply,
|
"emoji_reply": self._handle_emoji_reply,
|
||||||
"no_reply": self._handle_no_reply
|
"no_reply": self._handle_no_reply,
|
||||||
}
|
}
|
||||||
|
|
||||||
handler = action_handlers.get(action)
|
handler = action_handlers.get(action)
|
||||||
@@ -457,9 +456,7 @@ class HeartFChatting:
|
|||||||
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||||
return False, ""
|
return False, ""
|
||||||
|
|
||||||
async def _handle_text_reply(
|
async def _handle_text_reply(self, reasoning: str, emoji_query: str, cycle_timers: dict) -> tuple[bool, str]:
|
||||||
self, reasoning: str, emoji_query: str, cycle_timers: dict
|
|
||||||
) -> tuple[bool, str]:
|
|
||||||
"""
|
"""
|
||||||
处理文本回复
|
处理文本回复
|
||||||
|
|
||||||
@@ -544,9 +541,7 @@ class HeartFChatting:
|
|||||||
logger.error(f"{self.log_prefix} 表情发送失败: {e}")
|
logger.error(f"{self.log_prefix} 表情发送失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _handle_no_reply(
|
async def _handle_no_reply(self, reasoning: str, planner_start_db_time: float, cycle_timers: dict) -> bool:
|
||||||
self, reasoning: str, planner_start_db_time: float, cycle_timers: dict
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
处理不回复的情况
|
处理不回复的情况
|
||||||
|
|
||||||
@@ -573,9 +568,7 @@ class HeartFChatting:
|
|||||||
logger.info(f"{self.log_prefix} 等待被中断")
|
logger.info(f"{self.log_prefix} 等待被中断")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _wait_for_new_message(
|
async def _wait_for_new_message(self, observation, planner_start_db_time: float, log_prefix: str) -> bool:
|
||||||
self, observation, planner_start_db_time: float, log_prefix: str
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
等待新消息
|
等待新消息
|
||||||
|
|
||||||
@@ -614,9 +607,7 @@ class HeartFChatting:
|
|||||||
if timer_strings:
|
if timer_strings:
|
||||||
logger.debug(f"{log_prefix} 该次决策耗时: {'; '.join(timer_strings)}")
|
logger.debug(f"{log_prefix} 该次决策耗时: {'; '.join(timer_strings)}")
|
||||||
|
|
||||||
async def _handle_cycle_delay(
|
async def _handle_cycle_delay(self, action_taken_this_cycle: bool, cycle_start_time: float, log_prefix: str):
|
||||||
self, action_taken_this_cycle: bool, cycle_start_time: float, log_prefix: str
|
|
||||||
):
|
|
||||||
"""处理循环延迟"""
|
"""处理循环延迟"""
|
||||||
cycle_duration = time.monotonic() - cycle_start_time
|
cycle_duration = time.monotonic() - cycle_start_time
|
||||||
# if cycle_duration > 0.1:
|
# if cycle_duration > 0.1:
|
||||||
@@ -734,7 +725,9 @@ class HeartFChatting:
|
|||||||
action = arguments.get("action", "no_reply")
|
action = arguments.get("action", "no_reply")
|
||||||
# 验证动作是否在可用动作集中
|
# 验证动作是否在可用动作集中
|
||||||
if action not in self.action_manager.get_available_actions():
|
if action not in self.action_manager.get_available_actions():
|
||||||
logger.warning(f"{self.log_prefix}[Planner] LLM返回了未授权的动作: {action},使用默认动作no_reply")
|
logger.warning(
|
||||||
|
f"{self.log_prefix}[Planner] LLM返回了未授权的动作: {action},使用默认动作no_reply"
|
||||||
|
)
|
||||||
action = "no_reply"
|
action = "no_reply"
|
||||||
reasoning = f"LLM返回了未授权的动作: {action}"
|
reasoning = f"LLM返回了未授权的动作: {action}"
|
||||||
else:
|
else:
|
||||||
@@ -742,7 +735,9 @@ class HeartFChatting:
|
|||||||
emoji_query = arguments.get("emoji_query", "")
|
emoji_query = arguments.get("emoji_query", "")
|
||||||
|
|
||||||
# 记录决策结果
|
# 记录决策结果
|
||||||
logger.debug(f"{self.log_prefix}[要做什么]\nPrompt:\n{prompt}\n\n决策结果: {action}, 理由: {reasoning}, 表情查询: '{emoji_query}'")
|
logger.debug(
|
||||||
|
f"{self.log_prefix}[要做什么]\nPrompt:\n{prompt}\n\n决策结果: {action}, 理由: {reasoning}, 表情查询: '{emoji_query}'"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 处理工具调用失败
|
# 处理工具调用失败
|
||||||
logger.warning(f"{self.log_prefix}[Planner] {error_msg}")
|
logger.warning(f"{self.log_prefix}[Planner] {error_msg}")
|
||||||
@@ -927,9 +922,6 @@ class HeartFChatting:
|
|||||||
thinking_id=thinking_id, # Pass thinking_id positionally
|
thinking_id=thinking_id, # Pass thinking_id positionally
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if not response_set:
|
if not response_set:
|
||||||
logger.warning(f"{self.log_prefix}[Replier-{thinking_id}] LLM生成了一个空回复集。")
|
logger.warning(f"{self.log_prefix}[Replier-{thinking_id}] LLM生成了一个空回复集。")
|
||||||
return None
|
return None
|
||||||
@@ -980,8 +972,7 @@ class HeartFChatting:
|
|||||||
# 记录锚点消息ID
|
# 记录锚点消息ID
|
||||||
if self._current_cycle and anchor_message:
|
if self._current_cycle and anchor_message:
|
||||||
self._current_cycle.set_response_info(
|
self._current_cycle.set_response_info(
|
||||||
response_text=response_set,
|
response_text=response_set, anchor_message_id=anchor_message.message_info.message_id
|
||||||
anchor_message_id=anchor_message.message_info.message_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chat = anchor_message.chat_stream
|
chat = anchor_message.chat_stream
|
||||||
@@ -1056,9 +1047,7 @@ class HeartFChatting:
|
|||||||
emoji_path, description = emoji_raw
|
emoji_path, description = emoji_raw
|
||||||
# 记录表情信息
|
# 记录表情信息
|
||||||
if self._current_cycle:
|
if self._current_cycle:
|
||||||
self._current_cycle.set_response_info(
|
self._current_cycle.set_response_info(emoji_info=f"表情: {description}, 路径: {emoji_path}")
|
||||||
emoji_info=f"表情: {description}, 路径: {emoji_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
emoji_cq = image_path_to_base64(emoji_path)
|
emoji_cq = image_path_to_base64(emoji_path)
|
||||||
thinking_time_point = round(time.time(), 2)
|
thinking_time_point = round(time.time(), 2)
|
||||||
@@ -1100,4 +1089,3 @@ class HeartFChatting:
|
|||||||
if self._cycle_history:
|
if self._cycle_history:
|
||||||
return self._cycle_history[-1].to_dict()
|
return self._cycle_history[-1].to_dict()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from .heartflow_prompt_builder import prompt_builder
|
|||||||
from ..chat.utils import process_llm_response
|
from ..chat.utils import process_llm_response
|
||||||
from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG
|
from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG
|
||||||
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
||||||
from ..utils.timer_calculater import Timer
|
from ..utils.timer_calculator import Timer
|
||||||
|
|
||||||
from src.plugins.moods.moods import MoodManager
|
from src.plugins.moods.moods import MoodManager
|
||||||
|
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ from ...config.config import global_config
|
|||||||
from ..chat.message import MessageRecv
|
from ..chat.message import MessageRecv
|
||||||
from ..storage.storage import MessageStorage
|
from ..storage.storage import MessageStorage
|
||||||
from ..chat.utils import is_mentioned_bot_in_message
|
from ..chat.utils import is_mentioned_bot_in_message
|
||||||
from ..message import Seg
|
from maim_message import Seg
|
||||||
from src.heart_flow.heartflow import heartflow
|
from src.heart_flow.heartflow import heartflow
|
||||||
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
||||||
from ..chat.chat_stream import chat_manager
|
from ..chat.chat_stream import chat_manager
|
||||||
from ..chat.message_buffer import message_buffer
|
from ..chat.message_buffer import message_buffer
|
||||||
from ..utils.timer_calculater import Timer
|
from ..utils.timer_calculator import Timer
|
||||||
from src.plugins.person_info.relationship_manager import relationship_manager
|
from src.plugins.person_info.relationship_manager import relationship_manager
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ class HeartFCProcessor:
|
|||||||
"""
|
"""
|
||||||
logger.error(f"{context}: {error}")
|
logger.error(f"{context}: {error}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
if message and hasattr(message, 'raw_message'):
|
if message and hasattr(message, "raw_message"):
|
||||||
logger.error(f"相关消息原始内容: {message.raw_message}")
|
logger.error(f"相关消息原始内容: {message.raw_message}")
|
||||||
|
|
||||||
async def _process_relationship(self, message: MessageRecv) -> None:
|
async def _process_relationship(self, message: MessageRecv) -> None:
|
||||||
@@ -57,14 +57,10 @@ class HeartFCProcessor:
|
|||||||
|
|
||||||
if not is_known:
|
if not is_known:
|
||||||
logger.info(f"首次认识用户: {nickname}")
|
logger.info(f"首次认识用户: {nickname}")
|
||||||
await relationship_manager.first_knowing_some_one(
|
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
|
||||||
platform, user_id, nickname, cardname, ""
|
|
||||||
)
|
|
||||||
elif not await relationship_manager.is_qved_name(platform, user_id):
|
elif not await relationship_manager.is_qved_name(platform, user_id):
|
||||||
logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
|
logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
|
||||||
await relationship_manager.first_knowing_some_one(
|
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
|
||||||
platform, user_id, nickname, cardname, ""
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _calculate_interest(self, message: MessageRecv) -> Tuple[float, bool]:
|
async def _calculate_interest(self, message: MessageRecv) -> Tuple[float, bool]:
|
||||||
"""计算消息的兴趣度
|
"""计算消息的兴趣度
|
||||||
@@ -103,9 +99,11 @@ class HeartFCProcessor:
|
|||||||
if message.message_segment.type != "seglist":
|
if message.message_segment.type != "seglist":
|
||||||
return message.message_segment.type
|
return message.message_segment.type
|
||||||
|
|
||||||
if (isinstance(message.message_segment.data, list)
|
if (
|
||||||
|
isinstance(message.message_segment.data, list)
|
||||||
and all(isinstance(x, Seg) for x in message.message_segment.data)
|
and all(isinstance(x, Seg) for x in message.message_segment.data)
|
||||||
and len(message.message_segment.data) == 1):
|
and len(message.message_segment.data) == 1
|
||||||
|
):
|
||||||
return message.message_segment.data[0].type
|
return message.message_segment.data[0].type
|
||||||
|
|
||||||
return "seglist"
|
return "seglist"
|
||||||
@@ -145,8 +143,9 @@ class HeartFCProcessor:
|
|||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
# 3. 过滤检查
|
# 3. 过滤检查
|
||||||
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or \
|
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
|
||||||
self._check_ban_regex(message.raw_message, chat, userinfo):
|
message.raw_message, chat, userinfo
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
# 4. 缓冲检查
|
# 4. 缓冲检查
|
||||||
@@ -156,7 +155,7 @@ class HeartFCProcessor:
|
|||||||
type_messages = {
|
type_messages = {
|
||||||
"text": f"触发缓冲,消息:{message.processed_plain_text}",
|
"text": f"触发缓冲,消息:{message.processed_plain_text}",
|
||||||
"image": "触发缓冲,表情包/图片等待中",
|
"image": "触发缓冲,表情包/图片等待中",
|
||||||
"seglist": "触发缓冲,消息列表等待中"
|
"seglist": "触发缓冲,消息列表等待中",
|
||||||
}
|
}
|
||||||
logger.debug(type_messages.get(msg_type, "触发未知类型缓冲"))
|
logger.debug(type_messages.get(msg_type, "触发未知类型缓冲"))
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -38,11 +38,14 @@ def init_prompt():
|
|||||||
"heart_flow_prompt",
|
"heart_flow_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
Prompt("""
|
Prompt(
|
||||||
|
"""
|
||||||
你有以下信息可供参考:
|
你有以下信息可供参考:
|
||||||
{structured_info}
|
{structured_info}
|
||||||
以上的消息是你获取到的消息,或许可以帮助你更好地回复。
|
以上的消息是你获取到的消息,或许可以帮助你更好地回复。
|
||||||
""", "info_from_tools")
|
""",
|
||||||
|
"info_from_tools",
|
||||||
|
)
|
||||||
|
|
||||||
# Planner提示词 - 优化版
|
# Planner提示词 - 优化版
|
||||||
Prompt(
|
Prompt(
|
||||||
@@ -190,8 +193,8 @@ class PromptBuilder:
|
|||||||
|
|
||||||
if structured_info:
|
if structured_info:
|
||||||
structured_info_prompt = await global_prompt_manager.format_prompt(
|
structured_info_prompt = await global_prompt_manager.format_prompt(
|
||||||
"info_from_tools",
|
"info_from_tools", structured_info=structured_info
|
||||||
structured_info = structured_info)
|
)
|
||||||
else:
|
else:
|
||||||
structured_info_prompt = ""
|
structured_info_prompt = ""
|
||||||
|
|
||||||
|
|||||||
@@ -12,12 +12,12 @@ from ..chat.message import MessageSending, MessageRecv, MessageThinking, Message
|
|||||||
from ..chat.message_sender import message_manager
|
from ..chat.message_sender import message_manager
|
||||||
from ..chat.utils_image import image_path_to_base64
|
from ..chat.utils_image import image_path_to_base64
|
||||||
from ..willing.willing_manager import willing_manager
|
from ..willing.willing_manager import willing_manager
|
||||||
from ..message import UserInfo, Seg
|
from maim_message import UserInfo, Seg
|
||||||
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
||||||
from src.plugins.chat.chat_stream import ChatStream, chat_manager
|
from src.plugins.chat.chat_stream import ChatStream, chat_manager
|
||||||
from src.plugins.person_info.relationship_manager import relationship_manager
|
from src.plugins.person_info.relationship_manager import relationship_manager
|
||||||
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
||||||
from src.plugins.utils.timer_calculater import Timer
|
from src.plugins.utils.timer_calculator import Timer
|
||||||
|
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
chat_config = LogConfig(
|
chat_config = LogConfig(
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from ...config.config import global_config
|
|||||||
from ..chat.message import MessageThinking
|
from ..chat.message import MessageThinking
|
||||||
from .heartflow_prompt_builder import prompt_builder
|
from .heartflow_prompt_builder import prompt_builder
|
||||||
from ..chat.utils import process_llm_response
|
from ..chat.utils import process_llm_response
|
||||||
from ..utils.timer_calculater import Timer
|
from ..utils.timer_calculator import Timer
|
||||||
from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG
|
from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG
|
||||||
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
||||||
|
|
||||||
|
|||||||
@@ -3,23 +3,8 @@
|
|||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
from .api import global_api
|
from .api import global_api
|
||||||
from .message_base import (
|
|
||||||
Seg,
|
|
||||||
GroupInfo,
|
|
||||||
UserInfo,
|
|
||||||
FormatInfo,
|
|
||||||
TemplateInfo,
|
|
||||||
BaseMessageInfo,
|
|
||||||
MessageBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Seg",
|
|
||||||
"global_api",
|
"global_api",
|
||||||
"GroupInfo",
|
|
||||||
"UserInfo",
|
|
||||||
"FormatInfo",
|
|
||||||
"TemplateInfo",
|
|
||||||
"BaseMessageInfo",
|
|
||||||
"MessageBase",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,250 +1,6 @@
|
|||||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
|
||||||
from typing import Dict, Any, Callable, List, Set, Optional
|
|
||||||
from src.common.logger import get_module_logger
|
|
||||||
from src.plugins.message.message_base import MessageBase
|
|
||||||
from src.common.server import global_server
|
from src.common.server import global_server
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
|
||||||
import uvicorn
|
|
||||||
import os
|
import os
|
||||||
import traceback
|
from maim_message import MessageServer
|
||||||
|
|
||||||
logger = get_module_logger("api")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageHandler:
|
|
||||||
"""消息处理基类"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.message_handlers: List[Callable] = []
|
|
||||||
self.background_tasks = set()
|
|
||||||
|
|
||||||
def register_message_handler(self, handler: Callable):
|
|
||||||
"""注册消息处理函数"""
|
|
||||||
self.message_handlers.append(handler)
|
|
||||||
|
|
||||||
async def process_message(self, message: Dict[str, Any]):
|
|
||||||
"""处理单条消息"""
|
|
||||||
tasks = []
|
|
||||||
for handler in self.message_handlers:
|
|
||||||
try:
|
|
||||||
tasks.append(handler(message))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"消息处理出错: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
# 不抛出异常,而是记录错误并继续处理其他消息
|
|
||||||
continue
|
|
||||||
if tasks:
|
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
async def _handle_message(self, message: Dict[str, Any]):
|
|
||||||
"""后台处理单个消息"""
|
|
||||||
try:
|
|
||||||
await self.process_message(message)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(str(e)) from e
|
|
||||||
|
|
||||||
|
|
||||||
class MessageServer(BaseMessageHandler):
|
|
||||||
"""WebSocket服务端"""
|
|
||||||
|
|
||||||
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
host: str = "0.0.0.0",
|
|
||||||
port: int = 18000,
|
|
||||||
enable_token=False,
|
|
||||||
app: Optional[FastAPI] = None,
|
|
||||||
path: str = "/ws",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
# 将类级别的处理器添加到实例处理器中
|
|
||||||
self.message_handlers.extend(self._class_handlers)
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.path = path
|
|
||||||
self.app = app or FastAPI()
|
|
||||||
self.own_app = app is None # 标记是否使用自己创建的app
|
|
||||||
self.active_websockets: Set[WebSocket] = set()
|
|
||||||
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
|
|
||||||
self.valid_tokens: Set[str] = set()
|
|
||||||
self.enable_token = enable_token
|
|
||||||
self._setup_routes()
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
def _setup_routes(self):
|
|
||||||
@self.app.post("/api/message")
|
|
||||||
async def handle_message(message: Dict[str, Any]):
|
|
||||||
try:
|
|
||||||
# 创建后台任务处理消息
|
|
||||||
asyncio.create_task(self._handle_message(message))
|
|
||||||
return {"status": "success"}
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
||||||
|
|
||||||
@self.app.websocket("/ws")
|
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
|
||||||
headers = dict(websocket.headers)
|
|
||||||
token = headers.get("authorization")
|
|
||||||
platform = headers.get("platform", "default") # 获取platform标识
|
|
||||||
if self.enable_token:
|
|
||||||
if not token or not await self.verify_token(token):
|
|
||||||
await websocket.close(code=1008, reason="Invalid or missing token")
|
|
||||||
return
|
|
||||||
|
|
||||||
await websocket.accept()
|
|
||||||
self.active_websockets.add(websocket)
|
|
||||||
|
|
||||||
# 添加到platform映射
|
|
||||||
if platform not in self.platform_websockets:
|
|
||||||
self.platform_websockets[platform] = websocket
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
message = await websocket.receive_json()
|
|
||||||
# print(f"Received message: {message}")
|
|
||||||
asyncio.create_task(self._handle_message(message))
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
self._remove_websocket(websocket, platform)
|
|
||||||
except Exception as e:
|
|
||||||
self._remove_websocket(websocket, platform)
|
|
||||||
raise RuntimeError(str(e)) from e
|
|
||||||
finally:
|
|
||||||
self._remove_websocket(websocket, platform)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_class_handler(cls, handler: Callable):
|
|
||||||
"""注册类级别的消息处理器"""
|
|
||||||
if handler not in cls._class_handlers:
|
|
||||||
cls._class_handlers.append(handler)
|
|
||||||
|
|
||||||
def register_message_handler(self, handler: Callable):
|
|
||||||
"""注册实例级别的消息处理器"""
|
|
||||||
if handler not in self.message_handlers:
|
|
||||||
self.message_handlers.append(handler)
|
|
||||||
|
|
||||||
async def verify_token(self, token: str) -> bool:
|
|
||||||
if not self.enable_token:
|
|
||||||
return True
|
|
||||||
return token in self.valid_tokens
|
|
||||||
|
|
||||||
def add_valid_token(self, token: str):
|
|
||||||
self.valid_tokens.add(token)
|
|
||||||
|
|
||||||
def remove_valid_token(self, token: str):
|
|
||||||
self.valid_tokens.discard(token)
|
|
||||||
|
|
||||||
def run_sync(self):
|
|
||||||
"""同步方式运行服务器"""
|
|
||||||
if not self.own_app:
|
|
||||||
raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法")
|
|
||||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""异步方式运行服务器"""
|
|
||||||
self._running = True
|
|
||||||
try:
|
|
||||||
if self.own_app:
|
|
||||||
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
|
|
||||||
# 禁用 uvicorn 默认日志和访问日志
|
|
||||||
config = uvicorn.Config(
|
|
||||||
self.app, host=self.host, port=self.port, loop="asyncio", log_config=None, access_log=False
|
|
||||||
)
|
|
||||||
self.server = uvicorn.Server(config)
|
|
||||||
await self.server.serve()
|
|
||||||
else:
|
|
||||||
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
|
|
||||||
while self._running:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
await self.stop()
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
await self.stop()
|
|
||||||
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
|
||||||
finally:
|
|
||||||
await self.stop()
|
|
||||||
|
|
||||||
async def start_server(self):
|
|
||||||
"""启动服务器的异步方法"""
|
|
||||||
if not self._running:
|
|
||||||
self._running = True
|
|
||||||
await self.run()
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止服务器"""
|
|
||||||
# 清理platform映射
|
|
||||||
self.platform_websockets.clear()
|
|
||||||
|
|
||||||
# 取消所有后台任务
|
|
||||||
for task in self.background_tasks:
|
|
||||||
task.cancel()
|
|
||||||
# 等待所有任务完成
|
|
||||||
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
|
||||||
self.background_tasks.clear()
|
|
||||||
|
|
||||||
# 关闭所有WebSocket连接
|
|
||||||
for websocket in self.active_websockets:
|
|
||||||
await websocket.close()
|
|
||||||
self.active_websockets.clear()
|
|
||||||
|
|
||||||
if hasattr(self, "server") and self.own_app:
|
|
||||||
self._running = False
|
|
||||||
# 正确关闭 uvicorn 服务器
|
|
||||||
self.server.should_exit = True
|
|
||||||
await self.server.shutdown()
|
|
||||||
# 等待服务器完全停止
|
|
||||||
if hasattr(self.server, "started") and self.server.started:
|
|
||||||
await self.server.main_loop()
|
|
||||||
# 清理处理程序
|
|
||||||
self.message_handlers.clear()
|
|
||||||
|
|
||||||
def _remove_websocket(self, websocket: WebSocket, platform: str):
|
|
||||||
"""从所有集合中移除websocket"""
|
|
||||||
if websocket in self.active_websockets:
|
|
||||||
self.active_websockets.remove(websocket)
|
|
||||||
if platform in self.platform_websockets:
|
|
||||||
if self.platform_websockets[platform] == websocket:
|
|
||||||
del self.platform_websockets[platform]
|
|
||||||
|
|
||||||
async def broadcast_message(self, message: Dict[str, Any]):
|
|
||||||
disconnected = set()
|
|
||||||
for websocket in self.active_websockets:
|
|
||||||
try:
|
|
||||||
await websocket.send_json(message)
|
|
||||||
except Exception:
|
|
||||||
disconnected.add(websocket)
|
|
||||||
for websocket in disconnected:
|
|
||||||
self.active_websockets.remove(websocket)
|
|
||||||
|
|
||||||
async def broadcast_to_platform(self, platform: str, message: Dict[str, Any]):
|
|
||||||
"""向指定平台的所有WebSocket客户端广播消息"""
|
|
||||||
if platform not in self.platform_websockets:
|
|
||||||
raise ValueError(f"平台:{platform} 未连接")
|
|
||||||
|
|
||||||
disconnected = set()
|
|
||||||
try:
|
|
||||||
await self.platform_websockets[platform].send_json(message)
|
|
||||||
except Exception:
|
|
||||||
disconnected.add(self.platform_websockets[platform])
|
|
||||||
|
|
||||||
# 清理断开的连接
|
|
||||||
for websocket in disconnected:
|
|
||||||
self._remove_websocket(websocket, platform)
|
|
||||||
|
|
||||||
async def send_message(self, message: MessageBase):
|
|
||||||
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def send_message_rest(url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""发送消息到指定端点"""
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
try:
|
|
||||||
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
|
|
||||||
return await response.json()
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
|
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
|
||||||
|
|||||||
@@ -1,247 +0,0 @@
|
|||||||
from dataclasses import dataclass, asdict
|
|
||||||
from typing import List, Optional, Union, Dict
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Seg:
|
|
||||||
"""消息片段类,用于表示消息的不同部分
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
type: 片段类型,可以是 'text'、'image'、'seglist' 等
|
|
||||||
data: 片段的具体内容
|
|
||||||
- 对于 text 类型,data 是字符串
|
|
||||||
- 对于 image 类型,data 是 base64 字符串
|
|
||||||
- 对于 seglist 类型,data 是 Seg 列表
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
data: Union[str, List["Seg"]]
|
|
||||||
|
|
||||||
# def __init__(self, type: str, data: Union[str, List['Seg']],):
|
|
||||||
# """初始化实例,确保字典和属性同步"""
|
|
||||||
# # 先初始化字典
|
|
||||||
# self.type = type
|
|
||||||
# self.data = data
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict) -> "Seg":
|
|
||||||
"""从字典创建Seg实例"""
|
|
||||||
type = data.get("type")
|
|
||||||
data = data.get("data")
|
|
||||||
if type == "seglist":
|
|
||||||
data = [Seg.from_dict(seg) for seg in data]
|
|
||||||
return cls(type=type, data=data)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
"""转换为字典格式"""
|
|
||||||
result = {"type": self.type}
|
|
||||||
if self.type == "seglist":
|
|
||||||
result["data"] = [seg.to_dict() for seg in self.data]
|
|
||||||
else:
|
|
||||||
result["data"] = self.data
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GroupInfo:
|
|
||||||
"""群组信息类"""
|
|
||||||
|
|
||||||
platform: Optional[str] = None
|
|
||||||
group_id: Optional[int] = None
|
|
||||||
group_name: Optional[str] = None # 群名称
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
"""转换为字典格式"""
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict) -> "GroupInfo":
|
|
||||||
"""从字典创建GroupInfo实例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 包含必要字段的字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
GroupInfo: 新的实例
|
|
||||||
"""
|
|
||||||
if data.get("group_id") is None:
|
|
||||||
return None
|
|
||||||
return cls(
|
|
||||||
platform=data.get("platform"), group_id=data.get("group_id"), group_name=data.get("group_name", None)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class UserInfo:
|
|
||||||
"""用户信息类"""
|
|
||||||
|
|
||||||
platform: Optional[str] = None
|
|
||||||
user_id: Optional[int] = None
|
|
||||||
user_nickname: Optional[str] = None # 用户昵称
|
|
||||||
user_cardname: Optional[str] = None # 用户群昵称
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
"""转换为字典格式"""
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict) -> "UserInfo":
|
|
||||||
"""从字典创建UserInfo实例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 包含必要字段的字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserInfo: 新的实例
|
|
||||||
"""
|
|
||||||
return cls(
|
|
||||||
platform=data.get("platform"),
|
|
||||||
user_id=data.get("user_id"),
|
|
||||||
user_nickname=data.get("user_nickname", None),
|
|
||||||
user_cardname=data.get("user_cardname", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FormatInfo:
|
|
||||||
"""格式信息类"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
目前maimcore可接受的格式为text,image,emoji
|
|
||||||
可发送的格式为text,emoji,reply
|
|
||||||
"""
|
|
||||||
|
|
||||||
content_format: Optional[str] = None
|
|
||||||
accept_format: Optional[str] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
"""转换为字典格式"""
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict) -> "FormatInfo":
|
|
||||||
"""从字典创建FormatInfo实例
|
|
||||||
Args:
|
|
||||||
data: 包含必要字段的字典
|
|
||||||
Returns:
|
|
||||||
FormatInfo: 新的实例
|
|
||||||
"""
|
|
||||||
return cls(
|
|
||||||
content_format=data.get("content_format"),
|
|
||||||
accept_format=data.get("accept_format"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TemplateInfo:
|
|
||||||
"""模板信息类"""
|
|
||||||
|
|
||||||
template_items: Optional[Dict] = None
|
|
||||||
template_name: Optional[str] = None
|
|
||||||
template_default: bool = True
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
"""转换为字典格式"""
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict) -> "TemplateInfo":
|
|
||||||
"""从字典创建TemplateInfo实例
|
|
||||||
Args:
|
|
||||||
data: 包含必要字段的字典
|
|
||||||
Returns:
|
|
||||||
TemplateInfo: 新的实例
|
|
||||||
"""
|
|
||||||
return cls(
|
|
||||||
template_items=data.get("template_items"),
|
|
||||||
template_name=data.get("template_name"),
|
|
||||||
template_default=data.get("template_default", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseMessageInfo:
|
|
||||||
"""消息信息类"""
|
|
||||||
|
|
||||||
platform: Optional[str] = None
|
|
||||||
message_id: Union[str, int, None] = None
|
|
||||||
time: Optional[float] = None
|
|
||||||
group_info: Optional[GroupInfo] = None
|
|
||||||
user_info: Optional[UserInfo] = None
|
|
||||||
format_info: Optional[FormatInfo] = None
|
|
||||||
template_info: Optional[TemplateInfo] = None
|
|
||||||
additional_config: Optional[dict] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
"""转换为字典格式"""
|
|
||||||
result = {}
|
|
||||||
for field, value in asdict(self).items():
|
|
||||||
if value is not None:
|
|
||||||
if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)):
|
|
||||||
result[field] = value.to_dict()
|
|
||||||
else:
|
|
||||||
result[field] = value
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict) -> "BaseMessageInfo":
|
|
||||||
"""从字典创建BaseMessageInfo实例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 包含必要字段的字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseMessageInfo: 新的实例
|
|
||||||
"""
|
|
||||||
group_info = GroupInfo.from_dict(data.get("group_info", {}))
|
|
||||||
user_info = UserInfo.from_dict(data.get("user_info", {}))
|
|
||||||
format_info = FormatInfo.from_dict(data.get("format_info", {}))
|
|
||||||
template_info = TemplateInfo.from_dict(data.get("template_info", {}))
|
|
||||||
return cls(
|
|
||||||
platform=data.get("platform"),
|
|
||||||
message_id=data.get("message_id"),
|
|
||||||
time=data.get("time"),
|
|
||||||
additional_config=data.get("additional_config", None),
|
|
||||||
group_info=group_info,
|
|
||||||
user_info=user_info,
|
|
||||||
format_info=format_info,
|
|
||||||
template_info=template_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MessageBase:
|
|
||||||
"""消息类"""
|
|
||||||
|
|
||||||
message_info: BaseMessageInfo
|
|
||||||
message_segment: Seg
|
|
||||||
raw_message: Optional[str] = None # 原始消息,包含未解析的cq码
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
"""转换为字典格式
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 包含所有非None字段的字典,其中:
|
|
||||||
- message_info: 转换为字典格式
|
|
||||||
- message_segment: 转换为字典格式
|
|
||||||
- raw_message: 如果存在则包含
|
|
||||||
"""
|
|
||||||
result = {"message_info": self.message_info.to_dict(), "message_segment": self.message_segment.to_dict()}
|
|
||||||
if self.raw_message is not None:
|
|
||||||
result["raw_message"] = self.raw_message
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict) -> "MessageBase":
|
|
||||||
"""从字典创建MessageBase实例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 包含必要字段的字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MessageBase: 新的实例
|
|
||||||
"""
|
|
||||||
message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
|
|
||||||
message_segment = Seg.from_dict(data.get("message_segment", {}))
|
|
||||||
raw_message = data.get("raw_message", None)
|
|
||||||
return cls(message_info=message_info, message_segment=message_segment, raw_message=raw_message)
|
|
||||||
Reference in New Issue
Block a user