Merge branch 'dev' into dev

This commit is contained in:
SmartMita
2025-04-27 11:21:18 +09:00
committed by GitHub
64 changed files with 4270 additions and 3297 deletions

View File

@@ -8,6 +8,7 @@ from .pfc_utils import get_items_from_json
from src.individuality.individuality import Individuality
from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages
pfc_action_log_config = LogConfig(
console_format=PFC_ACTION_PLANNER_STYLE_CONFIG["console_format"],
@@ -132,12 +133,7 @@ class ActionPlanner:
chat_history_text = ""
try:
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
chat_history_list = observation_info.chat_history[-20:]
for msg in chat_history_list:
if isinstance(msg, dict) and "detailed_plain_text" in msg:
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
elif isinstance(msg, str):
chat_history_text += f"{msg}\n"
chat_history_text = observation_info.chat_history_str
if not chat_history_text: # 如果历史记录是空列表
chat_history_text = "还没有聊天记录。\n"
else:
@@ -146,12 +142,16 @@ class ActionPlanner:
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
new_messages_list = observation_info.unprocessed_messages
chat_history_text += f"--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n"
for msg in new_messages_list:
if isinstance(msg, dict) and "detailed_plain_text" in msg:
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
elif isinstance(msg, str):
chat_history_text += f"{msg}\n"
new_messages_str = await build_readable_messages(
new_messages_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
chat_history_text += (
f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
)
# 清理消息应该由调用者或 observation_info 内部逻辑处理,这里不再调用 clear
# if hasattr(observation_info, 'clear_unprocessed_messages'):
# observation_info.clear_unprocessed_messages()

View File

@@ -3,7 +3,7 @@ import asyncio
import traceback
from typing import Optional, Dict, Any, List
from src.common.logger import get_module_logger
from ..message.message_base import UserInfo
from maim_message import UserInfo
from ...config.config import global_config
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
from .message_storage import MongoDBMessageStorage

View File

@@ -98,7 +98,7 @@ class NotificationManager:
notification_type: 要处理的通知类型
handler: 处理器实例
"""
print(1145145511114445551111444)
# print(1145145511114445551111444)
if target not in self._handlers:
# print("没11有target")
self._handlers[target] = {}
@@ -146,9 +146,9 @@ class NotificationManager:
if target in self._handlers:
handlers = self._handlers[target].get(notification.type, [])
# print(1111111)
print(handlers)
# print(handlers)
for handler in handlers:
print(f"调用处理器: {handler}")
# print(f"调用处理器: {handler}")
await handler.handle_notification(notification)
def get_active_states(self) -> Set[NotificationType]:

View File

@@ -2,7 +2,7 @@ import time
import asyncio
import datetime
# from .message_storage import MongoDBMessageStorage
from src.plugins.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
from src.plugins.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_ch
from ...config.config import global_config
from typing import Dict, Any, Optional
from ..chat.message import Message
@@ -14,7 +14,7 @@ from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo
from .reply_generator import ReplyGenerator
from ..chat.chat_stream import ChatStream
from ..message.message_base import UserInfo
from maim_message import UserInfo
from src.plugins.chat.chat_stream import chat_manager
from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .waiter import Waiter
@@ -77,14 +77,22 @@ class Conversation:
raise
try:
logger.info(f"{self.stream_id} 加载初始聊天记录...")
initial_messages = await get_raw_msg_before_timestamp_with_chat( #
initial_messages = get_raw_msg_before_timestamp_with_chat( #
chat_id=self.stream_id,
timestamp=time.time(),
limit=30, # 加载最近30条作为初始上下文可以调整
)
chat_talking_prompt = await build_readable_messages(
initial_messages,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
if initial_messages:
# 将加载的消息填充到 ObservationInfo 的 chat_history
self.observation_info.chat_history = initial_messages
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
self.observation_info.chat_history_count = len(initial_messages)
# 更新 ObservationInfo 中的时间戳等信息
@@ -174,7 +182,7 @@ class Conversation:
if hasattr(self.observation_info, "clear_unprocessed_messages"):
# 确保 clear_unprocessed_messages 方法存在
logger.debug(f"准备执行 direct_reply清理 {initial_new_message_count} 条规划时已知的新消息。")
self.observation_info.clear_unprocessed_messages()
await self.observation_info.clear_unprocessed_messages()
# 手动重置计数器,确保状态一致性(理想情况下 clear 方法会做这个)
if hasattr(self.observation_info, "new_messages_count"):
self.observation_info.new_messages_count = 0
@@ -259,7 +267,7 @@ class Conversation:
# --- 根据不同的 action 执行 ---
if action == "direct_reply":
max_reply_attempts = 3 # 设置最大尝试次数(与 reply_checker.py 中的 max_retries 保持一致或稍大)
max_reply_attempts = 3 # 设置最大尝试次数(与 reply_checker.py 中的 max_retries 保持一致或稍大)
reply_attempt_count = 0
is_suitable = False
need_replan = False
@@ -284,17 +292,20 @@ class Conversation:
reply=self.generated_reply,
goal=current_goal_str,
chat_history=observation_info.chat_history,
retry_count=reply_attempt_count - 1, # 传递当前尝试次数从0开始计数
chat_history_str=observation_info.chat_history_str,
retry_count=reply_attempt_count - 1, # 传递当前尝试次数从0开始计数
)
logger.info(
f"{reply_attempt_count} 次检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
)
logger.info(f"{reply_attempt_count} 次检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}")
if is_suitable:
final_reply_to_send = self.generated_reply # 保存合适的回复
break # 回复合适,跳出循环
final_reply_to_send = self.generated_reply # 保存合适的回复
break # 回复合适,跳出循环
elif need_replan:
logger.warning(f"{reply_attempt_count} 次检查建议重新规划,停止尝试。原因: {check_reason}")
break # 如果检查器建议重新规划,也停止尝试
logger.warning(f"{reply_attempt_count} 次检查建议重新规划,停止尝试。原因: {check_reason}")
break # 如果检查器建议重新规划,也停止尝试
# 如果不合适但不需要重新规划,循环会继续进行下一次尝试
except Exception as check_err:
@@ -321,7 +332,7 @@ class Conversation:
return
# 发送合适的回复
self.generated_reply = final_reply_to_send # 确保 self.generated_reply 是最终要发送的内容
self.generated_reply = final_reply_to_send # 确保 self.generated_reply 是最终要发送的内容
await self._send_reply()
# 更新 action 历史状态为 done
@@ -337,7 +348,7 @@ class Conversation:
logger.warning(f"经过 {reply_attempt_count} 次尝试,未能生成合适的回复。最终原因: {check_reason}")
conversation_info.done_action[action_index].update(
{
"status": "recall", # 标记为 recall 因为没有成功发送
"status": "recall", # 标记为 recall 因为没有成功发送
"final_reason": f"尝试{reply_attempt_count}次后失败: {check_reason}",
"time": datetime.datetime.now().strftime("%H:%M:%S"),
}
@@ -352,7 +363,7 @@ class Conversation:
wait_action_record = {
"action": "wait",
"plan_reason": "因 direct_reply 多次尝试失败而执行的后备等待",
"status": "done", # wait 完成后可以认为是 done
"status": "done", # wait 完成后可以认为是 done
"time": datetime.datetime.now().strftime("%H:%M:%S"),
"final_reason": None,
}
@@ -461,42 +472,11 @@ class Conversation:
try:
# 外层 try: 捕获发送消息和后续处理中的主要错误
current_time = time.time() # 获取当前时间戳
_current_time = time.time() # 获取当前时间戳
reply_content = self.generated_reply # 获取要发送的内容
# 发送消息
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)
logger.info(f"消息已发送: {reply_content}") # 可以在发送后加个日志确认
# --- 添加的立即更新状态逻辑开始 ---
try:
# 内层 try: 专门捕获手动更新状态时可能出现的错误
# 创建一个代表刚刚发送的消息的字典
bot_message_info = {
"message_id": f"bot_sent_{current_time}", # 创建一个简单的唯一ID
"time": current_time,
"user_info": UserInfo( # 使用 UserInfo 类构建用户信息
user_id=str(global_config.BOT_QQ),
user_nickname=global_config.BOT_NICKNAME,
platform=self.chat_stream.platform, # 从 chat_stream 获取平台信息
).to_dict(), # 转换为字典格式存储
"processed_plain_text": reply_content, # 使用发送的内容
"detailed_plain_text": f"{int(current_time)},{global_config.BOT_NICKNAME}:{reply_content}", # 构造一个简单的详细文本, 时间戳取整
# 可以根据需要添加其他字段,保持与 observation_info.chat_history 中其他消息结构一致
}
# 直接更新 ObservationInfo 实例
if self.observation_info:
self.observation_info.chat_history.append(bot_message_info) # 将消息添加到历史记录末尾
self.observation_info.last_bot_speak_time = current_time # 更新 Bot 最后发言时间
self.observation_info.last_message_time = current_time # 更新最后消息时间
logger.debug("已手动将Bot发送的消息添加到 ObservationInfo")
else:
logger.warning("无法手动更新 ObservationInfo实例不存在")
except Exception as update_err:
logger.error(f"手动更新 ObservationInfo 时出错: {update_err}")
# --- 添加的立即更新状态逻辑结束 ---
# 原有的触发更新和等待代码
self.chat_observer.trigger_update()

View File

@@ -2,7 +2,7 @@ from typing import Optional
from src.common.logger import get_module_logger
from ..chat.chat_stream import ChatStream
from ..chat.message import Message
from ..message.message_base import Seg
from maim_message import Seg
from src.plugins.chat.message import MessageSending, MessageSet
from src.plugins.chat.message_sender import message_manager

View File

@@ -1,12 +1,13 @@
# Programmable Friendly Conversationalist
# Prefrontal cortex
from typing import List, Optional, Dict, Any, Set
from ..message.message_base import UserInfo
from maim_message import UserInfo
import time
from dataclasses import dataclass, field
from src.common.logger import get_module_logger
from .chat_observer import ChatObserver
from .chat_states import NotificationHandler, NotificationType
from src.plugins.utils.chat_message_builder import build_readable_messages
logger = get_module_logger("observation_info")
@@ -97,6 +98,7 @@ class ObservationInfo:
# data_list
chat_history: List[str] = field(default_factory=list)
chat_history_str: str = ""
unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list)
active_users: Set[str] = field(default_factory=set)
@@ -223,11 +225,18 @@ class ObservationInfo:
return None
return time.time() - self.last_bot_speak_time
def clear_unprocessed_messages(self):
async def clear_unprocessed_messages(self):
"""清空未处理消息列表"""
# 将未处理消息添加到历史记录中
for message in self.unprocessed_messages:
self.chat_history.append(message)
self.chat_history_str = await build_readable_messages(
self.chat_history[-20:] if len(self.chat_history) > 20 else self.chat_history,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
# 清空未处理消息列表
self.has_unread_messages = False
self.unprocessed_messages.clear()

View File

@@ -6,7 +6,7 @@ import datetime
from typing import List, Optional, Tuple, TYPE_CHECKING
from src.common.logger import get_module_logger
from ..chat.chat_stream import ChatStream
from ..message.message_base import UserInfo, Seg
from maim_message import UserInfo, Seg
from ..chat.message import Message
from ..models.utils_model import LLMRequest
from ...config.config import global_config
@@ -19,6 +19,7 @@ from src.individuality.individuality import Individuality
from .conversation_info import ConversationInfo
from .observation_info import ObservationInfo
import time
from src.plugins.utils.chat_message_builder import build_readable_messages
if TYPE_CHECKING:
pass
@@ -80,19 +81,20 @@ class GoalAnalyzer:
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
# 获取聊天历史记录
chat_history_list = observation_info.chat_history
chat_history_text = ""
for msg in chat_history_list:
chat_history_text += f"{msg}\n"
chat_history_text = observation_info.chat_history
if observation_info.new_messages_count > 0:
new_messages_list = observation_info.unprocessed_messages
new_messages_str = await build_readable_messages(
new_messages_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
chat_history_text += f"{observation_info.new_messages_count}条新消息:\n"
for msg in new_messages_list:
chat_history_text += f"{msg}\n"
observation_info.clear_unprocessed_messages()
# await observation_info.clear_unprocessed_messages()
identity_details_only = self.identity_detail_info
identity_addon = ""
@@ -371,22 +373,12 @@ class DirectMessageSender:
# 处理消息
await message.process()
message_json = message.to_dict()
_message_json = message.to_dict()
# 发送消息
try:
end_point = global_config.api_urls.get(message.message_info.platform, None)
if end_point:
# logger.info(f"发送消息到{end_point}")
# logger.info(message_json)
try:
await global_api.send_message_REST(end_point, message_json)
except Exception as e:
logger.error(f"REST方式发送失败出现错误: {str(e)}")
logger.info("尝试使用ws发送")
await self.send_via_ws(message)
else:
await self.send_via_ws(message)
await self.send_via_ws(message)
await self.storage.store_message(message, chat_stream)
logger.success(f"PFC消息已发送: {content}")
except Exception as e:
logger.error(f"PFC消息发送失败: {str(e)}")

View File

@@ -4,6 +4,7 @@ from src.plugins.memory_system.Hippocampus import HippocampusManager
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from ..chat.message import Message
from ..knowledge.knowledge_lib import qa_manager
logger = get_module_logger("knowledge_fetcher")
@@ -19,6 +20,25 @@ class KnowledgeFetcher:
request_type="knowledge_fetch",
)
def _lpmm_get_knowledge(self, query: str) -> str:
"""获取相关知识
Args:
query: 查询内容
Returns:
str: 构造好的,带相关度的知识
"""
logger.debug("正在从LPMM知识库中获取知识")
try:
knowledge_info = qa_manager.get_knowledge(query)
logger.debug(f"LPMM知识库查询结果: {knowledge_info:150}")
return knowledge_info
except Exception as e:
logger.error(f"LPMM知识库搜索工具执行失败: {str(e)}")
return "未找到匹配的知识"
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
"""获取相关知识
@@ -43,13 +63,16 @@ class KnowledgeFetcher:
max_depth=3,
fast_retrieval=False,
)
knowledge = ""
if related_memory:
knowledge = ""
sources = []
for memory in related_memory:
knowledge += memory[1] + "\n"
sources.append(f"记忆片段{memory[0]}")
return knowledge.strip(), "".join(sources)
knowledge = knowledge.strip(), "".join(sources)
knowledge += "现在有以下**知识**可供参考:\n "
knowledge += self._lpmm_get_knowledge(query)
knowledge += "请记住这些**知识**,并根据**知识**回答问题。\n"
return "未找到相关知识", "无记忆匹配"

View File

@@ -1,11 +1,10 @@
import json
import datetime
from typing import Tuple, List, Dict, Any
from src.common.logger import get_module_logger
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from .chat_observer import ChatObserver
from ..message.message_base import UserInfo
from maim_message import UserInfo
logger = get_module_logger("reply_checker")
@@ -22,7 +21,7 @@ class ReplyChecker:
self.max_retries = 3 # 最大重试次数
async def check(
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], retry_count: int = 0
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_text: str, retry_count: int = 0
) -> Tuple[bool, str, bool]:
"""检查生成的回复是否合适
@@ -36,7 +35,6 @@ class ReplyChecker:
"""
# 不再从 observer 获取,直接使用传入的 chat_history
# messages = self.chat_observer.get_cached_messages(limit=20)
chat_history_text = ""
try:
# 筛选出最近由 Bot 自己发送的消息
bot_messages = []
@@ -78,16 +76,9 @@ class ReplyChecker:
except Exception as e:
import traceback
logger.error(f"检查回复时出错: 类型={type(e)}, 值={e}")
logger.error(traceback.format_exc()) # 打印详细的回溯信息
for msg in chat_history[-20:]:
time_str = datetime.datetime.fromtimestamp(msg["time"]).strftime("%H:%M:%S")
user_info = UserInfo.from_dict(msg.get("user_info", {}))
sender = user_info.user_nickname or f"用户{user_info.user_id}"
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
logger.error(f"检查回复时出错: 类型={type(e)}, 值={e}")
logger.error(traceback.format_exc()) # 打印详细的回溯信息
prompt = f"""请检查以下回复或消息是否合适:

View File

@@ -7,6 +7,7 @@ from .reply_checker import ReplyChecker
from src.individuality.individuality import Individuality
from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages
logger = get_module_logger("reply_generator")
@@ -68,23 +69,19 @@ class ReplyGenerator:
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
# 获取聊天历史记录
chat_history_list = (
observation_info.chat_history[-20:]
if len(observation_info.chat_history) >= 20
else observation_info.chat_history
)
chat_history_text = ""
for msg in chat_history_list:
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
chat_history_text = observation_info.chat_history_str
if observation_info.new_messages_count > 0:
new_messages_list = observation_info.unprocessed_messages
chat_history_text += f"{observation_info.new_messages_count}条新消息:\n"
for msg in new_messages_list:
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
observation_info.clear_unprocessed_messages()
new_messages_str = await build_readable_messages(
new_messages_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
# await observation_info.clear_unprocessed_messages()
identity_details_only = self.identity_detail_info
identity_addon = ""
@@ -173,7 +170,7 @@ class ReplyGenerator:
return "抱歉,我现在有点混乱,让我重新思考一下..."
async def check_reply(
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], retry_count: int = 0
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_str: str, retry_count: int = 0
) -> Tuple[bool, str, bool]:
"""检查回复是否合适
@@ -185,4 +182,4 @@ class ReplyGenerator:
Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
"""
return await self.reply_checker.check(reply, goal, chat_history, retry_count)
return await self.reply_checker.check(reply, goal, chat_history, chat_history_str, retry_count)

View File

@@ -4,7 +4,7 @@ MaiMBot插件系统
"""
from .chat.chat_stream import chat_manager
from .chat.emoji_manager import emoji_manager
from .emoji_system.emoji_manager import emoji_manager
from .person_info.relationship_manager import relationship_manager
from .moods.moods import MoodManager
from .willing.willing_manager import willing_manager

View File

@@ -1,4 +1,4 @@
from .emoji_manager import emoji_manager
from ..emoji_system.emoji_manager import emoji_manager
from ..person_info.relationship_manager import relationship_manager
from .chat_stream import chat_manager
from .message_sender import message_manager

View File

@@ -6,7 +6,7 @@ from typing import Dict, Optional
from ...common.database import db
from ..message.message_base import GroupInfo, UserInfo
from maim_message import GroupInfo, UserInfo
from src.common.logger import get_module_logger, LogConfig, CHAT_STREAM_STYLE_CONFIG

View File

@@ -1,595 +0,0 @@
import asyncio
import base64
import hashlib
import os
import random
import time
import traceback
from typing import Optional, Tuple
from PIL import Image
import io
from ...common.database import db
from ...config.config import global_config
from ..chat.utils import get_embedding
from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLMRequest
from src.common.logger import get_module_logger, LogConfig, EMOJI_STYLE_CONFIG
emoji_log_config = LogConfig(
console_format=EMOJI_STYLE_CONFIG["console_format"],
file_format=EMOJI_STYLE_CONFIG["file_format"],
)
logger = get_module_logger("emoji", config=emoji_log_config)
image_manager = ImageManager()
class EmojiManager:
_instance = None
EMOJI_DIR = os.path.join("data", "emoji") # 表情包存储目录
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
self._scan_task = None
self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.llm_emotion_judge = LLMRequest(
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0
self.emoji_num_max = global_config.max_emoji_num
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
logger.info("启动表情包管理器")
def _ensure_emoji_dir(self):
"""确保表情存储目录存在"""
os.makedirs(self.EMOJI_DIR, exist_ok=True)
def _update_emoji_count(self):
"""更新表情包数量统计
检查数据库中的表情包数量并更新到 self.emoji_num
"""
try:
self._ensure_db()
self.emoji_num = db.emoji.count_documents({})
logger.info(f"[统计] 当前表情包数量: {self.emoji_num}")
except Exception as e:
logger.error(f"[错误] 更新表情包数量失败: {str(e)}")
def initialize(self):
"""初始化数据库连接和表情目录"""
if not self._initialized:
try:
self._ensure_emoji_collection()
self._ensure_emoji_dir()
self._initialized = True
# 更新表情包数量
self._update_emoji_count()
# 启动时执行一次完整性检查
self.check_emoji_file_integrity()
except Exception:
logger.exception("初始化表情管理器失败")
def _ensure_db(self):
"""确保数据库已初始化"""
if not self._initialized:
self.initialize()
if not self._initialized:
raise RuntimeError("EmojiManager not initialized")
@staticmethod
def _ensure_emoji_collection():
"""确保emoji集合存在并创建索引
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
索引的作用是加快数据库查询速度:
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
- tags字段的普通索引: 加快按标签搜索表情包的速度
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if "emoji" not in db.list_collection_names():
db.create_collection("emoji")
db.emoji.create_index([("embedding", "2dsphere")])
db.emoji.create_index([("filename", 1)], unique=True)
def record_usage(self, emoji_id: str):
"""记录表情使用次数"""
try:
self._ensure_db()
db.emoji.update_one({"_id": emoji_id}, {"$inc": {"usage_count": 1}})
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str, str]]:
"""根据文本内容获取相关表情包
Args:
text: 输入文本
Returns:
Optional[str]: 表情包文件路径如果没有找到则返回None
可不可以通过 配置文件中的指令 来自定义使用表情包的逻辑?
我觉得可行
"""
try:
self._ensure_db()
# 获取文本的embedding
text_for_search = await self._get_kimoji_for_text(text)
if not text_for_search:
logger.error("无法获取文本的情绪")
return None
text_embedding = await get_embedding(text_for_search, request_type="emoji")
if not text_embedding:
logger.error("无法获取文本的embedding")
return None
try:
# 获取所有表情包
all_emojis = [
e
for e in db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1, "blacklist": 1})
if "blacklist" not in e
]
if not all_emojis:
logger.warning("数据库中没有任何表情包")
return None
# 计算余弦相似度并排序
def cosine_similarity(v1, v2):
if not v1 or not v2:
return 0
dot_product = sum(a * b for a, b in zip(v1, v2))
norm_v1 = sum(a * a for a in v1) ** 0.5
norm_v2 = sum(b * b for b in v2) ** 0.5
if norm_v1 == 0 or norm_v2 == 0:
return 0
return dot_product / (norm_v1 * norm_v2)
# 计算所有表情包与输入文本的相似度
emoji_similarities = [
(emoji, cosine_similarity(text_embedding, emoji.get("embedding", []))) for emoji in all_emojis
]
# 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包
top_10_emojis = emoji_similarities[: 10 if len(emoji_similarities) > 10 else len(emoji_similarities)]
if not top_10_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_10_emojis)
if selected_emoji and "path" in selected_emoji:
# 更新使用次数
db.emoji.update_one({"_id": selected_emoji["_id"]}, {"$inc": {"usage_count": 1}})
logger.info(
f"[匹配] 找到表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})"
)
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
return selected_emoji["path"], "[ %s ]" % selected_emoji.get("description", "无描述")
except Exception as search_error:
logger.error(f"[错误] 搜索表情包失败: {str(search_error)}")
return None
return None
except Exception as e:
logger.error(f"[错误] 获取表情包失败: {str(e)}")
return None
@staticmethod
async def _get_emoji_description(image_base64: str) -> str:
"""获取表情包的标签使用image_manager的描述生成功能"""
try:
# 使用image_manager获取描述去掉前后的方括号和"表情包:"前缀
description = await image_manager.get_emoji_description(image_base64)
# 去掉[表情包xxx]的格式,只保留描述内容
description = description.strip("[]").replace("表情包:", "")
return description
except Exception as e:
logger.error(f"[错误] 获取表情包描述失败: {str(e)}")
return None
async def _check_emoji(self, image_base64: str, image_format: str) -> str:
try:
prompt = (
f'这是一个表情包,请回答这个表情包是否满足"{global_config.EMOJI_CHECK_PROMPT}"的要求,是则回答是,'
f"否则回答否,不要出现任何其他内容"
)
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
logger.debug(f"[检查] 表情包检查结果: {content}")
return content
except Exception as e:
logger.error(f"[错误] 表情包检查失败: {str(e)}")
return None
async def _get_kimoji_for_text(self, text: str):
try:
prompt = (
f"这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,"
f"请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,"
f'注意不要输出任何对消息内容的分析内容,只输出"一种什么样的感觉"中间的形容词部分。'
)
content, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=1.5)
logger.info(f"[情感] 表情包情感描述: {content}")
return content
except Exception as e:
logger.error(f"[错误] 获取表情包情感失败: {str(e)}")
return None
async def scan_new_emojis(self):
"""扫描新的表情包"""
try:
emoji_dir = self.EMOJI_DIR
os.makedirs(emoji_dir, exist_ok=True)
# 获取所有支持的图片文件
files_to_process = [
f for f in os.listdir(emoji_dir) if f.lower().endswith((".jpg", ".jpeg", ".png", ".gif"))
]
# 检查当前表情包数量
self._update_emoji_count()
if self.emoji_num >= self.emoji_num_max:
logger.warning(f"[警告] 表情包数量已达到上限({self.emoji_num}/{self.emoji_num_max}),跳过注册")
return
# 计算还可以注册的数量
remaining_slots = self.emoji_num_max - self.emoji_num
logger.info(f"[注册] 还可以注册 {remaining_slots} 个表情包")
for filename in files_to_process:
# 如果已经达到上限,停止注册
if self.emoji_num >= self.emoji_num_max:
logger.warning(f"[警告] 表情包数量已达到上限({self.emoji_num}/{self.emoji_num_max}),停止注册")
break
image_path = os.path.join(emoji_dir, filename)
# 获取图片的base64编码和哈希值
image_base64 = image_path_to_base64(image_path)
if image_base64 is None:
os.remove(image_path)
continue
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 检查是否已经注册过
existing_emoji_by_path = db["emoji"].find_one({"filename": filename})
existing_emoji_by_hash = db["emoji"].find_one({"hash": image_hash})
if existing_emoji_by_path and existing_emoji_by_hash:
if existing_emoji_by_path["_id"] != existing_emoji_by_hash["_id"]:
logger.error(f"[错误] 表情包已存在但记录不一致: {filename}")
db.emoji.delete_one({"_id": existing_emoji_by_path["_id"]})
db.emoji.delete_one({"_id": existing_emoji_by_hash["_id"]})
existing_emoji = None
else:
existing_emoji = existing_emoji_by_hash
elif existing_emoji_by_hash:
logger.error(f"[错误] 表情包hash已存在但path不存在: {filename}")
db.emoji.delete_one({"_id": existing_emoji_by_hash["_id"]})
existing_emoji = None
elif existing_emoji_by_path:
logger.error(f"[错误] 表情包path已存在但hash不存在: {filename}")
db.emoji.delete_one({"_id": existing_emoji_by_path["_id"]})
existing_emoji = None
else:
existing_emoji = None
description = None
if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get("description")
# 检查是否在images集合中存在
existing_image = db.images.find_one({"hash": image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
"hash": image_hash,
"path": image_path,
"type": "emoji",
"description": description,
"timestamp": int(time.time()),
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, "emoji")
logger.success(f"[同步] 已同步表情包到images集合: {filename}")
continue
# 检查是否在images集合中已有描述
existing_description = image_manager._get_description_from_db(image_hash, "emoji")
if existing_description:
description = existing_description
else:
# 获取表情包的描述
description = await self._get_emoji_description(image_base64)
if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64, image_format)
if "" not in check:
os.remove(image_path)
logger.info(f"[过滤] 表情包描述: {description}")
logger.info(f"[过滤] 表情包不满足规则,已移除: {check}")
continue
logger.info(f"[检查] 表情包检查通过: {check}")
if description is not None:
embedding = await get_embedding(description, request_type="emoji")
if not embedding:
logger.error("获取消息嵌入向量失败")
raise ValueError("获取消息嵌入向量失败")
# 准备数据库记录
emoji_record = {
"filename": filename,
"path": image_path,
"embedding": embedding,
"description": description,
"hash": image_hash,
"timestamp": int(time.time()),
}
# 保存到emoji数据库
db["emoji"].insert_one(emoji_record)
logger.success(f"[注册] 新表情包: {filename}")
logger.info(f"[描述] {description}")
# 更新当前表情包数量
self.emoji_num += 1
logger.info(f"[统计] 当前表情包数量: {self.emoji_num}/{self.emoji_num_max}")
# 保存到images数据库
image_doc = {
"hash": image_hash,
"path": image_path,
"type": "emoji",
"description": description,
"timestamp": int(time.time()),
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, "emoji")
logger.success(f"[同步] 已保存到images集合: {filename}")
else:
logger.warning(f"[跳过] 表情包: {filename}")
except Exception:
logger.exception("[错误] 扫描表情包失败")
def check_emoji_file_integrity(self):
"""检查表情包文件完整性
如果文件已被删除,则从数据库中移除对应记录
"""
try:
self._ensure_db()
# 获取所有表情包记录
all_emojis = list(db.emoji.find())
removed_count = 0
total_count = len(all_emojis)
for emoji in all_emojis:
try:
if "path" not in emoji:
logger.warning(f"[检查] 发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
db.emoji.delete_one({"_id": emoji["_id"]})
removed_count += 1
continue
if "embedding" not in emoji:
logger.warning(f"[检查] 发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
db.emoji.delete_one({"_id": emoji["_id"]})
removed_count += 1
continue
# 检查文件是否存在
if not os.path.exists(emoji["path"]):
logger.warning(f"[检查] 表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录
result = db.emoji.delete_one({"_id": emoji["_id"]})
if result.deleted_count > 0:
logger.debug(f"[清理] 成功删除数据库记录: {emoji['_id']}")
removed_count += 1
else:
logger.error(f"[错误] 删除数据库记录失败: {emoji['_id']}")
continue
if "hash" not in emoji:
logger.warning(f"[检查] 发现缺失记录缺少hash字段ID: {emoji.get('_id', 'unknown')}")
hash = hashlib.md5(open(emoji["path"], "rb").read()).hexdigest()
db.emoji.update_one({"_id": emoji["_id"]}, {"$set": {"hash": hash}})
else:
file_hash = hashlib.md5(open(emoji["path"], "rb").read()).hexdigest()
if emoji["hash"] != file_hash:
logger.warning(f"[检查] 表情包文件hash不匹配ID: {emoji.get('_id', 'unknown')}")
db.emoji.delete_one({"_id": emoji["_id"]})
removed_count += 1
# 修复拼写错误
if "discription" in emoji:
desc = emoji["discription"]
db.emoji.update_one(
{"_id": emoji["_id"]}, {"$unset": {"discription": ""}, "$set": {"description": desc}}
)
except Exception as item_error:
logger.error(f"[错误] 处理表情包记录时出错: {str(item_error)}")
continue
# 验证清理结果
remaining_count = db.emoji.count_documents({})
if removed_count > 0:
logger.success(f"[清理] 已清理 {removed_count} 个失效的表情包记录")
logger.info(f"[统计] 清理前: {total_count} | 清理后: {remaining_count}")
else:
logger.info(f"[检查] 已检查 {total_count} 个表情包记录")
except Exception as e:
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc())
def check_emoji_file_full(self):
"""检查表情包文件是否完整,如果数量超出限制且允许删除,则删除多余的表情包
删除规则:
1. 优先删除创建时间更早的表情包
2. 优先删除使用次数少的表情包,但使用次数多的也有小概率被删除
"""
try:
self._ensure_db()
# 更新表情包数量
self._update_emoji_count()
# 检查是否超出限制
if self.emoji_num <= self.emoji_num_max:
return
# 如果超出限制但不允许删除,则只记录警告
if not global_config.max_reach_deletion:
logger.warning(f"[警告] 表情包数量({self.emoji_num})超出限制({self.emoji_num_max}),但未开启自动删除")
return
# 计算需要删除的数量
delete_count = self.emoji_num - self.emoji_num_max
logger.info(f"[清理] 需要删除 {delete_count} 个表情包")
# 获取所有表情包,按时间戳升序(旧的在前)排序
all_emojis = list(db.emoji.find().sort([("timestamp", 1)]))
# 计算权重:使用次数越多,被删除的概率越小
weights = []
max_usage = max((emoji.get("usage_count", 0) for emoji in all_emojis), default=1)
for emoji in all_emojis:
usage_count = emoji.get("usage_count", 0)
# 使用指数衰减函数计算权重,使用次数越多权重越小
weight = 1.0 / (1.0 + usage_count / max(1, max_usage))
weights.append(weight)
# 根据权重随机选择要删除的表情包
to_delete = []
remaining_indices = list(range(len(all_emojis)))
while len(to_delete) < delete_count and remaining_indices:
# 计算当前剩余表情包的权重
current_weights = [weights[i] for i in remaining_indices]
# 归一化权重
total_weight = sum(current_weights)
if total_weight == 0:
break
normalized_weights = [w / total_weight for w in current_weights]
# 随机选择一个表情包
selected_idx = random.choices(remaining_indices, weights=normalized_weights, k=1)[0]
to_delete.append(all_emojis[selected_idx])
remaining_indices.remove(selected_idx)
# 删除选中的表情包
deleted_count = 0
for emoji in to_delete:
try:
# 删除文件
if "path" in emoji and os.path.exists(emoji["path"]):
os.remove(emoji["path"])
logger.info(f"[删除] 文件: {emoji['path']} (使用次数: {emoji.get('usage_count', 0)})")
# 删除数据库记录
db.emoji.delete_one({"_id": emoji["_id"]})
deleted_count += 1
# 同时从images集合中删除
if "hash" in emoji:
db.images.delete_one({"hash": emoji["hash"]})
except Exception as e:
logger.error(f"[错误] 删除表情包失败: {str(e)}")
continue
# 更新表情包数量
self._update_emoji_count()
logger.success(f"[清理] 已删除 {deleted_count} 个表情包,当前数量: {self.emoji_num}")
except Exception as e:
logger.error(f"[错误] 检查表情包数量失败: {str(e)}")
async def start_periodic_check_register(self):
"""定期检查表情包完整性和数量"""
while True:
logger.info("[扫描] 开始检查表情包完整性...")
self.check_emoji_file_integrity()
logger.info("[扫描] 开始删除所有图片缓存...")
await self.delete_all_images()
logger.info("[扫描] 开始扫描新表情包...")
if self.emoji_num < self.emoji_num_max:
await self.scan_new_emojis()
if self.emoji_num > self.emoji_num_max:
logger.warning(f"[警告] 表情包数量超过最大限制: {self.emoji_num} > {self.emoji_num_max},跳过注册")
if not global_config.max_reach_deletion:
logger.warning("表情包数量超过最大限制,终止注册")
break
else:
logger.warning("表情包数量超过最大限制,开始删除表情包")
self.check_emoji_file_full()
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
@staticmethod
async def delete_all_images():
"""删除 data/image 目录下的所有文件"""
try:
image_dir = os.path.join("data", "image")
if not os.path.exists(image_dir):
logger.warning(f"[警告] 目录不存在: {image_dir}")
return
deleted_count = 0
failed_count = 0
# 遍历目录下的所有文件
for filename in os.listdir(image_dir):
file_path = os.path.join(image_dir, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
deleted_count += 1
logger.debug(f"[删除] 文件: {file_path}")
except Exception as e:
failed_count += 1
logger.error(f"[错误] 删除文件失败 {file_path}: {str(e)}")
logger.success(f"[清理] 已删除 {deleted_count} 个文件,失败 {failed_count}")
except Exception as e:
logger.error(f"[错误] 删除图片目录失败: {str(e)}")
# 创建全局单例
emoji_manager = EmojiManager()

View File

@@ -7,7 +7,7 @@ import urllib3
from src.common.logger import get_module_logger
from .chat_stream import ChatStream
from .utils_image import image_manager
from ..message.message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
logger = get_module_logger("chat_message")
@@ -127,12 +127,12 @@ class MessageRecv(Message):
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return "[图片]"
return "[发了一张图片,网卡了加载不出来]"
elif seg.type == "emoji":
self.is_emoji = True
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return "[表情]"
return "[发了一个表情包,网卡了加载不出来]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
@@ -141,14 +141,8 @@ class MessageRecv(Message):
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
timestamp = self.message_info.time
user_info = self.message_info.user_info
# name = (
# f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
# if user_info.user_cardname != None
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
# )
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@@ -222,11 +216,11 @@ class MessageProcessBase(Message):
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return "[图片]"
return "[图片,网卡了加载不出来]"
elif seg.type == "emoji":
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return "[表情]"
return "[表情,网卡了加载不出来]"
elif seg.type == "at":
return f"[@{seg.data}]"
elif seg.type == "reply":

View File

@@ -3,7 +3,7 @@ from src.common.logger import get_module_logger
import asyncio
from dataclasses import dataclass, field
from .message import MessageRecv
from ..message.message_base import BaseMessageInfo, GroupInfo, Seg
from maim_message import BaseMessageInfo, GroupInfo
import hashlib
from typing import Dict
from collections import OrderedDict
@@ -128,58 +128,67 @@ class MessageBuffer:
if result:
async with self.lock: # 再次加锁
# 清理所有早于当前消息的已处理消息, 收集所有早于当前消息的F消息的processed_plain_text
keep_msgs = OrderedDict()
combined_text = []
found = False
type = "seglist"
is_update = True
for msg_id, msg in self.buffer_pool[person_id_].items():
keep_msgs = OrderedDict() # 用于存放 T 消息之后的消息
collected_texts = [] # 用于收集 T 消息及之前 F 消息的文本
process_target_found = False
# 遍历当前用户的所有缓冲消息
for msg_id, cache_msg in self.buffer_pool[person_id_].items():
# 如果找到了目标处理消息 (T 状态)
if msg_id == message.message_info.message_id:
found = True
if msg.message.message_segment.type != "seglist":
type = msg.message.message_segment.type
else:
if (
isinstance(msg.message.message_segment.data, list)
and all(isinstance(x, Seg) for x in msg.message.message_segment.data)
and len(msg.message.message_segment.data) == 1
):
type = msg.message.message_segment.data[0].type
combined_text.append(msg.message.processed_plain_text)
continue
if found:
keep_msgs[msg_id] = msg
elif msg.result == "F":
# 收集F消息的文本内容
f_type = "seglist"
if msg.message.message_segment.type != "seglist":
f_type = msg.message.message_segment.type
else:
if (
isinstance(msg.message.message_segment.data, list)
and all(isinstance(x, Seg) for x in msg.message.message_segment.data)
and len(msg.message.message_segment.data) == 1
):
f_type = msg.message.message_segment.data[0].type
if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
if f_type == "text":
combined_text.append(msg.message.processed_plain_text)
elif f_type != "text":
is_update = False
elif msg.result == "U":
logger.debug(f"异常未处理信息id {msg.message.message_info.message_id}")
process_target_found = True
# 收集这条 T 消息的文本 (如果有)
if (
hasattr(cache_msg.message, "processed_plain_text")
and cache_msg.message.processed_plain_text
):
collected_texts.append(cache_msg.message.processed_plain_text)
# 不立即放入 keep_msgs因为它之前的 F 消息也处理完了
# 更新当前消息的processed_plain_text
if combined_text and combined_text[0] != message.processed_plain_text and is_update:
if type == "text":
message.processed_plain_text = "".join(combined_text)
logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容到当前消息")
elif type == "emoji":
combined_text.pop()
message.processed_plain_text = "".join(combined_text)
message.is_emoji = False
logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容覆盖当前emoji消息")
# 如果已经找到了目标 T 消息,之后的消息需要保留
elif process_target_found:
keep_msgs[msg_id] = cache_msg
# 如果还没找到目标 T 消息,说明是之前的消息 (F 或 U)
else:
if cache_msg.result == "F":
# 收集这条 F 消息的文本 (如果有)
if (
hasattr(cache_msg.message, "processed_plain_text")
and cache_msg.message.processed_plain_text
):
collected_texts.append(cache_msg.message.processed_plain_text)
elif cache_msg.result == "U":
# 理论上不应该在 T 消息之前还有 U 消息,记录日志
logger.warning(
f"异常状态:在目标 T 消息 {message.message_info.message_id} 之前发现未处理的 U 消息 {cache_msg.message.message_info.message_id}"
)
# 也可以选择收集其文本
if (
hasattr(cache_msg.message, "processed_plain_text")
and cache_msg.message.processed_plain_text
):
collected_texts.append(cache_msg.message.processed_plain_text)
# 更新当前消息 (message) 的 processed_plain_text
# 只有在收集到的文本多于一条,或者只有一条但与原始文本不同时才合并
if collected_texts:
# 使用 OrderedDict 去重,同时保留原始顺序
unique_texts = list(OrderedDict.fromkeys(collected_texts))
merged_text = "".join(unique_texts)
# 只有在合并后的文本与原始文本不同时才更新
# 并且确保不是空合并
if merged_text and merged_text != message.processed_plain_text:
message.processed_plain_text = merged_text
# 如果合并了文本,原消息不再视为纯 emoji
if hasattr(message, "is_emoji"):
message.is_emoji = False
logger.debug(
f"合并了 {len(unique_texts)} 条消息的文本内容到当前消息 {message.message_info.message_id}"
)
# 更新缓冲池,只保留 T 消息之后的消息
self.buffer_pool[person_id_] = keep_msgs
return result
except asyncio.TimeoutError:

View File

@@ -62,20 +62,10 @@ class MessageSender:
# logger.trace(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志
# --- 结束打字延迟 ---
message_json = message.to_dict()
message_preview = truncate_message(message.processed_plain_text)
try:
end_point = global_config.api_urls.get(message.message_info.platform, None)
if end_point:
try:
await global_api.send_message_rest(end_point, message_json)
except Exception as e:
logger.error(f"REST发送失败: {str(e)}")
logger.info(f"[{message.chat_stream.stream_id}] 尝试使用WS发送")
await self.send_via_ws(message)
else:
await self.send_via_ws(message)
await self.send_via_ws(message)
logger.success(f"发送消息 '{message_preview}' 成功") # 调整日志格式
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}")

View File

@@ -12,7 +12,7 @@ from ..models.utils_model import LLMRequest
from ..utils.typo_generator import ChineseTypoGenerator
from ...config.config import global_config
from .message import MessageRecv, Message
from ..message.message_base import UserInfo
from maim_message import UserInfo
from .chat_stream import ChatStream
from ..moods.moods import MoodManager
from ...common.database import db
@@ -234,6 +234,13 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
Returns:
List[str]: 分割和合并后的句子列表
"""
# 预处理:处理多余的换行符
# 1. 将连续的换行符替换为单个换行符
text = re.sub(r"\n\s*\n+", "\n", text)
# 2. 处理换行符和其他分隔符的组合
text = re.sub(r"\n\s*([,。;\s])", r"\1", text)
text = re.sub(r"([,。;\s])\s*\n", r"\1", text)
# 处理两个汉字中间的换行符
text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text)
@@ -365,12 +372,16 @@ def random_remove_punctuation(text: str) -> str:
def process_llm_response(text: str) -> List[str]:
# 先保护颜文字
protected_text, kaomoji_mapping = protect_kaomoji(text)
logger.trace(f"保护颜文字后的文本: {protected_text}")
if global_config.enable_kaomoji_protection:
protected_text, kaomoji_mapping = protect_kaomoji(text)
logger.trace(f"保护颜文字后的文本: {protected_text}")
else:
protected_text = text
kaomoji_mapping = {}
# 提取被 () 或 [] 包裹且包含中文的内容
pattern = re.compile(r"[\(\[\](?=.*[\u4e00-\u9fff]).*?[\)\]\]")
# _extracted_contents = pattern.findall(text)
extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
_extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
# 去除 () 和 [] 及其包裹的内容
cleaned_text = pattern.sub("", protected_text)
@@ -413,13 +424,14 @@ def process_llm_response(text: str) -> List[str]:
if len(sentences) > max_sentence_num:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f"{global_config.BOT_NICKNAME}不知道哦"]
if extracted_contents:
for content in extracted_contents:
sentences.append(content)
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
sentences = recover_kaomoji(sentences, kaomoji_mapping)
print(sentences)
# if extracted_contents:
# for content in extracted_contents:
# sentences.append(content)
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
if global_config.enable_kaomoji_protection:
sentences = recover_kaomoji(sentences, kaomoji_mapping)
return sentences

View File

@@ -121,7 +121,7 @@ class ImageManager:
prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
else:
prompt = "这是一个表情包,请用使用1-2个词描述一下表情包所表达的情感和内容,简短一些"
prompt = "这是一个表情包,请用使用个词描述一下表情包所表达的情感和内容,简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
cached_description = self._get_description_from_db(image_hash, "emoji")
@@ -130,7 +130,7 @@ class ImageManager:
return f"[表达了:{cached_description}]"
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
if global_config.save_emoji:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
@@ -152,7 +152,7 @@ class ImageManager:
"timestamp": timestamp,
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.success(f"保存表情包: {file_path}")
logger.trace(f"保存表情包: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
@@ -196,7 +196,7 @@ class ImageManager:
return "[图片]"
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
if global_config.save_pic:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
@@ -309,11 +309,15 @@ def image_path_to_base64(image_path: str) -> str:
image_path: 图片文件路径
Returns:
str: base64编码的图片数据
Raises:
FileNotFoundError: 当图片文件不存在时
IOError: 当读取图片文件失败时
"""
try:
with open(image_path, "rb") as f:
image_data = f.read()
return base64.b64encode(image_data).decode("utf-8")
except Exception as e:
logger.error(f"读取图片失败: {image_path}, 错误: {str(e)}")
return None
if not os.path.exists(image_path):
raise FileNotFoundError(f"图片文件不存在: {image_path}")
with open(image_path, "rb") as f:
image_data = f.read()
if not image_data:
raise IOError(f"读取图片文件失败: {image_path}")
return base64.b64encode(image_data).decode("utf-8")

View File

@@ -0,0 +1,827 @@
import asyncio
import base64
import hashlib
import os
import random
import time
import traceback
from typing import Optional, Tuple
from PIL import Image
import io
import re
from ...common.database import db
from ...config.config import global_config
from ..chat.utils_image import image_path_to_base64, image_manager
from ..models.utils_model import LLMRequest
from src.common.logger import get_module_logger, LogConfig, EMOJI_STYLE_CONFIG
emoji_log_config = LogConfig(
console_format=EMOJI_STYLE_CONFIG["console_format"],
file_format=EMOJI_STYLE_CONFIG["file_format"],
)
logger = get_module_logger("emoji", config=emoji_log_config)
BASE_DIR = os.path.join("data")
EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
"""
还没经过测试,有些地方数据库和内存数据同步可能不完全
"""
class MaiEmoji:
"""定义一个表情包"""
def __init__(self, filename: str, path: str):
self.path = path # 存储目录路径
self.filename = filename
self.embedding = []
self.hash = "" # 初始为空,在创建实例时会计算
self.description = ""
self.emotion = []
self.usage_count = 0
self.last_used_time = time.time()
self.register_time = time.time()
self.is_deleted = False # 标记是否已被删除
self.format = ""
async def initialize_hash_format(self):
"""从文件创建表情包实例
参数:
file_path: 文件的完整路径
返回:
MaiEmoji: 创建的表情包实例如果失败则返回None
"""
try:
file_path = os.path.join(self.path, self.filename)
if not os.path.exists(file_path):
logger.error(f"[错误] 表情包文件不存在: {file_path}")
return None
image_base64 = image_path_to_base64(file_path)
if image_base64 is None:
logger.error(f"[错误] 无法读取图片: {file_path}")
return None
# 计算哈希值
image_bytes = base64.b64decode(image_base64)
self.hash = hashlib.md5(image_bytes).hexdigest()
# 获取图片格式
self.format = Image.open(io.BytesIO(image_bytes)).format.lower()
except Exception as e:
logger.error(f"[错误] 初始化表情包失败: {str(e)}")
logger.error(traceback.format_exc())
return None
async def register_to_db(self):
"""
注册表情包
将表情包对应的文件从当前路径移动到EMOJI_REGISTED_DIR目录下
并修改对应的实例属性,然后将表情包信息保存到数据库中
"""
try:
# 确保目标目录存在
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
# 源路径是当前实例的完整路径
source_path = os.path.join(self.path, self.filename)
# 目标路径
destination_path = os.path.join(EMOJI_REGISTED_DIR, self.filename)
# 检查源文件是否存在
if not os.path.exists(source_path):
logger.error(f"[错误] 源文件不存在: {source_path}")
return False
# --- 文件移动 ---
try:
# 如果目标文件已存在,先删除 (确保移动成功)
if os.path.exists(destination_path):
os.remove(destination_path)
os.rename(source_path, destination_path)
logger.info(f"[移动] 文件从 {source_path} 移动到 {destination_path}")
# 更新实例的路径属性为新目录
self.path = EMOJI_REGISTED_DIR
except Exception as move_error:
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
return False # 文件移动失败,不继续
# --- 数据库操作 ---
try:
# 准备数据库记录 for emoji collection
emoji_record = {
"filename": self.filename,
"path": os.path.join(self.path, self.filename), # 使用更新后的路径
"embedding": self.embedding,
"description": self.description,
"emotion": self.emotion, # 添加情感标签字段
"hash": self.hash,
"format": self.format,
"timestamp": int(self.register_time), # 使用实例的注册时间
"usage_count": self.usage_count,
"last_used_time": self.last_used_time,
}
# 使用upsert确保记录存在或被更新
db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True)
logger.success(f"[注册] 表情包信息保存到数据库: {self.description}")
return True
except Exception as db_error:
logger.error(f"[错误] 保存数据库失败: {str(db_error)}")
# 考虑是否需要将文件移回?为了简化,暂时只记录错误
return False
except Exception as e:
logger.error(f"[错误] 注册表情包失败: {str(e)}")
logger.error(traceback.format_exc())
return False
async def delete(self):
"""删除表情包
删除表情包的文件和数据库记录
返回:
bool: 是否成功删除
"""
try:
# 1. 删除文件
if os.path.exists(os.path.join(self.path, self.filename)):
try:
os.remove(os.path.join(self.path, self.filename))
logger.info(f"[删除] 文件: {os.path.join(self.path, self.filename)}")
except Exception as e:
logger.error(f"[错误] 删除文件失败 {os.path.join(self.path, self.filename)}: {str(e)}")
# 继续执行,即使文件删除失败也尝试删除数据库记录
# 2. 删除数据库记录
result = db.emoji.delete_one({"hash": self.hash})
deleted_in_db = result.deleted_count > 0
if deleted_in_db:
logger.success(f"[删除] 成功删除表情包记录: {self.description}")
# 3. 标记对象已被删除
self.is_deleted = True
return True
else:
logger.error(f"[错误] 删除表情包记录失败: {self.hash}")
return False
except Exception as e:
logger.error(f"[错误] 删除表情包失败: {str(e)}")
return False
class EmojiManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
self._scan_task = None
self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.llm_emotion_judge = LLMRequest(
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0
self.emoji_num_max = global_config.max_emoji_num
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表使用类型注解明确列表元素类型
logger.info("启动表情包管理器")
def _ensure_emoji_dir(self):
"""确保表情存储目录存在"""
os.makedirs(EMOJI_DIR, exist_ok=True)
def initialize(self):
"""初始化数据库连接和表情目录"""
if not self._initialized:
try:
self._ensure_emoji_collection()
self._ensure_emoji_dir()
self._initialized = True
# 更新表情包数量
# 启动时执行一次完整性检查
# await self.check_emoji_file_integrity()
except Exception:
logger.exception("初始化表情管理器失败")
def _ensure_db(self):
"""确保数据库已初始化"""
if not self._initialized:
self.initialize()
if not self._initialized:
raise RuntimeError("EmojiManager not initialized")
@staticmethod
def _ensure_emoji_collection():
"""确保emoji集合存在并创建索引
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
索引的作用是加快数据库查询速度:
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
- tags字段的普通索引: 加快按标签搜索表情包的速度
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if "emoji" not in db.list_collection_names():
db.create_collection("emoji")
db.emoji.create_index([("embedding", "2dsphere")])
db.emoji.create_index([("filename", 1)], unique=True)
def record_usage(self, hash: str):
"""记录表情使用次数"""
try:
db.emoji.update_one({"hash": hash}, {"$inc": {"usage_count": 1}})
for emoji in self.emoji_objects:
if emoji.hash == hash:
emoji.usage_count += 1
break
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
async def get_emoji_for_text(self, text_emotion: str) -> Optional[Tuple[str, str]]:
"""根据文本内容获取相关表情包
Args:
text_emotion: 输入的情感描述文本
Returns:
Optional[Tuple[str, str]]: (表情包文件路径, 表情包描述)如果没有找到则返回None
"""
try:
self._ensure_db()
time_start = time.time()
# 获取所有表情包
all_emojis = self.emoji_objects
if not all_emojis:
logger.warning("数据库中没有任何表情包")
return None
# 计算每个表情包与输入文本的最大情感相似度
emoji_similarities = []
for emoji in all_emojis:
emotions = emoji.emotion
if not emotions:
continue
# 计算与每个emotion标签的相似度取最大值
max_similarity = 0
for emotion in emotions:
# 使用编辑距离计算相似度
distance = self._levenshtein_distance(text_emotion, emotion)
max_len = max(len(text_emotion), len(emotion))
similarity = 1 - (distance / max_len if max_len > 0 else 0)
max_similarity = max(max_similarity, similarity)
emoji_similarities.append((emoji, max_similarity))
# 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前5个最相似的表情包
top_5_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
if not top_5_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前5个中随机选择一个
selected_emoji, similarity = random.choice(top_5_emojis)
# 更新使用次数
self.record_usage(selected_emoji.hash)
time_end = time.time()
logger.info(
f"找到[{text_emotion}]表情包,用时:{time_end - time_start:.2f}秒: {selected_emoji.description} (相似度: {similarity:.4f})"
)
return selected_emoji.path, f"[ {selected_emoji.description} ]"
except Exception as e:
logger.error(f"[错误] 获取表情包失败: {str(e)}")
return None
def _levenshtein_distance(self, s1: str, s2: str) -> int:
"""计算两个字符串的编辑距离
Args:
s1: 第一个字符串
s2: 第二个字符串
Returns:
int: 编辑距离
"""
if len(s1) < len(s2):
return self._levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
async def check_emoji_file_integrity(self):
"""检查表情包文件完整性
遍历self.emoji_objects中的所有对象检查文件是否存在
如果文件已被删除,则执行对象的删除方法并从列表中移除
"""
try:
if not self.emoji_objects:
logger.warning("[检查] emoji_objects为空跳过完整性检查")
return
total_count = len(self.emoji_objects)
removed_count = 0
# 使用列表复制进行遍历,因为我们会在遍历过程中修改列表
for emoji in self.emoji_objects[:]:
try:
# 检查文件是否存在
if not os.path.exists(emoji.path):
logger.warning(f"[检查] 表情包文件已被删除: {emoji.path}")
# 执行表情包对象的删除方法
await emoji.delete()
# 从列表中移除该对象
self.emoji_objects.remove(emoji)
# 更新计数
self.emoji_num -= 1
removed_count += 1
continue
except Exception as item_error:
logger.error(f"[错误] 处理表情包记录时出错: {str(item_error)}")
continue
# 输出清理结果
if removed_count > 0:
logger.success(f"[清理] 已清理 {removed_count} 个失效的表情包记录")
logger.info(f"[统计] 清理前: {total_count} | 清理后: {len(self.emoji_objects)}")
else:
logger.info(f"[检查] 已检查 {total_count} 个表情包记录,全部完好")
except Exception as e:
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc())
async def start_periodic_check_register(self):
"""定期检查表情包完整性和数量"""
await self.get_all_emoji_from_db()
while True:
logger.info("[扫描] 开始检查表情包完整性...")
await self.check_emoji_file_integrity()
await self.clear_temp_emoji()
logger.info("[扫描] 开始扫描新表情包...")
# 检查表情包目录是否存在
if not os.path.exists(EMOJI_DIR):
logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}")
os.makedirs(EMOJI_DIR, exist_ok=True)
logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
continue
# 检查目录是否为空
files = os.listdir(EMOJI_DIR)
if not files:
logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
continue
# 检查是否需要处理表情包(数量超过最大值或不足)
if (self.emoji_num > self.emoji_num_max and global_config.max_reach_deletion) or (
self.emoji_num < self.emoji_num_max
):
try:
# 获取目录下所有图片文件
files_to_process = [
f
for f in files
if os.path.isfile(os.path.join(EMOJI_DIR, f))
and f.lower().endswith((".jpg", ".jpeg", ".png", ".gif"))
]
# 处理每个符合条件的文件
for filename in files_to_process:
# 尝试注册表情包
success = await self.register_emoji_by_filename(filename)
if success:
# 注册成功则跳出循环
break
else:
# 注册失败则删除对应文件
file_path = os.path.join(EMOJI_DIR, filename)
os.remove(file_path)
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
async def get_all_emoji_from_db(self):
"""获取所有表情包并初始化为MaiEmoji类对象
参数:
hash: 可选,如果提供则只返回指定哈希值的表情包
返回:
list[MaiEmoji]: 表情包对象列表
"""
try:
self._ensure_db()
# 获取所有表情包
all_emoji_data = list(db.emoji.find())
# 将数据库记录转换为MaiEmoji对象
emoji_objects = []
for emoji_data in all_emoji_data:
emoji = MaiEmoji(
filename=emoji_data.get("filename", ""),
path=emoji_data.get("path", ""),
)
# 设置额外属性
emoji.hash = emoji_data.get("hash", "")
emoji.usage_count = emoji_data.get("usage_count", 0)
emoji.last_used_time = emoji_data.get("last_used_time", emoji_data.get("timestamp", time.time()))
emoji.register_time = emoji_data.get("timestamp", time.time())
emoji.description = emoji_data.get("description", "")
emoji.emotion = emoji_data.get("emotion", []) # 添加情感标签的加载
emoji_objects.append(emoji)
# 存储到EmojiManager中
self.emoji_objects = emoji_objects
except Exception as e:
logger.error(f"[错误] 获取所有表情包对象失败: {str(e)}")
async def get_emoji_from_db(self, hash=None):
"""获取所有表情包并初始化为MaiEmoji类对象
参数:
hash: 可选,如果提供则只返回指定哈希值的表情包
返回:
list[MaiEmoji]: 表情包对象列表
"""
try:
self._ensure_db()
# 准备查询条件
query = {}
if hash:
query = {"hash": hash}
# 获取所有表情包
all_emoji_data = list(db.emoji.find(query))
# 将数据库记录转换为MaiEmoji对象
emoji_objects = []
for emoji_data in all_emoji_data:
emoji = MaiEmoji(
filename=emoji_data.get("filename", ""),
path=emoji_data.get("path", ""),
)
# 设置额外属性
emoji.usage_count = emoji_data.get("usage_count", 0)
emoji.last_used_time = emoji_data.get("last_used_time", emoji_data.get("timestamp", time.time()))
emoji.register_time = emoji_data.get("timestamp", time.time())
emoji.description = emoji_data.get("description", "")
emoji.emotion = emoji_data.get("emotion", []) # 添加情感标签的加载
emoji_objects.append(emoji)
# 存储到EmojiManager中
self.emoji_objects = emoji_objects
return emoji_objects
except Exception as e:
logger.error(f"[错误] 获取所有表情包对象失败: {str(e)}")
return []
async def get_emoji_from_manager(self, hash) -> MaiEmoji:
"""从EmojiManager中获取表情包
参数:
hash:如果提供则只返回指定哈希值的表情包
"""
for emoji in self.emoji_objects:
if emoji.hash == hash:
return emoji
return None
async def delete_emoji(self, emoji_hash: str) -> bool:
"""根据哈希值删除表情包
Args:
emoji_hash: 表情包的哈希值
Returns:
bool: 是否成功删除
"""
try:
self._ensure_db()
# 从emoji_objects中查找表情包对象
emoji = await self.get_emoji_from_manager(emoji_hash)
if not emoji:
logger.warning(f"[警告] 未找到哈希值为 {emoji_hash} 的表情包")
return False
# 使用MaiEmoji对象的delete方法删除表情包
success = await emoji.delete()
if success:
# 从emoji_objects列表中移除该对象
self.emoji_objects = [e for e in self.emoji_objects if e.hash != emoji_hash]
# 更新计数
self.emoji_num -= 1
logger.info(f"[统计] 当前表情包数量: {self.emoji_num}")
return True
else:
logger.error(f"[错误] 删除表情包失败: {emoji_hash}")
return False
except Exception as e:
logger.error(f"[错误] 删除表情包失败: {str(e)}")
logger.error(traceback.format_exc())
return False
def _emoji_objects_to_readable_list(self, emoji_objects):
"""将表情包对象列表转换为可读的字符串列表
参数:
emoji_objects: MaiEmoji对象列表
返回:
list[str]: 可读的表情包信息字符串列表
"""
emoji_info_list = []
for i, emoji in enumerate(emoji_objects):
# 转换时间戳为可读时间
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(emoji.register_time))
# 构建每个表情包的信息字符串
emoji_info = (
f"编号: {i + 1}\n描述: {emoji.description}\n使用次数: {emoji.usage_count}\n添加时间: {time_str}\n"
)
emoji_info_list.append(emoji_info)
return emoji_info_list
async def replace_a_emoji(self, new_emoji: MaiEmoji):
"""替换一个表情包
Args:
new_emoji: 新表情包对象
Returns:
bool: 是否成功替换表情包
"""
try:
self._ensure_db()
# 获取所有表情包对象
all_emojis = self.emoji_objects
# 将表情包信息转换为可读的字符串
emoji_info_list = self._emoji_objects_to_readable_list(all_emojis)
# 构建提示词
prompt = (
f"{global_config.BOT_NICKNAME}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})"
f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n"
f"新表情包信息:\n"
f"描述: {new_emoji.description}\n\n"
f"现有表情包列表:\n" + "\n".join(emoji_info_list) + "\n\n"
"请决定:\n"
"1. 是否要删除某个现有表情包来为新表情包腾出空间?\n"
"2. 如果要删除,应该删除哪一个(给出编号)\n"
"请只回答:'不删除''删除编号X'(X为表情包编号)。"
)
# 调用大模型进行决策
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8)
logger.info(f"[决策] 大模型决策结果: {decision}")
# 解析决策结果
if "不删除" in decision:
logger.info("[决策] 决定不删除任何表情包")
return False
# 尝试从决策中提取表情包编号
match = re.search(r"删除编号(\d+)", decision)
if match:
emoji_index = int(match.group(1)) - 1 # 转换为0-based索引
# 检查索引是否有效
if 0 <= emoji_index < len(all_emojis):
emoji_to_delete = all_emojis[emoji_index]
# 删除选定的表情包
logger.info(f"[决策] 决定删除表情包: {emoji_to_delete.description}")
delete_success = await self.delete_emoji(emoji_to_delete.hash)
if delete_success:
# 修复:等待异步注册完成
register_success = await new_emoji.register_to_db()
if register_success:
self.emoji_objects.append(new_emoji)
self.emoji_num += 1
logger.success(f"[成功] 注册表情包: {new_emoji.description}")
return True
else:
logger.error(f"[错误] 注册表情包到数据库失败: {new_emoji.filename}")
return False
else:
logger.error("[错误] 删除表情包失败,无法完成替换")
return False
else:
logger.error(f"[错误] 无效的表情包编号: {emoji_index + 1}")
else:
logger.error(f"[错误] 无法从决策中提取表情包编号: {decision}")
return False
except Exception as e:
logger.error(f"[错误] 替换表情包失败: {str(e)}")
logger.error(traceback.format_exc())
return False
async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]:
"""获取表情包描述和情感列表
Args:
image_base64: 图片的base64编码
Returns:
Tuple[str, list]: 返回表情包描述和情感列表
"""
try:
# 解码图片并获取格式
image_bytes = base64.b64decode(image_base64)
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 调用AI获取描述
if image_format == "gif" or image_format == "GIF":
image_base64 = image_manager.transform_gif(image_base64)
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,详细描述一下表情包表达的情感和内容,请关注其幽默和讽刺意味"
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
else:
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,请关注其幽默和讽刺意味"
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
# 审核表情包
if global_config.EMOJI_CHECK:
prompt = f'''
这是一个表情包,请对这个表情包进行审核,标准如下:
1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求
2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
3. 不能是任何形式的截图,聊天记录或视频截图
4. 不要出现5个以上文字
请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
'''
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
if content == "":
return None, []
# 分析情感含义
emotion_prompt = f"""
基于这个表情包的描述:'{description}'请列出1-2个可能的情感标签每个标签用一个词组表示格式如下
幽默的讽刺
悲伤的无奈
愤怒的抗议
愤怒的讽刺
直接输出词组,词组检用逗号分隔。"""
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7)
# 处理情感列表
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
return f"[表情包:{description}]", emotions
except Exception as e:
logger.error(f"获取表情包描述失败: {str(e)}")
return "", []
async def register_emoji_by_filename(self, filename: str) -> bool:
"""读取指定文件名的表情包图片,分析并注册到数据库
Args:
filename: 表情包文件名必须位于EMOJI_DIR目录下
Returns:
bool: 注册是否成功
"""
try:
# 使用MaiEmoji类创建表情包实例
new_emoji = MaiEmoji(filename, EMOJI_DIR)
await new_emoji.initialize_hash_format()
emoji_base64 = image_path_to_base64(os.path.join(EMOJI_DIR, filename))
description, emotions = await self.build_emoji_description(emoji_base64)
if description == "":
return False
new_emoji.description = description
new_emoji.emotion = emotions
# 检查是否已经注册过
# 对比内存中是否存在相同哈希值的表情包
if await self.get_emoji_from_manager(new_emoji.hash):
logger.warning(f"[警告] 表情包已存在: {filename}")
return False
if self.emoji_num >= self.emoji_num_max:
logger.warning(f"表情包数量已达到上限({self.emoji_num}/{self.emoji_num_max})")
replaced = await self.replace_a_emoji(new_emoji)
if not replaced:
logger.error("[错误] 替换表情包失败,无法完成注册")
return False
else:
# 修复:等待异步注册完成
register_success = await new_emoji.register_to_db()
if register_success:
self.emoji_objects.append(new_emoji)
self.emoji_num += 1
logger.success(f"[成功] 注册表情包: {filename}")
return True
else:
logger.error(f"[错误] 注册表情包到数据库失败: {filename}")
return False
except Exception as e:
logger.error(f"[错误] 注册表情包失败: {str(e)}")
logger.error(traceback.format_exc())
return False
async def clear_temp_emoji(self):
"""每天清理临时表情包
清理/data/emoji和/data/image目录下的所有文件
当目录中文件数超过50时会全部删除
"""
logger.info("[清理] 开始清理临时表情包...")
# 清理emoji目录
emoji_dir = os.path.join(BASE_DIR, "emoji")
if os.path.exists(emoji_dir):
files = os.listdir(emoji_dir)
# 如果文件数超过50就全部删除
if len(files) > 50:
for filename in files:
file_path = os.path.join(emoji_dir, filename)
if os.path.isfile(file_path):
os.remove(file_path)
logger.debug(f"[清理] 删除表情包文件: {filename}")
# 清理image目录
image_dir = os.path.join(BASE_DIR, "image")
if os.path.exists(image_dir):
files = os.listdir(image_dir)
# 如果文件数超过50就全部删除
if len(files) > 50:
for filename in files:
file_path = os.path.join(image_dir, filename)
if os.path.isfile(file_path):
os.remove(file_path)
logger.debug(f"[清理] 删除图片文件: {filename}")
logger.success("[清理] 临时文件清理完成")
# 创建全局单例
emoji_manager = EmojiManager()

View File

@@ -0,0 +1,74 @@
import time
from typing import List, Optional, Dict, Any
class CycleInfo:
"""循环信息记录类"""
def __init__(self, cycle_id: int):
self.cycle_id = cycle_id
self.start_time = time.time()
self.end_time: Optional[float] = None
self.action_taken = False
self.action_type = "unknown"
self.reasoning = ""
self.timers: Dict[str, float] = {}
self.thinking_id = ""
self.replanned = False
# 添加响应信息相关字段
self.response_info: Dict[str, Any] = {
"response_text": [], # 回复的文本列表
"emoji_info": "", # 表情信息
"anchor_message_id": "", # 锚点消息ID
"reply_message_ids": [], # 回复消息ID列表
"sub_mind_thinking": "", # 子思维思考内容
}
def to_dict(self) -> Dict[str, Any]:
"""将循环信息转换为字典格式"""
return {
"cycle_id": self.cycle_id,
"start_time": self.start_time,
"end_time": self.end_time,
"action_taken": self.action_taken,
"action_type": self.action_type,
"reasoning": self.reasoning,
"timers": self.timers,
"thinking_id": self.thinking_id,
"response_info": self.response_info,
}
def complete_cycle(self):
"""完成循环,记录结束时间"""
self.end_time = time.time()
def set_action_info(self, action_type: str, reasoning: str, action_taken: bool):
"""设置动作信息"""
self.action_type = action_type
self.reasoning = reasoning
self.action_taken = action_taken
def set_thinking_id(self, thinking_id: str):
"""设置思考消息ID"""
self.thinking_id = thinking_id
def set_response_info(
self,
response_text: Optional[List[str]] = None,
emoji_info: Optional[str] = None,
anchor_message_id: Optional[str] = None,
reply_message_ids: Optional[List[str]] = None,
sub_mind_thinking: Optional[str] = None,
):
"""设置响应信息"""
if response_text is not None:
self.response_info["response_text"] = response_text
if emoji_info is not None:
self.response_info["emoji_info"] = emoji_info
if anchor_message_id is not None:
self.response_info["anchor_message_id"] = anchor_message_id
if reply_message_ids is not None:
self.response_info["reply_message_ids"] = reply_message_ids
if sub_mind_thinking is not None:
self.response_info["sub_mind_thinking"] = sub_mind_thinking

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,92 @@
# HeartFChatting 逻辑详解
`HeartFChatting` 类是心流系统Heart Flow System中实现**专注聊天**`ChatState.FOCUSED`)功能的核心。顾名思义,其职责乃是在特定聊天流(`stream_id`)中,模拟更为连贯深入之对话。此非凭空臆造,而是依赖一个持续不断的 **思考(Think)-规划(Plan)-执行(Execute)** 循环。当其所系的 `SubHeartflow` 进入 `FOCUSED` 状态时,便会创建并启动 `HeartFChatting` 实例;若状态转为他途(譬如 `CHAT``ABSENT`),则会将其关闭。
## 1. 初始化简述 (`__init__`, `_initialize`)
创生之初,`HeartFChatting` 需注入若干关键之物:`chat_id`(亦即 `stream_id`)、关联的 `SubMind` 实例,以及 `Observation` 实例(用以观察环境)。
其内部核心组件包括:
- `ActionManager`: 管理当前循环可选之策(如:不应、言语、表情)。
- `HeartFCGenerator` (`self.gpt_instance`): 专司生成回复文本之职。
- `ToolUser` (`self.tool_user`): 虽主要用于获取工具定义,然亦备 `SubMind` 调用之需(实际执行由 `SubMind` 操持)。
- `HeartFCSender` (`self.heart_fc_sender`): 负责消息发送诸般事宜,含"正在思考"之态。
- `LLMRequest` (`self.planner_llm`): 配置用于执行"规划"任务的大语言模型。
*初始化过程采取懒加载策略,仅在首次需要访问 `ChatStream` 时(通常在 `start` 方法中)进行。*
## 2. 生命周期 (`start`, `shutdown`)
- **启动 (`start`)**: 外部调用此法,以启 `HeartFChatting` 之流程。内部会安全地启动主循环任务。
- **关闭 (`shutdown`)**: 外部调用此法,以止其运行。会取消主循环任务,清理状态,并释放锁。
## 3. 核心循环 (`_hfc_loop`) 与 循环记录 (`CycleInfo`)
`_hfc_loop``HeartFChatting` 之脉搏,以异步方式不舍昼夜运行(直至 `shutdown` 被调用)。其核心在于周而复始地执行 **思考-规划-执行** 之周期。
每一轮循环,皆会创建一个 `CycleInfo` 对象。此对象犹如史官,详细记载该次循环之点滴:
- **身份标识**: 循环 ID (`cycle_id`)。
- **时间轨迹**: 起止时刻 (`start_time`, `end_time`)。
- **行动细节**: 是否执行动作 (`action_taken`)、动作类型 (`action_type`)、决策理由 (`reasoning`)。
- **耗时考量**: 各阶段计时 (`timers`)。
- **关联信息**: 思考消息 ID (`thinking_id`)、是否重新规划 (`replanned`)、详尽响应信息 (`response_info`含生成文本、表情、锚点、实际发送ID、`SubMind`思考等)。
这些 `CycleInfo` 被存入一个队列 (`_cycle_history`),近者得观。此记录不仅便于调试,更关键的是,它会作为**上下文信息**传递给下一次循环的"思考"阶段,使得 `SubMind` 能鉴往知来,做出更连贯的决策。
*循环间会根据执行情况智能引入延迟,避免空耗资源。*
## 4. 思考-规划-执行周期 (`_think_plan_execute_loop`)
此乃 `HeartFChatting` 最核心的逻辑单元,每一循环皆按序执行以下三步:
### 4.1. 思考 (`_get_submind_thinking`)
* **第一步:观察环境**: 调用 `Observation``observe()` 方法,感知聊天室是否有新动态(如新消息)。
* **第二步:触发子思维**: 调用关联 `SubMind``do_thinking_before_reply()` 方法。
* **关键点**: 会将**上一个循环**的 `CycleInfo` 传入,让 `SubMind` 了解上次行动的决策、理由及是否重新规划,从而实现"承前启后"的思考。
* `SubMind` 在此阶段不仅进行思考,还可能**调用其配置的工具**来收集信息。
* **第三步:获取成果**: `SubMind` 返回两部分重要信息:
1. 当前的内心想法 (`current_mind`)。
2. 通过工具调用收集到的结构化信息 (`structured_info`)。
### 4.2. 规划 (`_planner`)
* **输入**: 接收来自"思考"阶段的 `current_mind``structured_info`,以及"观察"到的最新消息。
* **目标**: 基于当前想法、已知信息、聊天记录、机器人个性以及可用动作,决定**接下来要做什么**。
* **决策方式**:
1. 构建一个精心设计的提示词 (`_build_planner_prompt`)。
2. 获取 `ActionManager` 中定义的当前可用动作(如 `no_reply`, `text_reply`, `emoji_reply`)作为"工具"选项。
3. 调用大语言模型 (`self.planner_llm`)**强制**其选择一个动作"工具"并提供理由。可选动作包括:
* `no_reply`: 不回复(例如,自己刚说过话或对方未回应)。
* `text_reply`: 发送文本回复。
* `emoji_reply`: 仅发送表情。
* 文本回复亦可附带表情(通过 `emoji_query` 参数指定)。
* **动态调整(重新规划)**:
* 在做出初步决策后,会检查自规划开始后是否有新消息 (`_check_new_messages`)。
* 若有新消息,则有一定概率触发**重新规划**。此时会再次调用规划器,但提示词会包含之前决策的信息,要求 LLM 重新考虑。
* **输出**: 返回一个包含最终决策的字典,主要包括:
* `action`: 选定的动作类型。
* `reasoning`: 做出此决策的理由。
* `emoji_query`: (可选) 如果需要发送表情,指定表情的主题。
### 4.3. 执行 (`_handle_action`)
* **输入**: 接收"规划"阶段输出的 `action``reasoning``emoji_query`
* **行动**: 根据 `action` 的类型,分派到不同的处理函数:
* **文本回复 (`_handle_text_reply`)**:
1. 获取锚点消息(当前实现为系统触发的占位符)。
2. 调用 `HeartFCSender``register_thinking` 标记开始思考。
3. 调用 `HeartFCGenerator` (`_replier_work`) 生成回复文本。**注意**: 回复器逻辑 (`_replier_work`) 本身并非独立复杂组件,主要是调用 `HeartFCGenerator` 完成文本生成。
4. 调用 `HeartFCSender` (`_sender`) 发送生成的文本和可能的表情。**注意**: 发送逻辑 (`_sender`, `_send_response_messages`, `_handle_emoji`) 同样委托给 `HeartFCSender` 实例处理,包含模拟打字、实际发送、存储消息等细节。
* **仅表情回复 (`_handle_emoji_reply`)**:
1. 获取锚点消息。
2. 调用 `HeartFCSender` 发送表情。
* **不回复 (`_handle_no_reply`)**:
1. 记录理由。
2. 进入等待状态 (`_wait_for_new_message`)直到检测到新消息或超时目前300秒期间会监听关闭信号。
## 总结
`HeartFChatting` 通过 **观察 -> 思考(含工具)-> 规划 -> 执行** 的闭环,并利用 `CycleInfo` 进行上下文传递,实现了更加智能和连贯的专注聊天行为。其核心在于利用 `SubMind` 进行深度思考和信息收集,再通过 LLM 规划器进行决策,最后由 `HeartFCSender` 可靠地执行消息发送任务。

View File

@@ -8,7 +8,7 @@ from .heartflow_prompt_builder import prompt_builder
from ..chat.utils import process_llm_response
from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
from ..utils.timer_calculater import Timer
from ..utils.timer_calculator import Timer
from src.plugins.moods.moods import MoodManager
@@ -49,17 +49,13 @@ class HeartFCGenerator:
arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier()
with Timer() as t_generate_response:
current_model = self.model_normal
current_model.temperature = global_config.llm_normal["temp"] * arousal_multiplier # 激活度越高,温度越高
model_response = await self._generate_response_with_model(
structured_info, current_mind_info, reason, message, current_model, thinking_id
)
current_model = self.model_normal
current_model.temperature = global_config.llm_normal["temp"] * arousal_multiplier # 激活度越高,温度越高
model_response = await self._generate_response_with_model(
structured_info, current_mind_info, reason, message, current_model, thinking_id
)
if model_response:
logger.info(
f"{global_config.BOT_NICKNAME}的回复是:{model_response},生成回复时间: {t_generate_response.human_readable}"
)
model_processed_response = await self._process_response(model_response)
return model_processed_response
@@ -78,7 +74,7 @@ class HeartFCGenerator:
) -> str:
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
with Timer() as t_build_prompt:
with Timer() as _build_prompt:
prompt = await prompt_builder.build_prompt(
build_mode="focus",
reason=reason,

View File

@@ -0,0 +1,159 @@
# HeartFC_chat 工作原理文档
HeartFC_chat 是一个基于心流理论的聊天系统通过模拟人类的思维过程和情感变化来实现自然的对话交互。系统采用Plan-Replier-Sender循环机制实现了智能化的对话决策和生成。
## 核心工作流程
### 1. 消息处理与存储 (HeartFCProcessor)
[代码位置: src/plugins/heartFC_chat/heartflow_processor.py]
消息处理器负责接收和预处理消息,主要完成以下工作:
```mermaid
graph TD
A[接收原始消息] --> B[解析为MessageRecv对象]
B --> C[消息缓冲处理]
C --> D[过滤检查]
D --> E[存储到数据库]
```
核心实现:
- 消息处理入口:`process_message()` [行号: 38-215]
- 消息解析和缓冲:`message_buffer.start_caching_messages()` [行号: 63]
- 过滤检查:`_check_ban_words()`, `_check_ban_regex()` [行号: 196-215]
- 消息存储:`storage.store_message()` [行号: 108]
### 2. 对话管理循环 (HeartFChatting)
[代码位置: src/plugins/heartFC_chat/heartFC_chat.py]
HeartFChatting是系统的核心组件实现了完整的对话管理循环
```mermaid
graph TD
A[Plan阶段] -->|决策是否回复| B[Replier阶段]
B -->|生成回复内容| C[Sender阶段]
C -->|发送消息| D[等待新消息]
D --> A
```
#### Plan阶段 [行号: 282-386]
- 主要函数:`_planner()`
- 功能实现:
* 获取观察信息:`observation.observe()` [行号: 297]
* 思维处理:`sub_mind.do_thinking_before_reply()` [行号: 301]
* LLM决策使用`PLANNER_TOOL_DEFINITION`进行动作规划 [行号: 13-42]
#### Replier阶段 [行号: 388-416]
- 主要函数:`_replier_work()`
- 调用生成器:`gpt_instance.generate_response()` [行号: 394]
- 处理生成结果和错误情况
#### Sender阶段 [行号: 418-450]
- 主要函数:`_sender()`
- 发送实现:
* 创建消息:`_create_thinking_message()` [行号: 452-477]
* 发送回复:`_send_response_messages()` [行号: 479-525]
* 处理表情:`_handle_emoji()` [行号: 527-567]
### 3. 回复生成机制 (HeartFCGenerator)
[代码位置: src/plugins/heartFC_chat/heartFC_generator.py]
回复生成器负责产生高质量的回复内容:
```mermaid
graph TD
A[获取上下文信息] --> B[构建提示词]
B --> C[调用LLM生成]
C --> D[后处理优化]
D --> E[返回回复集]
```
核心实现:
- 生成入口:`generate_response()` [行号: 39-67]
* 情感调节:`arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier()` [行号: 47]
* 模型生成:`_generate_response_with_model()` [行号: 69-95]
* 响应处理:`_process_response()` [行号: 97-106]
### 4. 提示词构建系统 (HeartFlowPromptBuilder)
[代码位置: src/plugins/heartFC_chat/heartflow_prompt_builder.py]
提示词构建器支持两种工作模式HeartFC_chat专门使用Focus模式而Normal模式是为normal_chat设计的
#### 专注模式 (Focus Mode) - HeartFC_chat专用
- 实现函数:`_build_prompt_focus()` [行号: 116-141]
- 特点:
* 专注于当前对话状态和思维
* 更强的目标导向性
* 用于HeartFC_chat的Plan-Replier-Sender循环
* 简化的上下文处理,专注于决策
#### 普通模式 (Normal Mode) - Normal_chat专用
- 实现函数:`_build_prompt_normal()` [行号: 143-215]
- 特点:
* 用于normal_chat的常规对话
* 完整的个性化处理
* 关系系统集成
* 知识库检索:`get_prompt_info()` [行号: 217-591]
HeartFC_chat的Focus模式工作流程
```mermaid
graph TD
A[获取结构化信息] --> B[获取当前思维状态]
B --> C[构建专注模式提示词]
C --> D[用于Plan阶段决策]
D --> E[用于Replier阶段生成]
```
## 智能特性
### 1. 对话决策机制
- LLM决策工具定义`PLANNER_TOOL_DEFINITION` [heartFC_chat.py 行号: 13-42]
- 决策执行:`_planner()` [heartFC_chat.py 行号: 282-386]
- 考虑因素:
* 上下文相关性
* 情感状态
* 兴趣程度
* 对话时机
### 2. 状态管理
[代码位置: src/plugins/heartFC_chat/heartFC_chat.py]
- 状态机实现:`HeartFChatting`类 [行号: 44-567]
- 核心功能:
* 初始化:`_initialize()` [行号: 89-112]
* 循环控制:`_run_pf_loop()` [行号: 192-281]
* 状态转换:`_handle_loop_completion()` [行号: 166-190]
### 3. 回复生成策略
[代码位置: src/plugins/heartFC_chat/heartFC_generator.py]
- 温度调节:`current_model.temperature = global_config.llm_normal["temp"] * arousal_multiplier` [行号: 48]
- 生成控制:`_generate_response_with_model()` [行号: 69-95]
- 响应处理:`_process_response()` [行号: 97-106]
## 系统配置
### 关键参数
- LLM配置`model_normal` [heartFC_generator.py 行号: 32-37]
- 过滤规则:`_check_ban_words()`, `_check_ban_regex()` [heartflow_processor.py 行号: 196-215]
- 状态控制:`INITIAL_DURATION = 60.0` [heartFC_chat.py 行号: 11]
### 优化建议
1. 调整LLM参数`temperature``max_tokens`
2. 优化提示词模板:`init_prompt()` [heartflow_prompt_builder.py 行号: 8-115]
3. 配置状态转换条件
4. 维护过滤规则
## 注意事项
1. 系统稳定性
- 异常处理各主要函数都包含try-except块
- 状态检查:`_processing_lock`确保并发安全
- 循环控制:`_loop_active``_loop_task`管理
2. 性能优化
- 缓存使用:`message_buffer`系统
- LLM调用优化批量处理和复用
- 异步处理:使用`asyncio`
3. 质量控制
- 日志记录:使用`get_module_logger()`
- 错误追踪:详细的异常记录
- 响应监控:完整的状态跟踪

View File

@@ -0,0 +1,153 @@
# src/plugins/heartFC_chat/heartFC_sender.py
import asyncio # 重新导入 asyncio
from typing import Dict, Optional # 重新导入类型
from src.common.logger import get_module_logger
from ..message.api import global_api
from ..chat.message import MessageSending, MessageThinking # 只保留 MessageSending 和 MessageThinking
from ..storage.storage import MessageStorage
from ..chat.utils import truncate_message
from src.common.logger import LogConfig, SENDER_STYLE_CONFIG
from src.plugins.chat.utils import calculate_typing_time
# 定义日志配置
sender_config = LogConfig(
# 使用消息发送专用样式
console_format=SENDER_STYLE_CONFIG["console_format"],
file_format=SENDER_STYLE_CONFIG["file_format"],
)
logger = get_module_logger("msg_sender", config=sender_config)
class HeartFCSender:
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
def __init__(self):
self.storage = MessageStorage()
# 用于存储活跃的思考消息
self.thinking_messages: Dict[str, Dict[str, MessageThinking]] = {}
self._thinking_lock = asyncio.Lock() # 保护 thinking_messages 的锁
async def send_message(self, message: MessageSending) -> None:
"""合并后的消息发送函数包含WS发送和日志记录"""
message_preview = truncate_message(message.processed_plain_text)
try:
# 直接调用API发送消息
await global_api.send_message(message)
logger.success(f"发送消息 '{message_preview}' 成功")
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}")
if not message.message_info.platform:
raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置请检查配置文件") from e
raise e # 重新抛出其他异常
async def register_thinking(self, thinking_message: MessageThinking):
"""注册一个思考中的消息。"""
if not thinking_message.chat_stream or not thinking_message.message_info.message_id:
logger.error("无法注册缺少 chat_stream 或 message_id 的思考消息")
return
chat_id = thinking_message.chat_stream.stream_id
message_id = thinking_message.message_info.message_id
async with self._thinking_lock:
if chat_id not in self.thinking_messages:
self.thinking_messages[chat_id] = {}
if message_id in self.thinking_messages[chat_id]:
logger.warning(f"[{chat_id}] 尝试注册已存在的思考消息 ID: {message_id}")
self.thinking_messages[chat_id][message_id] = thinking_message
logger.debug(f"[{chat_id}] Registered thinking message: {message_id}")
async def complete_thinking(self, chat_id: str, message_id: str):
"""完成并移除一个思考中的消息记录。"""
async with self._thinking_lock:
if chat_id in self.thinking_messages and message_id in self.thinking_messages[chat_id]:
del self.thinking_messages[chat_id][message_id]
logger.debug(f"[{chat_id}] Completed thinking message: {message_id}")
if not self.thinking_messages[chat_id]:
del self.thinking_messages[chat_id]
logger.debug(f"[{chat_id}] Removed empty thinking message container.")
def is_thinking(self, chat_id: str, message_id: str) -> bool:
"""检查指定的消息 ID 是否当前正处于思考状态。"""
return chat_id in self.thinking_messages and message_id in self.thinking_messages[chat_id]
async def get_thinking_start_time(self, chat_id: str, message_id: str) -> Optional[float]:
"""获取已注册思考消息的开始时间。"""
async with self._thinking_lock:
thinking_message = self.thinking_messages.get(chat_id, {}).get(message_id)
return thinking_message.thinking_start_time if thinking_message else None
async def type_and_send_message(self, message: MessageSending, type=False):
"""
立即处理、发送并存储单个 MessageSending 消息。
调用此方法前,应先调用 register_thinking 注册对应的思考消息。
此方法执行后会调用 complete_thinking 清理思考状态。
"""
if not message.chat_stream:
logger.error("消息缺少 chat_stream无法发送")
return
if not message.message_info or not message.message_info.message_id:
logger.error("消息缺少 message_info 或 message_id无法发送")
return
chat_id = message.chat_stream.stream_id
message_id = message.message_info.message_id
try:
_ = message.update_thinking_time()
# --- 条件应用 set_reply 逻辑 ---
if message.apply_set_reply_logic and message.is_head and not message.is_private_message():
logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...")
message.set_reply()
# --- 结束条件 set_reply ---
await message.process()
if type:
typing_time = calculate_typing_time(
input_string=message.processed_plain_text,
thinking_start_time=message.thinking_start_time,
is_emoji=message.is_emoji,
)
await asyncio.sleep(typing_time)
await self.send_message(message)
await self.storage.store_message(message, message.chat_stream)
except Exception as e:
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
raise e
finally:
await self.complete_thinking(chat_id, message_id)
async def send_and_store(self, message: MessageSending):
"""处理、发送并存储单个消息,不涉及思考状态管理。"""
if not message.chat_stream:
logger.error(f"[{message.message_info.platform or 'UnknownPlatform'}] 消息缺少 chat_stream无法发送")
return
if not message.message_info or not message.message_info.message_id:
logger.error(
f"[{message.chat_stream.stream_id if message.chat_stream else 'UnknownStream'}] 消息缺少 message_info 或 message_id无法发送"
)
return
chat_id = message.chat_stream.stream_id
message_id = message.message_info.message_id # 获取消息ID用于日志
try:
await message.process()
await asyncio.sleep(0.5)
await self.send_message(message) # 使用现有的发送方法
await self.storage.store_message(message, message.chat_stream) # 使用现有的存储方法
except Exception as e:
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
# 重新抛出异常,让调用者知道失败了
raise e

View File

@@ -5,13 +5,14 @@ from ...config.config import global_config
from ..chat.message import MessageRecv
from ..storage.storage import MessageStorage
from ..chat.utils import is_mentioned_bot_in_message
from ..message import Seg
from maim_message import Seg
from src.heart_flow.heartflow import heartflow
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
from ..chat.chat_stream import chat_manager
from ..chat.message_buffer import message_buffer
from ..utils.timer_calculater import Timer
from ..utils.timer_calculator import Timer
from src.plugins.person_info.relationship_manager import relationship_manager
from typing import Optional, Tuple
# 定义日志配置
processor_config = LogConfig(
@@ -22,193 +23,202 @@ logger = get_module_logger("heartflow_processor", config=processor_config)
class HeartFCProcessor:
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
def __init__(self):
"""初始化心流处理器,创建消息存储实例"""
self.storage = MessageStorage()
async def process_message(self, message_data: str) -> None:
"""处理接收到的原始消息数据,完成消息解析、缓冲、过滤、存储、兴趣度计算与更新等核心流程。
此函数是消息处理的核心入口,负责接收原始字符串格式的消息数据,并将其转化为结构化的 `MessageRecv` 对象。
主要执行步骤包括:
1. 解析 `message_data` 为 `MessageRecv` 对象,提取用户信息、群组信息等。
2. 将消息加入 `message_buffer` 进行缓冲处理,以应对消息轰炸或者某些人一条消息分几次发等情况。
3. 获取或创建对应的 `chat_stream` 和 `subheartflow` 实例,用于管理会话状态和心流。
4. 对消息内容进行初步处理(如提取纯文本)。
5. 应用全局配置中的过滤词和正则表达式,过滤不符合规则的消息。
6. 查询消息缓冲结果,如果消息被缓冲器拦截(例如,判断为消息轰炸的一部分),则中止后续处理。
7. 对于通过缓冲的消息,将其存储到 `MessageStorage` 中。
8. 调用海马体(`HippocampusManager`)计算消息内容的记忆激活率。(这部分算法后续会进行优化)
9. 根据是否被提及(@)和记忆激活率,计算最终的兴趣度增量。(提及的额外兴趣增幅)
10. 使用计算出的增量更新 `InterestManager` 中对应会话的兴趣度。
11. 记录处理后的消息信息及当前的兴趣度到日志。
注意:此函数本身不负责生成和发送回复。回复的决策和生成逻辑被移至 `HeartFC_Chat` 类中的监控任务,
该任务会根据 `InterestManager` 中的兴趣度变化来决定何时触发回复。
async def _handle_error(self, error: Exception, context: str, message: Optional[MessageRecv] = None) -> None:
"""统一的错误处理函数
Args:
message_data: str: 从消息源接收到的原始消息字符串。
error: 捕获到的异常
context: 错误发生的上下文描述
message: 可选的消息对象,用于记录相关消息内容
"""
logger.error(f"{context}: {error}")
logger.error(traceback.format_exc())
if message and hasattr(message, "raw_message"):
logger.error(f"相关消息原始内容: {message.raw_message}")
async def _process_relationship(self, message: MessageRecv) -> None:
"""处理用户关系逻辑
Args:
message: 消息对象,包含用户信息
"""
platform = message.message_info.platform
user_id = message.message_info.user_info.user_id
nickname = message.message_info.user_info.user_nickname
cardname = message.message_info.user_info.user_cardname or nickname
is_known = await relationship_manager.is_known_some_one(platform, user_id)
if not is_known:
logger.info(f"首次认识用户: {nickname}")
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
elif not await relationship_manager.is_qved_name(platform, user_id):
logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
async def _calculate_interest(self, message: MessageRecv) -> Tuple[float, bool]:
"""计算消息的兴趣度
Args:
message: 待处理的消息对象
Returns:
Tuple[float, bool]: (兴趣度, 是否被提及)
"""
is_mentioned, _ = is_mentioned_bot_in_message(message)
interested_rate = 0.0
with Timer("记忆激活"):
interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
message.processed_plain_text,
fast_retrieval=True,
)
logger.trace(f"记忆激活率: {interested_rate:.2f}")
if is_mentioned:
interest_increase_on_mention = 1
interested_rate += interest_increase_on_mention
return interested_rate, is_mentioned
def _get_message_type(self, message: MessageRecv) -> str:
"""获取消息类型
Args:
message: 消息对象
Returns:
str: 消息类型
"""
if message.message_segment.type != "seglist":
return message.message_segment.type
if (
isinstance(message.message_segment.data, list)
and all(isinstance(x, Seg) for x in message.message_segment.data)
and len(message.message_segment.data) == 1
):
return message.message_segment.data[0].type
return "seglist"
async def process_message(self, message_data: str) -> None:
"""处理接收到的原始消息数据
主要流程:
1. 消息解析与初始化
2. 消息缓冲处理
3. 过滤检查
4. 兴趣度计算
5. 关系处理
Args:
message_data: 原始消息字符串
"""
timing_results = {} # 初始化 timing_results
message = None
try:
# 1. 消息解析与初始化
message = MessageRecv(message_data)
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
messageinfo = message.message_info
# 消息加入缓冲池
# 2. 消息缓冲与流程序化
await message_buffer.start_caching_messages(message)
# 创建聊天流
chat = await chat_manager.get_or_create_stream(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo,
)
subheartflow = await heartflow.create_subheartflow(chat.stream_id)
subheartflow = await heartflow.get_or_create_subheartflow(chat.stream_id)
message.update_chat_stream(chat)
await heartflow.create_subheartflow(chat.stream_id)
await message.process()
logger.trace(f"消息处理成功: {message.processed_plain_text}")
# 过滤词/正则表达式过滤
# 3. 过滤检查
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
message.raw_message, chat, userinfo
):
return
# 查询缓冲器结果
# 4. 缓冲检查
buffer_result = await message_buffer.query_buffer_result(message)
# 处理缓冲器结果 (Bombing logic)
if not buffer_result:
f_type = "seglist"
if message.message_segment.type != "seglist":
f_type = message.message_segment.type
else:
if (
isinstance(message.message_segment.data, list)
and all(isinstance(x, Seg) for x in message.message_segment.data)
and len(message.message_segment.data) == 1
):
f_type = message.message_segment.data[0].type
if f_type == "text":
logger.debug(f"触发缓冲,消息:{message.processed_plain_text}")
elif f_type == "image":
logger.debug("触发缓冲,表情包/图片等待中")
elif f_type == "seglist":
logger.debug("触发缓冲,消息列表等待中")
return # 被缓冲器拦截,不生成回复
# ---- 只有通过缓冲的消息才进行存储和后续处理 ----
# 存储消息 (使用可能被缓冲器更新过的 message)
try:
await self.storage.store_message(message, chat)
logger.trace(f"存储成功 (通过缓冲后): {message.processed_plain_text}")
except Exception as e:
logger.error(f"存储消息失败: {e}")
logger.error(traceback.format_exc())
# 存储失败可能仍需考虑是否继续,暂时返回
msg_type = self._get_message_type(message)
type_messages = {
"text": f"触发缓冲,消息:{message.processed_plain_text}",
"image": "触发缓冲,表情包/图片等待中",
"seglist": "触发缓冲,消息列表等待中",
}
logger.debug(type_messages.get(msg_type, "触发未知类型缓冲"))
return
# 激活度计算 (使用可能被缓冲器更新过的 message.processed_plain_text)
is_mentioned, _ = is_mentioned_bot_in_message(message)
interested_rate = 0.0 # 默认值
try:
with Timer("记忆激活", timing_results):
interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
message.processed_plain_text,
fast_retrieval=True, # 使用更新后的文本
)
logger.trace(f"记忆激活率 (通过缓冲后): {interested_rate:.2f}")
except Exception as e:
logger.error(f"计算记忆激活率失败: {e}")
logger.error(traceback.format_exc())
# 5. 消息存储
await self.storage.store_message(message, chat)
logger.trace(f"存储成功: {message.processed_plain_text}")
# --- 修改:兴趣度更新逻辑 --- #
if is_mentioned:
interest_increase_on_mention = 1
mentioned_boost = interest_increase_on_mention # 从配置获取提及增加值
interested_rate += mentioned_boost
# 6. 兴趣度计算与更新
interested_rate, is_mentioned = await self._calculate_interest(message)
await subheartflow.interest_chatting.increase_interest(value=interested_rate)
subheartflow.interest_chatting.add_interest_dict(message, interested_rate, is_mentioned)
# 更新兴趣度 (调用 SubHeartflow 的方法)
current_time = time.time()
await subheartflow.interest_chatting.increase_interest(current_time, value=interested_rate)
# 添加到 SubHeartflow 的 interest_dict给normal_chat处理
await subheartflow.add_interest_dict_entry(message, interested_rate, is_mentioned)
# 打印消息接收和处理信息
# 7. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
current_time = time.strftime("%H%M%S", time.localtime(message.message_info.time))
logger.info(
f"[{current_time}][{mes_name}]"
f"{message.message_info.user_info.user_nickname}:"
f"{userinfo.user_nickname}:"
f"{message.processed_plain_text}"
f"[兴趣度: {interested_rate:.2f}]"
)
try:
is_known = await relationship_manager.is_known_some_one(
message.message_info.platform, message.message_info.user_info.user_id
)
if not is_known:
logger.info(f"首次认识用户: {message.message_info.user_info.user_nickname}")
await relationship_manager.first_knowing_some_one(
message.message_info.platform,
message.message_info.user_info.user_id,
message.message_info.user_info.user_nickname,
message.message_info.user_info.user_cardname or message.message_info.user_info.user_nickname,
"",
)
else:
# logger.debug(f"已认识用户: {message.message_info.user_info.user_nickname}")
if not await relationship_manager.is_qved_name(
message.message_info.platform, message.message_info.user_info.user_id
):
logger.info(f"更新已认识但未取名的用户: {message.message_info.user_info.user_nickname}")
await relationship_manager.first_knowing_some_one(
message.message_info.platform,
message.message_info.user_info.user_id,
message.message_info.user_info.user_nickname,
message.message_info.user_info.user_cardname
or message.message_info.user_info.user_nickname,
"",
)
except Exception as e:
logger.error(f"处理认识关系失败: {e}")
logger.error(traceback.format_exc())
# 8. 关系处理
await self._process_relationship(message)
except Exception as e:
logger.error(f"消息处理失败 (process_message V3): {e}")
logger.error(traceback.format_exc())
if message: # 记录失败的消息内容
logger.error(f"失败消息原始内容: {message.raw_message}")
await self._handle_error(e, "消息处理失败", message)
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
"""检查消息是否包含过滤词"""
"""检查消息是否包含过滤词
Args:
text: 待检查的文本
chat: 聊天对象
userinfo: 用户信息
Returns:
bool: 是否包含过滤词
"""
for word in global_config.ban_words:
if word in text:
logger.info(
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
)
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return True
return False
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
"""检查消息是否匹配过滤正则表达式
Args:
text: 待检查的文本
chat: 聊天对象
userinfo: 用户信息
Returns:
bool: 是否匹配过滤正则
"""
for pattern in global_config.ban_msgs_regex:
if pattern.search(text):
logger.info(
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
)
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return True
return False

View File

@@ -21,42 +21,80 @@ logger = get_module_logger("prompt")
def init_prompt():
Prompt(
"""
你有以下信息可供参考:
{structured_info}
{info_from_tools}
{chat_target}
{chat_talking_prompt}
现在你想要在群里发言或者回复。\n
你需要扮演一位网名叫{bot_name}的人进行回复,这个人的特点是:"{prompt_personality} {prompt_identity}"
你正在{chat_target_2},现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,你可以参考贴吧,小红书或者微博的回复风格。
你刚刚脑子里在想:
你正在{chat_target_2},现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,你可以参考贴吧,知乎或者微博的回复风格。
看到以上聊天记录,你刚刚在想:
{current_mind_info}
{reason}
因为上述想法,你决定发言,原因是:{reason}
回复尽量简短一些。请注意把握聊天内容,不要回复的太有条理,可以有个性。请一次只回复一个话题,不要同时回复多个人,不用指出你回复的是谁。{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要说你说过的话题 ,注意只输出回复内容。
{moderation_prompt}。注意:不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
{moderation_prompt}。注意:回复不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
"heart_flow_prompt",
)
# Planner提示词
Prompt(
"""你的名字是 {bot_name}。是一个女大二学生正在QQ聊天正在决定是否以及如何回应当前的聊天。
"""
你有以下信息可供参考:
{structured_info}
以上的消息是你获取到的消息,或许可以帮助你更好地回复。
""",
"info_from_tools",
)
# Planner提示词 - 优化版
Prompt(
"""你的名字是{bot_name}一名大二女生正在QQ群聊中。需要基于以下信息决定如何参与对话
{structured_info_block}
{chat_content_block}
看了以上内容,你产生的内心想法
的内心想法:
{current_mind_block}
请结合你的内心想法和观察到的聊天内容,分析情况并使用 'decide_reply_action' 工具来决定你的最终行动。
决策依据:
1. 如果聊天内容无聊、与你无关、或者你的内心想法认为不适合回复(例如在讨论你不懂或不感兴趣的话题),选择 'no_reply'
2. 如果聊天内容值得回应,且适合用文字表达(参考你的内心想法),选择 'text_reply'。如果你有情绪想表达,想在文字后追加一个表达情绪的表情,请同时提供 'emoji_query' (例如:'开心的''惊讶的')。
3. 如果聊天内容或你的内心想法适合用一个表情来回应(例如表示赞同、惊讶、无语等),选择 'emoji_reply' 并提供表情主题 'emoji_query'
4. 如果最后一条消息是你自己发的,观察到的内容只有你自己的发言,并且之后没有人回复你,通常选择 'no_reply',除非有特殊原因需要追问。
5. 如果聊天记录中最新的消息是你自己发送的,并且你还想继续回复,你应该紧紧衔接你发送的消息,进行话题的深入,补充,或追问等等;。
6. 表情包是用来表达情绪的,不要直接回复或评价别人的表情包,而是根据对话内容和情绪选择是否用表情回应
7. 不要回复你自己的话,不要把自己的话当做别人说的。
必须调用 'decide_reply_action' 工具并提供 'action''reasoning'。如果选择了 'emoji_reply' 或者选择了 'text_reply' 并想追加表情,则必须提供 'emoji_query'""",
{replan}
请综合分析聊天内容和你看到的新消息,参考内心想法,使用'decide_reply_action'工具做出决策。决策时请注意:
【回复原则】
1. 不回复(no_reply)适用:
- 话题无关/无聊/不感兴趣
- 最后一条消息是你自己发的且无人回应
- 讨论你不懂的专业话题
- 你发送了太多消息,且无人回复
2. 文字回复(text_reply)适用:
- 有实质性内容需要表达
- 有人提到你,但你还没有回应他
- 可以追加emoji_query表达情绪(格式:情绪描述,如"俏皮的调侃")
- 不要追加太多表情
3. 纯表情回复(emoji_reply)适用:
- 适合用表情回应的场景
- 需提供明确的emoji_query
4. 自我对话处理:
- 如果是自己发的消息想继续,需自然衔接
- 避免重复或评价自己的发言
- 不要和自己聊天
【必须遵守】
- 遵守回复原则
- 必须调用工具并包含action和reasoning
- 你可以选择文字回复(text_reply),纯表情回复(emoji_reply),不回复(no_reply)
- 选择text_reply或emoji_reply时必须提供emoji_query
- 保持回复自然,符合日常聊天习惯""",
"planner_prompt",
)
Prompt(
"""你原本打算{action},因为:{reasoning}
但是你看到了新的消息,你决定重新决定行动。""",
"replan_prompt",
)
Prompt("你正在qq群里聊天下面是群里在聊的内容", "chat_target_group1")
Prompt("和群里聊天", "chat_target_group2")
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
@@ -79,9 +117,9 @@ def init_prompt():
你的网名叫{bot_name},有人也叫你{bot_other_names}{prompt_personality}
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},然后给出日常且口语化的回复,平淡一些,
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要重复自己说过的话
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。,只输出回复内容""",
"reasoning_prompt_main",
)
Prompt(
@@ -116,13 +154,14 @@ class PromptBuilder:
elif build_mode == "focus":
return await self._build_prompt_focus(
reason, current_mind_info, structured_info, chat_stream,
reason,
current_mind_info,
structured_info,
chat_stream,
)
return None
async def _build_prompt_focus(
self, reason, current_mind_info, structured_info, chat_stream
) -> tuple[str, str]:
async def _build_prompt_focus(self, reason, current_mind_info, structured_info, chat_stream) -> tuple[str, str]:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1)
prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1)
@@ -145,7 +184,7 @@ class PromptBuilder:
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
timestamp_mode="normal",
read_mark=0.0,
)
@@ -156,11 +195,18 @@ class PromptBuilder:
if random.random() < 0.02:
prompt_ger += "你喜欢用反问句"
if structured_info:
structured_info_prompt = await global_prompt_manager.format_prompt(
"info_from_tools", structured_info=structured_info
)
else:
structured_info_prompt = ""
logger.debug("开始构建prompt")
prompt = await global_prompt_manager.format_prompt(
"heart_flow_prompt",
structured_info=structured_info,
info_from_tools=structured_info_prompt,
chat_target=await global_prompt_manager.get_prompt_async("chat_target_group1")
if chat_in_group
else await global_prompt_manager.get_prompt_async("chat_target_private1"),
@@ -490,23 +536,36 @@ class PromptBuilder:
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识
found_knowledge_from_lpmm = qa_manager.get_knowledge(message)
try:
found_knowledge_from_lpmm = qa_manager.get_knowledge(message)
end_time = time.time()
if found_knowledge_from_lpmm is not None:
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info
else:
logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索")
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
related_info += knowledge_from_old
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info
end_time = time.time()
if found_knowledge_from_lpmm is not None:
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info
else:
logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索")
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
related_info += knowledge_from_old
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
try:
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
related_info += knowledge_from_old
logger.debug(
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
)
return related_info
except Exception as e2:
logger.error(f"使用旧版数据库获取知识时也发生异常: {str(e2)}")
return ""
@staticmethod
def get_info_from_db(

View File

@@ -1,23 +1,24 @@
import time
import asyncio
import traceback
import statistics # 导入 statistics 模块
from random import random
from typing import List, Optional # 导入 Optional
from ..moods.moods import MoodManager
from ...config.config import global_config
from ..chat.emoji_manager import emoji_manager
from ..emoji_system.emoji_manager import emoji_manager
from .normal_chat_generator import NormalChatGenerator
from ..chat.message import MessageSending, MessageRecv, MessageThinking, MessageSet
from ..chat.message_sender import message_manager
from ..chat.utils_image import image_path_to_base64
from ..willing.willing_manager import willing_manager
from ..message import UserInfo, Seg
from maim_message import UserInfo, Seg
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
from src.plugins.chat.chat_stream import ChatStream, chat_manager
from src.plugins.person_info.relationship_manager import relationship_manager
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
from src.plugins.utils.timer_calculater import Timer
from src.plugins.utils.timer_calculator import Timer
# 定义日志配置
chat_config = LogConfig(
@@ -46,6 +47,8 @@ class NormalChat:
self.gpt = NormalChatGenerator()
self.mood_manager = MoodManager.get_instance() # MoodManager 保持单例
# 存储此实例的兴趣监控任务
self.start_time = time.time()
self._chat_task: Optional[asyncio.Task] = None
logger.info(f"[{self.stream_name}] NormalChat 实例初始化完成。")
@@ -164,14 +167,13 @@ class NormalChat:
)
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
async def _find_interested_message(self) -> None:
async def _reply_interested_message(self) -> None:
"""
后台任务方法轮询当前实例关联chat的兴趣消息
通常由start_monitoring_interest()启动
"""
while True:
await asyncio.sleep(1) # 每秒检查一次
await asyncio.sleep(0.5) # 每秒检查一次
# 检查任务是否已被取消
if self._chat_task is None or self._chat_task.cancelled():
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
@@ -318,6 +320,68 @@ class NormalChat:
# 意愿管理器注销当前message信息 (无论是否回复,只要处理过就删除)
willing_manager.delete(message.message_info.message_id)
# --- 新增:处理初始高兴趣消息的私有方法 ---
async def _process_initial_interest_messages(self):
"""处理启动时存在于 interest_dict 中的高兴趣消息。"""
items_to_process = list(self.interest_dict.items())
if not items_to_process:
return # 没有初始消息,直接返回
logger.info(f"[{self.stream_name}] 发现 {len(items_to_process)} 条初始兴趣消息,开始处理高兴趣部分...")
interest_values = [item[1][1] for item in items_to_process] # 提取兴趣值列表
messages_to_reply = [] # 需要立即回复的消息
if len(interest_values) == 1:
# 如果只有一个消息,直接处理
messages_to_reply.append(items_to_process[0])
logger.info(f"[{self.stream_name}] 只有一条初始消息,直接处理。")
elif len(interest_values) > 1:
# 计算均值和标准差
try:
mean_interest = statistics.mean(interest_values)
stdev_interest = statistics.stdev(interest_values)
threshold = mean_interest + stdev_interest
logger.info(
f"[{self.stream_name}] 初始兴趣值 均值: {mean_interest:.2f}, 标准差: {stdev_interest:.2f}, 阈值: {threshold:.2f}"
)
# 找出高于阈值的消息
for item in items_to_process:
msg_id, (message, interest_value, is_mentioned) = item
if interest_value > threshold:
messages_to_reply.append(item)
logger.info(f"[{self.stream_name}] 找到 {len(messages_to_reply)} 条高于阈值的初始消息进行处理。")
except statistics.StatisticsError as e:
logger.error(f"[{self.stream_name}] 计算初始兴趣统计值时出错: {e},跳过初始处理。")
# 处理需要回复的消息
processed_count = 0
# --- 修改迭代前创建要处理的ID列表副本防止迭代时修改 ---
messages_to_process_initially = list(messages_to_reply) # 创建副本
# --- 修改结束 ---
for item in messages_to_process_initially: # 使用副本迭代
msg_id, (message, interest_value, is_mentioned) = item
# --- 修改:在处理前尝试 pop防止竞争 ---
popped_item = self.interest_dict.pop(msg_id, None)
if popped_item is None:
logger.warning(f"[{self.stream_name}] 初始兴趣消息 {msg_id} 在处理前已被移除,跳过。")
continue # 如果消息已被其他任务处理pop则跳过
# --- 修改结束 ---
try:
logger.info(f"[{self.stream_name}] 处理初始高兴趣消息 {msg_id} (兴趣值: {interest_value:.2f})")
await self.normal_response(message=message, is_mentioned=is_mentioned, interested_rate=interest_value)
processed_count += 1
except Exception as e:
logger.error(f"[{self.stream_name}] 处理初始兴趣消息 {msg_id} 时出错: {e}\\n{traceback.format_exc()}")
logger.info(
f"[{self.stream_name}] 初始高兴趣消息处理完毕,共处理 {processed_count} 条。剩余 {len(self.interest_dict)} 条待轮询。"
)
# --- 新增结束 ---
# 保持 staticmethod, 因为不依赖实例状态, 但需要 chat 对象来获取日志上下文
@staticmethod
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
@@ -351,38 +415,42 @@ class NormalChat:
# 改为实例方法, 移除 chat 参数
async def start_chat(self):
"""为此 NormalChat 实例关联的 ChatStream 启动聊天任务(如果尚未运行)"""
"""为此 NormalChat 实例关联的 ChatStream 启动聊天任务(如果尚未运行)
并在后台处理一次初始的高兴趣消息。""" # 文言文注释示例:启聊之始,若有遗珠,当于暗处拂拭,勿碍正途。
if self._chat_task is None or self._chat_task.done():
logger.info(f"[{self.stream_name}] 启动聊天任务...")
task = asyncio.create_task(self._find_interested_message())
task.add_done_callback(lambda t: self._handle_task_completion(t)) # 回调现在是实例方法
self._chat_task = task
# --- 修改:使用 create_task 启动初始消息处理 ---
logger.info(f"[{self.stream_name}] 开始后台处理初始兴趣消息...")
# 创建一个任务来处理初始消息,不阻塞当前流程
_initial_process_task = asyncio.create_task(self._process_initial_interest_messages())
# 可以考虑给这个任务也添加完成回调来记录日志或处理错误
# initial_process_task.add_done_callback(...)
# --- 修改结束 ---
# 启动后台轮询任务 (这部分不变)
logger.info(f"[{self.stream_name}] 启动后台兴趣消息轮询任务...")
polling_task = asyncio.create_task(self._reply_interested_message()) # 注意变量名区分
polling_task.add_done_callback(lambda t: self._handle_task_completion(t))
self._chat_task = polling_task # self._chat_task 仍然指向主要的轮询任务
else:
logger.info(f"[{self.stream_name}] 聊天轮询任务已在运行中。")
# 改为实例方法, 移除 stream_id 参数
def _handle_task_completion(self, task: asyncio.Task):
"""兴趣监控任务完成时的回调函数。"""
# 检查完成的任务是否是当前实例的任务
"""任务完成回调处理"""
if task is not self._chat_task:
logger.warning(f"[{self.stream_name}] 收到一个未知或过时任务的完成回调")
logger.warning(f"[{self.stream_name}] 收到未知任务回调")
return
try:
# 检查任务是否因异常而结束
exception = task.exception()
if exception:
logger.error(f"[{self.stream_name}] 兴趣监控任务因异常结束: {exception}")
logger.error(traceback.format_exc()) # 记录完整的 traceback
# else: # 减少日志
# logger.info(f"[{self.stream_name}] 兴趣监控任务正常结束。")
if exc := task.exception():
logger.error(f"[{self.stream_name}] 任务异常: {exc}")
logger.error(traceback.format_exc())
except asyncio.CancelledError:
logger.info(f"[{self.stream_name}] 兴趣监控任务取消")
logger.info(f"[{self.stream_name}] 任务取消")
except Exception as e:
logger.error(f"[{self.stream_name}] 处理任务完成回调时出错: {e}")
logger.error(f"[{self.stream_name}] 回调处理错误: {e}")
finally:
# 标记任务已完成/移除
if self._chat_task is task: # 再次确认是当前任务
if self._chat_task is task:
self._chat_task = None
logger.debug(f"[{self.stream_name}] 聊天任务已被标记为完成/移除。")
logger.debug(f"[{self.stream_name}] 任务清理完成")
# 改为实例方法, 移除 stream_id 参数
async def stop_chat(self):
@@ -402,7 +470,7 @@ class NormalChat:
# 确保任务状态更新,即使等待出错 (回调函数也会尝试更新)
if self._chat_task is task:
self._chat_task = None
# 清理所有未处理的思考消息
try:
container = await message_manager.get_container(self.stream_id)

View File

@@ -5,7 +5,7 @@ from ...config.config import global_config
from ..chat.message import MessageThinking
from .heartflow_prompt_builder import prompt_builder
from ..chat.utils import process_llm_response
from ..utils.timer_calculater import Timer
from ..utils.timer_calculator import Timer
from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager

View File

@@ -404,7 +404,7 @@ class Hippocampus:
# logger.info("没有找到有效的关键词节点")
return []
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
# 从每个关键词获取记忆
all_memories = []
@@ -576,7 +576,7 @@ class Hippocampus:
# logger.info("没有找到有效的关键词节点")
return []
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
# 从每个关键词获取记忆
all_memories = []
@@ -761,7 +761,7 @@ class Hippocampus:
# logger.info("没有找到有效的关键词节点")
return 0
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
# 从每个关键词获取记忆
activate_map = {} # 存储每个词的累计激活值

View File

@@ -3,23 +3,8 @@
__version__ = "0.1.0"
from .api import global_api
from .message_base import (
Seg,
GroupInfo,
UserInfo,
FormatInfo,
TemplateInfo,
BaseMessageInfo,
MessageBase,
)
__all__ = [
"Seg",
"global_api",
"GroupInfo",
"UserInfo",
"FormatInfo",
"TemplateInfo",
"BaseMessageInfo",
"MessageBase",
]

View File

@@ -1,250 +1,6 @@
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from typing import Dict, Any, Callable, List, Set, Optional
from src.common.logger import get_module_logger
from src.plugins.message.message_base import MessageBase
from src.common.server import global_server
import aiohttp
import asyncio
import uvicorn
import os
import traceback
logger = get_module_logger("api")
class BaseMessageHandler:
"""消息处理基类"""
def __init__(self):
self.message_handlers: List[Callable] = []
self.background_tasks = set()
def register_message_handler(self, handler: Callable):
"""注册消息处理函数"""
self.message_handlers.append(handler)
async def process_message(self, message: Dict[str, Any]):
"""处理单条消息"""
tasks = []
for handler in self.message_handlers:
try:
tasks.append(handler(message))
except Exception as e:
logger.error(f"消息处理出错: {str(e)}")
logger.error(traceback.format_exc())
# 不抛出异常,而是记录错误并继续处理其他消息
continue
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _handle_message(self, message: Dict[str, Any]):
"""后台处理单个消息"""
try:
await self.process_message(message)
except Exception as e:
raise RuntimeError(str(e)) from e
class MessageServer(BaseMessageHandler):
"""WebSocket服务端"""
_class_handlers: List[Callable] = [] # 类级别的消息处理器
def __init__(
self,
host: str = "0.0.0.0",
port: int = 18000,
enable_token=False,
app: Optional[FastAPI] = None,
path: str = "/ws",
):
super().__init__()
# 将类级别的处理器添加到实例处理器中
self.message_handlers.extend(self._class_handlers)
self.host = host
self.port = port
self.path = path
self.app = app or FastAPI()
self.own_app = app is None # 标记是否使用自己创建的app
self.active_websockets: Set[WebSocket] = set()
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
self.valid_tokens: Set[str] = set()
self.enable_token = enable_token
self._setup_routes()
self._running = False
def _setup_routes(self):
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
try:
# 创建后台任务处理消息
asyncio.create_task(self._handle_message(message))
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@self.app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
headers = dict(websocket.headers)
token = headers.get("authorization")
platform = headers.get("platform", "default") # 获取platform标识
if self.enable_token:
if not token or not await self.verify_token(token):
await websocket.close(code=1008, reason="Invalid or missing token")
return
await websocket.accept()
self.active_websockets.add(websocket)
# 添加到platform映射
if platform not in self.platform_websockets:
self.platform_websockets[platform] = websocket
try:
while True:
message = await websocket.receive_json()
# print(f"Received message: {message}")
asyncio.create_task(self._handle_message(message))
except WebSocketDisconnect:
self._remove_websocket(websocket, platform)
except Exception as e:
self._remove_websocket(websocket, platform)
raise RuntimeError(str(e)) from e
finally:
self._remove_websocket(websocket, platform)
@classmethod
def register_class_handler(cls, handler: Callable):
"""注册类级别的消息处理器"""
if handler not in cls._class_handlers:
cls._class_handlers.append(handler)
def register_message_handler(self, handler: Callable):
"""注册实例级别的消息处理器"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def verify_token(self, token: str) -> bool:
if not self.enable_token:
return True
return token in self.valid_tokens
def add_valid_token(self, token: str):
self.valid_tokens.add(token)
def remove_valid_token(self, token: str):
self.valid_tokens.discard(token)
def run_sync(self):
"""同步方式运行服务器"""
if not self.own_app:
raise RuntimeError("当使用外部FastAPI实例时请使用该实例的运行方法")
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
self._running = True
try:
if self.own_app:
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
# 禁用 uvicorn 默认日志和访问日志
config = uvicorn.Config(
self.app, host=self.host, port=self.port, loop="asyncio", log_config=None, access_log=False
)
self.server = uvicorn.Server(config)
await self.server.serve()
else:
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
while self._running:
await asyncio.sleep(1)
except KeyboardInterrupt:
await self.stop()
raise
except Exception as e:
await self.stop()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.stop()
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
# 清理platform映射
self.platform_websockets.clear()
# 取消所有后台任务
for task in self.background_tasks:
task.cancel()
# 等待所有任务完成
await asyncio.gather(*self.background_tasks, return_exceptions=True)
self.background_tasks.clear()
# 关闭所有WebSocket连接
for websocket in self.active_websockets:
await websocket.close()
self.active_websockets.clear()
if hasattr(self, "server") and self.own_app:
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def _remove_websocket(self, websocket: WebSocket, platform: str):
"""从所有集合中移除websocket"""
if websocket in self.active_websockets:
self.active_websockets.remove(websocket)
if platform in self.platform_websockets:
if self.platform_websockets[platform] == websocket:
del self.platform_websockets[platform]
async def broadcast_message(self, message: Dict[str, Any]):
disconnected = set()
for websocket in self.active_websockets:
try:
await websocket.send_json(message)
except Exception:
disconnected.add(websocket)
for websocket in disconnected:
self.active_websockets.remove(websocket)
async def broadcast_to_platform(self, platform: str, message: Dict[str, Any]):
"""向指定平台的所有WebSocket客户端广播消息"""
if platform not in self.platform_websockets:
raise ValueError(f"平台:{platform} 未连接")
disconnected = set()
try:
await self.platform_websockets[platform].send_json(message)
except Exception:
disconnected.add(self.platform_websockets[platform])
# 清理断开的连接
for websocket in disconnected:
self._remove_websocket(websocket, platform)
async def send_message(self, message: MessageBase):
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
@staticmethod
async def send_message_rest(url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
return await response.json()
except Exception as e:
raise e
from maim_message import MessageServer
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())

View File

@@ -1,247 +0,0 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Dict
@dataclass
class Seg:
"""消息片段类,用于表示消息的不同部分
Attributes:
type: 片段类型,可以是 'text''image''seglist'
data: 片段的具体内容
- 对于 text 类型data 是字符串
- 对于 image 类型data 是 base64 字符串
- 对于 seglist 类型data 是 Seg 列表
"""
type: str
data: Union[str, List["Seg"]]
# def __init__(self, type: str, data: Union[str, List['Seg']],):
# """初始化实例,确保字典和属性同步"""
# # 先初始化字典
# self.type = type
# self.data = data
@classmethod
def from_dict(cls, data: Dict) -> "Seg":
"""从字典创建Seg实例"""
type = data.get("type")
data = data.get("data")
if type == "seglist":
data = [Seg.from_dict(seg) for seg in data]
return cls(type=type, data=data)
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {"type": self.type}
if self.type == "seglist":
result["data"] = [seg.to_dict() for seg in self.data]
else:
result["data"] = self.data
return result
@dataclass
class GroupInfo:
"""群组信息类"""
platform: Optional[str] = None
group_id: Optional[int] = None
group_name: Optional[str] = None # 群名称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "GroupInfo":
"""从字典创建GroupInfo实例
Args:
data: 包含必要字段的字典
Returns:
GroupInfo: 新的实例
"""
if data.get("group_id") is None:
return None
return cls(
platform=data.get("platform"), group_id=data.get("group_id"), group_name=data.get("group_name", None)
)
@dataclass
class UserInfo:
"""用户信息类"""
platform: Optional[str] = None
user_id: Optional[int] = None
user_nickname: Optional[str] = None # 用户昵称
user_cardname: Optional[str] = None # 用户群昵称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "UserInfo":
"""从字典创建UserInfo实例
Args:
data: 包含必要字段的字典
Returns:
UserInfo: 新的实例
"""
return cls(
platform=data.get("platform"),
user_id=data.get("user_id"),
user_nickname=data.get("user_nickname", None),
user_cardname=data.get("user_cardname", None),
)
@dataclass
class FormatInfo:
"""格式信息类"""
"""
目前maimcore可接受的格式为text,image,emoji
可发送的格式为text,emoji,reply
"""
content_format: Optional[str] = None
accept_format: Optional[str] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "FormatInfo":
"""从字典创建FormatInfo实例
Args:
data: 包含必要字段的字典
Returns:
FormatInfo: 新的实例
"""
return cls(
content_format=data.get("content_format"),
accept_format=data.get("accept_format"),
)
@dataclass
class TemplateInfo:
"""模板信息类"""
template_items: Optional[Dict] = None
template_name: Optional[str] = None
template_default: bool = True
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "TemplateInfo":
"""从字典创建TemplateInfo实例
Args:
data: 包含必要字段的字典
Returns:
TemplateInfo: 新的实例
"""
return cls(
template_items=data.get("template_items"),
template_name=data.get("template_name"),
template_default=data.get("template_default", True),
)
@dataclass
class BaseMessageInfo:
"""消息信息类"""
platform: Optional[str] = None
message_id: Union[str, int, None] = None
time: Optional[float] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
format_info: Optional[FormatInfo] = None
template_info: Optional[TemplateInfo] = None
additional_config: Optional[dict] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {}
for field, value in asdict(self).items():
if value is not None:
if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)):
result[field] = value.to_dict()
else:
result[field] = value
return result
@classmethod
def from_dict(cls, data: Dict) -> "BaseMessageInfo":
"""从字典创建BaseMessageInfo实例
Args:
data: 包含必要字段的字典
Returns:
BaseMessageInfo: 新的实例
"""
group_info = GroupInfo.from_dict(data.get("group_info", {}))
user_info = UserInfo.from_dict(data.get("user_info", {}))
format_info = FormatInfo.from_dict(data.get("format_info", {}))
template_info = TemplateInfo.from_dict(data.get("template_info", {}))
return cls(
platform=data.get("platform"),
message_id=data.get("message_id"),
time=data.get("time"),
additional_config=data.get("additional_config", None),
group_info=group_info,
user_info=user_info,
format_info=format_info,
template_info=template_info,
)
@dataclass
class MessageBase:
"""消息类"""
message_info: BaseMessageInfo
message_segment: Seg
raw_message: Optional[str] = None # 原始消息包含未解析的cq码
def to_dict(self) -> Dict:
"""转换为字典格式
Returns:
Dict: 包含所有非None字段的字典其中
- message_info: 转换为字典格式
- message_segment: 转换为字典格式
- raw_message: 如果存在则包含
"""
result = {"message_info": self.message_info.to_dict(), "message_segment": self.message_segment.to_dict()}
if self.raw_message is not None:
result["raw_message"] = self.raw_message
return result
@classmethod
def from_dict(cls, data: Dict) -> "MessageBase":
"""从字典创建MessageBase实例
Args:
data: 包含必要字段的字典
Returns:
MessageBase: 新的实例
"""
message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
message_segment = Seg.from_dict(data.get("message_segment", {}))
raw_message = data.get("raw_message", None)
return cls(message_info=message_info, message_segment=message_segment, raw_message=raw_message)

View File

@@ -178,395 +178,6 @@ class LLMRequest:
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)
'''
async def _execute_request(
self,
endpoint: str,
prompt: str = None,
image_base64: str = None,
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
response_handler: callable = None,
user_id: str = "system",
request_type: str = None,
):
"""统一请求执行入口
Args:
endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
payload: 请求体数据
retry_policy: 自定义重试策略
response_handler: 自定义响应处理器
user_id: 用户ID
request_type: 请求类型
"""
if request_type is None:
request_type = self.request_type
# 合并重试策略
default_retry = {
"max_retries": 3,
"base_wait": 10,
"retry_codes": [429, 413, 500, 503],
"abort_codes": [400, 401, 402, 403],
}
policy = {**default_retry, **(retry_policy or {})}
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
# 判断是否为流式
stream_mode = self.stream
# logger_msg = "进入流式输出模式," if stream_mode else ""
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
# logger.info(f"使用模型: {self.model_name}")
# 构建请求体
if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format)
elif payload is None:
payload = await self._build_payload(prompt)
# 流式输出标志
# 先构建payload再添加流式输出标志
if stream_mode:
payload["stream"] = stream_mode
for retry in range(policy["max_retries"]):
try:
# 使用上下文管理器处理会话
headers = await self._build_headers()
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if stream_mode:
headers["Accept"] = "text/event-stream"
async with aiohttp.ClientSession() as session:
try:
async with session.post(api_url, headers=headers, json=payload) as response:
# 处理需要重试的状态码
if response.status in policy["retry_codes"]:
wait_time = policy["base_wait"] * (2**retry)
logger.warning(
f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试"
)
if response.status == 413:
logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64, image_format)
elif response.status in [500, 503]:
logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
else:
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
elif response.status in policy["abort_codes"]:
logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
# 尝试获取并记录服务器返回的详细错误信息
try:
error_json = await response.json()
if error_json and isinstance(error_json, list) and len(error_json) > 0:
for error_item in error_json:
if "error" in error_item and isinstance(error_item["error"], dict):
error_obj = error_item["error"]
error_code = error_obj.get("code")
error_message = error_obj.get("message")
error_status = error_obj.get("status")
logger.error(
f"服务器错误详情: 代码={error_code}, 状态={error_status}, "
f"消息={error_message}"
)
elif isinstance(error_json, dict) and "error" in error_json:
# 处理单个错误对象的情况
error_obj = error_json.get("error", {})
error_code = error_obj.get("code")
error_message = error_obj.get("message")
error_status = error_obj.get("status")
logger.error(
f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
)
else:
# 记录原始错误响应内容
logger.error(f"服务器错误响应: {error_json}")
except Exception as e:
logger.warning(f"无法解析服务器错误响应: {str(e)}")
if response.status == 403:
# 只针对硅基流动的V3和R1进行降级处理
if (
self.model_name.startswith("Pro/deepseek-ai")
and self.base_url == "https://api.siliconflow.cn/v1/"
):
old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(
f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}"
)
# 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name:
global_config.llm_normal["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name
logger.warning(
f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}"
)
# 更新payload中的模型名
if payload and "model" in payload:
payload["model"] = self.model_name
# 重新尝试请求
retry -= 1 # 不计入重试次数
continue
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
response.raise_for_status()
reasoning_content = ""
# 将流式输出转化为非流式输出
if stream_mode:
flag_delta_content_finished = False
accumulated_content = ""
usage = None # 初始化usage变量避免未定义错误
async for line_bytes in response.content:
try:
line = line_bytes.decode("utf-8").strip()
if not line:
continue
if line.startswith("data:"):
data_str = line[5:].strip()
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
if flag_delta_content_finished:
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage # 获取token用量
else:
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
if delta_content is None:
delta_content = ""
accumulated_content += delta_content
# 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason")
if delta.get("reasoning_content", None):
reasoning_content += delta["reasoning_content"]
if finish_reason == "stop":
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage
break
# 部分平台在文本输出结束前不会返回token用量此时需要再获取一次chunk
flag_delta_content_finished = True
except Exception as e:
logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}")
except GeneratorExit:
logger.warning("模型 {self.model_name} 流式输出被中断,正在清理资源...")
# 确保资源被正确清理
await response.release()
# 返回已经累积的内容
result = {
"choices": [
{
"message": {
"content": accumulated_content,
"reasoning_content": reasoning_content,
# 流式输出可能没有工具调用此处不需要添加tool_calls字段
}
}
],
"usage": usage,
}
return (
response_handler(result)
if response_handler
else self._default_response_handler(result, user_id, request_type, endpoint)
)
except Exception as e:
logger.error(f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}")
# 确保在发生错误时也能正确清理资源
try:
await response.release()
except Exception as cleanup_error:
logger.error(f"清理资源时发生错误: {cleanup_error}")
# 返回已经累积的内容
result = {
"choices": [
{
"message": {
"content": accumulated_content,
"reasoning_content": reasoning_content,
# 流式输出可能没有工具调用此处不需要添加tool_calls字段
}
}
],
"usage": usage,
}
return (
response_handler(result)
if response_handler
else self._default_response_handler(result, user_id, request_type, endpoint)
)
content = accumulated_content
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
# 构造一个伪result以便调用自定义响应处理器或默认处理器
result = {
"choices": [
{
"message": {
"content": content,
"reasoning_content": reasoning_content,
# 流式输出可能没有工具调用此处不需要添加tool_calls字段
}
}
],
"usage": usage,
}
return (
response_handler(result)
if response_handler
else self._default_response_handler(result, user_id, request_type, endpoint)
)
else:
result = await response.json()
# 使用自定义处理器或默认处理
return (
response_handler(result)
if response_handler
else self._default_response_handler(result, user_id, request_type, endpoint)
)
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if retry < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry)
logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
continue
else:
logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(e)}")
raise RuntimeError(f"网络请求失败: {str(e)}") from e
except Exception as e:
logger.critical(f"模型 {self.model_name} 未预期的错误: {str(e)}")
raise RuntimeError(f"请求过程中发生错误: {str(e)}") from e
except aiohttp.ClientResponseError as e:
# 处理aiohttp抛出的响应错误
if retry < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry)
logger.error(
f"模型 {self.model_name} HTTP响应错误等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}"
)
try:
if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
error_text = await e.response.text()
try:
error_json = json.loads(error_text)
if isinstance(error_json, list) and len(error_json) > 0:
for error_item in error_json:
if "error" in error_item and isinstance(error_item["error"], dict):
error_obj = error_item["error"]
logger.error(
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
f"状态={error_obj.get('status')}, "
f"消息={error_obj.get('message')}"
)
elif isinstance(error_json, dict) and "error" in error_json:
error_obj = error_json.get("error", {})
logger.error(
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
f"状态={error_obj.get('status')}, "
f"消息={error_obj.get('message')}"
)
else:
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
except (json.JSONDecodeError, TypeError) as json_err:
logger.warning(
f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
)
except (AttributeError, TypeError, ValueError) as parse_err:
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
await asyncio.sleep(wait_time)
else:
logger.critical(
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}"
)
# 安全地检查和记录请求详情
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"模型 {self.model_name} API请求失败: 状态码 {e.status}, {e.message}") from e
except Exception as e:
if retry < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry)
logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
logger.critical(f"模型 {self.model_name} 请求失败: {str(e)}")
# 安全地检查和记录请求详情
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(e)}") from e
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数API请求仍然失败")
'''
async def _prepare_request(
self,
endpoint: str,
@@ -820,6 +431,7 @@ class LLMRequest:
policy = request_content["policy"]
payload = request_content["payload"]
wait_time = policy["base_wait"] * (2**retry_count)
keep_request = False
if retry_count < policy["max_retries"] - 1:
keep_request = True
if isinstance(exception, RequestAbortException):

View File

@@ -256,7 +256,7 @@ class MoodManager:
def print_mood_status(self) -> None:
"""打印当前情绪状态"""
logger.info(
f"[情绪状态]愉悦度: {self.current_mood.valence:.2f}, "
f"愉悦度: {self.current_mood.valence:.2f}, "
f"唤醒度: {self.current_mood.arousal:.2f}, "
f"心情: {self.current_mood.text}"
)

View File

@@ -53,7 +53,7 @@ person_info_default = {
# "impression" : None,
# "gender" : Unkown,
"konw_time": 0,
"msg_interval": 3000,
"msg_interval": 2000,
"msg_interval_list": [],
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
@@ -384,18 +384,30 @@ class PersonInfoManager:
if delta > 0:
time_interval.append(delta)
time_interval = [t for t in time_interval if 500 <= t <= 8000]
if len(time_interval) >= 30:
time_interval = [t for t in time_interval if 200 <= t <= 8000]
# --- 修改后的逻辑 ---
# 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断)
if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条
time_interval.sort()
# 画图(log)
# 画图(log) - 这部分保留
msg_interval_map = True
log_dir = Path("logs/person_info")
log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6))
time_series = pd.Series(time_interval)
plt.hist(time_series, bins=50, density=True, alpha=0.4, color="pink", label="Histogram")
time_series.plot(kind="kde", color="mediumpurple", linewidth=1, label="Density")
# 使用截断前的数据画图,更能反映原始分布
time_series_original = pd.Series(time_interval)
plt.hist(
time_series_original,
bins=50,
density=True,
alpha=0.4,
color="pink",
label="Histogram (Original Filtered)",
)
time_series_original.plot(
kind="kde", color="mediumpurple", linewidth=1, label="Density (Original Filtered)"
)
plt.grid(True, alpha=0.2)
plt.xlim(0, 8000)
plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)")
@@ -405,15 +417,24 @@ class PersonInfoManager:
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path)
plt.close()
# 画图
# 画图结束
q25, q75 = np.percentile(time_interval, [25, 75])
iqr = q75 - q25
filtered = [x for x in time_interval if (q25 - 1.5 * iqr) <= x <= (q75 + 1.5 * iqr)]
# 去掉头尾各 5 个数据点
trimmed_interval = time_interval[5:-5]
msg_interval = int(round(np.percentile(filtered, 80)))
await self.update_one_field(person_id, "msg_interval", msg_interval)
logger.trace(f"用户{person_id}的msg_interval已经被更新为{msg_interval}")
# 计算截断后数据的 37% 分位数
if trimmed_interval: # 确保截断后列表不为空
msg_interval = int(round(np.percentile(trimmed_interval, 37)))
# 更新数据库
await self.update_one_field(person_id, "msg_interval", msg_interval)
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}")
else:
logger.trace(f"用户{person_id}截断后数据为空无法计算msg_interval")
else:
logger.trace(
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
)
# --- 修改结束 ---
except Exception as e:
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}")
continue

View File

@@ -168,7 +168,10 @@ async def _build_readable_messages_internal(
user_info = msg.get("user_info", {})
platform = user_info.get("platform")
user_id = user_info.get("user_id")
user_nickname = user_info.get("nickname")
user_nickname = user_info.get("user_nickname")
user_cardname = user_info.get("user_cardname")
timestamp = msg.get("time")
content = msg.get("processed_plain_text", "") # 默认空字符串
@@ -186,7 +189,12 @@ async def _build_readable_messages_internal(
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name:
person_name = user_nickname
if user_cardname:
person_name = f"昵称:{user_cardname}"
elif user_nickname:
person_name = f"{user_nickname}"
else:
person_name = "某人"
message_details.append((timestamp, person_name, content))
@@ -303,9 +311,7 @@ async def build_readable_messages(
)
readable_read_mark = translate_timestamp_to_human_readable(read_mark, mode=timestamp_mode)
read_mark_line = (
f"\n\n--- 以上消息已读 (标记时间: {readable_read_mark}) ---\n--- 请关注你上次思考之后以下的新消息---\n"
)
read_mark_line = f"\n--- 以上消息是你已经思考过的内容已读 (标记时间: {readable_read_mark}) ---\n--- 请关注以下未读的新消息---\n"
# 组合结果,确保空部分不引入多余的标记或换行
if formatted_before and formatted_after: