Merge branch 'dev' into dev
This commit is contained in:
@@ -8,6 +8,7 @@ from .pfc_utils import get_items_from_json
|
||||
from src.individuality.individuality import Individuality
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
pfc_action_log_config = LogConfig(
|
||||
console_format=PFC_ACTION_PLANNER_STYLE_CONFIG["console_format"],
|
||||
@@ -132,12 +133,7 @@ class ActionPlanner:
|
||||
chat_history_text = ""
|
||||
try:
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
chat_history_list = observation_info.chat_history[-20:]
|
||||
for msg in chat_history_list:
|
||||
if isinstance(msg, dict) and "detailed_plain_text" in msg:
|
||||
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
|
||||
elif isinstance(msg, str):
|
||||
chat_history_text += f"{msg}\n"
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if not chat_history_text: # 如果历史记录是空列表
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
else:
|
||||
@@ -146,12 +142,16 @@ class ActionPlanner:
|
||||
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
|
||||
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
chat_history_text += f"--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n"
|
||||
for msg in new_messages_list:
|
||||
if isinstance(msg, dict) and "detailed_plain_text" in msg:
|
||||
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
|
||||
elif isinstance(msg, str):
|
||||
chat_history_text += f"{msg}\n"
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += (
|
||||
f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
)
|
||||
# 清理消息应该由调用者或 observation_info 内部逻辑处理,这里不再调用 clear
|
||||
# if hasattr(observation_info, 'clear_unprocessed_messages'):
|
||||
# observation_info.clear_unprocessed_messages()
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.common.logger import get_module_logger
|
||||
from ..message.message_base import UserInfo
|
||||
from maim_message import UserInfo
|
||||
from ...config.config import global_config
|
||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||
from .message_storage import MongoDBMessageStorage
|
||||
|
||||
@@ -98,7 +98,7 @@ class NotificationManager:
|
||||
notification_type: 要处理的通知类型
|
||||
handler: 处理器实例
|
||||
"""
|
||||
print(1145145511114445551111444)
|
||||
# print(1145145511114445551111444)
|
||||
if target not in self._handlers:
|
||||
# print("没11有target")
|
||||
self._handlers[target] = {}
|
||||
@@ -146,9 +146,9 @@ class NotificationManager:
|
||||
if target in self._handlers:
|
||||
handlers = self._handlers[target].get(notification.type, [])
|
||||
# print(1111111)
|
||||
print(handlers)
|
||||
# print(handlers)
|
||||
for handler in handlers:
|
||||
print(f"调用处理器: {handler}")
|
||||
# print(f"调用处理器: {handler}")
|
||||
await handler.handle_notification(notification)
|
||||
|
||||
def get_active_states(self) -> Set[NotificationType]:
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
import asyncio
|
||||
import datetime
|
||||
# from .message_storage import MongoDBMessageStorage
|
||||
from src.plugins.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_ch
|
||||
from ...config.config import global_config
|
||||
from typing import Dict, Any, Optional
|
||||
from ..chat.message import Message
|
||||
@@ -14,7 +14,7 @@ from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo
|
||||
from .reply_generator import ReplyGenerator
|
||||
from ..chat.chat_stream import ChatStream
|
||||
from ..message.message_base import UserInfo
|
||||
from maim_message import UserInfo
|
||||
from src.plugins.chat.chat_stream import chat_manager
|
||||
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
||||
from .waiter import Waiter
|
||||
@@ -77,14 +77,22 @@ class Conversation:
|
||||
raise
|
||||
try:
|
||||
logger.info(f"为 {self.stream_id} 加载初始聊天记录...")
|
||||
initial_messages = await get_raw_msg_before_timestamp_with_chat( #
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat( #
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=30, # 加载最近30条作为初始上下文,可以调整
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
initial_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
if initial_messages:
|
||||
# 将加载的消息填充到 ObservationInfo 的 chat_history
|
||||
self.observation_info.chat_history = initial_messages
|
||||
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
|
||||
self.observation_info.chat_history_count = len(initial_messages)
|
||||
|
||||
# 更新 ObservationInfo 中的时间戳等信息
|
||||
@@ -174,7 +182,7 @@ class Conversation:
|
||||
if hasattr(self.observation_info, "clear_unprocessed_messages"):
|
||||
# 确保 clear_unprocessed_messages 方法存在
|
||||
logger.debug(f"准备执行 direct_reply,清理 {initial_new_message_count} 条规划时已知的新消息。")
|
||||
self.observation_info.clear_unprocessed_messages()
|
||||
await self.observation_info.clear_unprocessed_messages()
|
||||
# 手动重置计数器,确保状态一致性(理想情况下 clear 方法会做这个)
|
||||
if hasattr(self.observation_info, "new_messages_count"):
|
||||
self.observation_info.new_messages_count = 0
|
||||
@@ -259,7 +267,7 @@ class Conversation:
|
||||
|
||||
# --- 根据不同的 action 执行 ---
|
||||
if action == "direct_reply":
|
||||
max_reply_attempts = 3 # 设置最大尝试次数(与 reply_checker.py 中的 max_retries 保持一致或稍大)
|
||||
max_reply_attempts = 3 # 设置最大尝试次数(与 reply_checker.py 中的 max_retries 保持一致或稍大)
|
||||
reply_attempt_count = 0
|
||||
is_suitable = False
|
||||
need_replan = False
|
||||
@@ -284,17 +292,20 @@ class Conversation:
|
||||
reply=self.generated_reply,
|
||||
goal=current_goal_str,
|
||||
chat_history=observation_info.chat_history,
|
||||
retry_count=reply_attempt_count - 1, # 传递当前尝试次数(从0开始计数)
|
||||
chat_history_str=observation_info.chat_history_str,
|
||||
retry_count=reply_attempt_count - 1, # 传递当前尝试次数(从0开始计数)
|
||||
)
|
||||
logger.info(
|
||||
f"第 {reply_attempt_count} 次检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
|
||||
)
|
||||
logger.info(f"第 {reply_attempt_count} 次检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}")
|
||||
|
||||
if is_suitable:
|
||||
final_reply_to_send = self.generated_reply # 保存合适的回复
|
||||
break # 回复合适,跳出循环
|
||||
final_reply_to_send = self.generated_reply # 保存合适的回复
|
||||
break # 回复合适,跳出循环
|
||||
|
||||
elif need_replan:
|
||||
logger.warning(f"第 {reply_attempt_count} 次检查建议重新规划,停止尝试。原因: {check_reason}")
|
||||
break # 如果检查器建议重新规划,也停止尝试
|
||||
logger.warning(f"第 {reply_attempt_count} 次检查建议重新规划,停止尝试。原因: {check_reason}")
|
||||
break # 如果检查器建议重新规划,也停止尝试
|
||||
|
||||
# 如果不合适但不需要重新规划,循环会继续进行下一次尝试
|
||||
except Exception as check_err:
|
||||
@@ -321,7 +332,7 @@ class Conversation:
|
||||
return
|
||||
|
||||
# 发送合适的回复
|
||||
self.generated_reply = final_reply_to_send # 确保 self.generated_reply 是最终要发送的内容
|
||||
self.generated_reply = final_reply_to_send # 确保 self.generated_reply 是最终要发送的内容
|
||||
await self._send_reply()
|
||||
|
||||
# 更新 action 历史状态为 done
|
||||
@@ -337,7 +348,7 @@ class Conversation:
|
||||
logger.warning(f"经过 {reply_attempt_count} 次尝试,未能生成合适的回复。最终原因: {check_reason}")
|
||||
conversation_info.done_action[action_index].update(
|
||||
{
|
||||
"status": "recall", # 标记为 recall 因为没有成功发送
|
||||
"status": "recall", # 标记为 recall 因为没有成功发送
|
||||
"final_reason": f"尝试{reply_attempt_count}次后失败: {check_reason}",
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
}
|
||||
@@ -352,7 +363,7 @@ class Conversation:
|
||||
wait_action_record = {
|
||||
"action": "wait",
|
||||
"plan_reason": "因 direct_reply 多次尝试失败而执行的后备等待",
|
||||
"status": "done", # wait 完成后可以认为是 done
|
||||
"status": "done", # wait 完成后可以认为是 done
|
||||
"time": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"final_reason": None,
|
||||
}
|
||||
@@ -461,42 +472,11 @@ class Conversation:
|
||||
|
||||
try:
|
||||
# 外层 try: 捕获发送消息和后续处理中的主要错误
|
||||
current_time = time.time() # 获取当前时间戳
|
||||
_current_time = time.time() # 获取当前时间戳
|
||||
reply_content = self.generated_reply # 获取要发送的内容
|
||||
|
||||
# 发送消息
|
||||
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)
|
||||
logger.info(f"消息已发送: {reply_content}") # 可以在发送后加个日志确认
|
||||
|
||||
# --- 添加的立即更新状态逻辑开始 ---
|
||||
try:
|
||||
# 内层 try: 专门捕获手动更新状态时可能出现的错误
|
||||
# 创建一个代表刚刚发送的消息的字典
|
||||
bot_message_info = {
|
||||
"message_id": f"bot_sent_{current_time}", # 创建一个简单的唯一ID
|
||||
"time": current_time,
|
||||
"user_info": UserInfo( # 使用 UserInfo 类构建用户信息
|
||||
user_id=str(global_config.BOT_QQ),
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
platform=self.chat_stream.platform, # 从 chat_stream 获取平台信息
|
||||
).to_dict(), # 转换为字典格式存储
|
||||
"processed_plain_text": reply_content, # 使用发送的内容
|
||||
"detailed_plain_text": f"{int(current_time)},{global_config.BOT_NICKNAME}:{reply_content}", # 构造一个简单的详细文本, 时间戳取整
|
||||
# 可以根据需要添加其他字段,保持与 observation_info.chat_history 中其他消息结构一致
|
||||
}
|
||||
|
||||
# 直接更新 ObservationInfo 实例
|
||||
if self.observation_info:
|
||||
self.observation_info.chat_history.append(bot_message_info) # 将消息添加到历史记录末尾
|
||||
self.observation_info.last_bot_speak_time = current_time # 更新 Bot 最后发言时间
|
||||
self.observation_info.last_message_time = current_time # 更新最后消息时间
|
||||
logger.debug("已手动将Bot发送的消息添加到 ObservationInfo")
|
||||
else:
|
||||
logger.warning("无法手动更新 ObservationInfo:实例不存在")
|
||||
|
||||
except Exception as update_err:
|
||||
logger.error(f"手动更新 ObservationInfo 时出错: {update_err}")
|
||||
# --- 添加的立即更新状态逻辑结束 ---
|
||||
|
||||
# 原有的触发更新和等待代码
|
||||
self.chat_observer.trigger_update()
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from ..chat.chat_stream import ChatStream
|
||||
from ..chat.message import Message
|
||||
from ..message.message_base import Seg
|
||||
from maim_message import Seg
|
||||
from src.plugins.chat.message import MessageSending, MessageSet
|
||||
from src.plugins.chat.message_sender import message_manager
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# Programmable Friendly Conversationalist
|
||||
# Prefrontal cortex
|
||||
from typing import List, Optional, Dict, Any, Set
|
||||
from ..message.message_base import UserInfo
|
||||
from maim_message import UserInfo
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from src.common.logger import get_module_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
logger = get_module_logger("observation_info")
|
||||
|
||||
@@ -97,6 +98,7 @@ class ObservationInfo:
|
||||
|
||||
# data_list
|
||||
chat_history: List[str] = field(default_factory=list)
|
||||
chat_history_str: str = ""
|
||||
unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list)
|
||||
active_users: Set[str] = field(default_factory=set)
|
||||
|
||||
@@ -223,11 +225,18 @@ class ObservationInfo:
|
||||
return None
|
||||
return time.time() - self.last_bot_speak_time
|
||||
|
||||
def clear_unprocessed_messages(self):
|
||||
async def clear_unprocessed_messages(self):
|
||||
"""清空未处理消息列表"""
|
||||
# 将未处理消息添加到历史记录中
|
||||
for message in self.unprocessed_messages:
|
||||
self.chat_history.append(message)
|
||||
self.chat_history_str = await build_readable_messages(
|
||||
self.chat_history[-20:] if len(self.chat_history) > 20 else self.chat_history,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
# 清空未处理消息列表
|
||||
self.has_unread_messages = False
|
||||
self.unprocessed_messages.clear()
|
||||
|
||||
@@ -6,7 +6,7 @@ import datetime
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING
|
||||
from src.common.logger import get_module_logger
|
||||
from ..chat.chat_stream import ChatStream
|
||||
from ..message.message_base import UserInfo, Seg
|
||||
from maim_message import UserInfo, Seg
|
||||
from ..chat.message import Message
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
@@ -19,6 +19,7 @@ from src.individuality.individuality import Individuality
|
||||
from .conversation_info import ConversationInfo
|
||||
from .observation_info import ObservationInfo
|
||||
import time
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -80,19 +81,20 @@ class GoalAnalyzer:
|
||||
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||
|
||||
# 获取聊天历史记录
|
||||
chat_history_list = observation_info.chat_history
|
||||
chat_history_text = ""
|
||||
for msg in chat_history_list:
|
||||
chat_history_text += f"{msg}\n"
|
||||
chat_history_text = observation_info.chat_history
|
||||
|
||||
if observation_info.new_messages_count > 0:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
|
||||
chat_history_text += f"有{observation_info.new_messages_count}条新消息:\n"
|
||||
for msg in new_messages_list:
|
||||
chat_history_text += f"{msg}\n"
|
||||
|
||||
observation_info.clear_unprocessed_messages()
|
||||
# await observation_info.clear_unprocessed_messages()
|
||||
|
||||
identity_details_only = self.identity_detail_info
|
||||
identity_addon = ""
|
||||
@@ -371,22 +373,12 @@ class DirectMessageSender:
|
||||
# 处理消息
|
||||
await message.process()
|
||||
|
||||
message_json = message.to_dict()
|
||||
_message_json = message.to_dict()
|
||||
|
||||
# 发送消息
|
||||
try:
|
||||
end_point = global_config.api_urls.get(message.message_info.platform, None)
|
||||
if end_point:
|
||||
# logger.info(f"发送消息到{end_point}")
|
||||
# logger.info(message_json)
|
||||
try:
|
||||
await global_api.send_message_REST(end_point, message_json)
|
||||
except Exception as e:
|
||||
logger.error(f"REST方式发送失败,出现错误: {str(e)}")
|
||||
logger.info("尝试使用ws发送")
|
||||
await self.send_via_ws(message)
|
||||
else:
|
||||
await self.send_via_ws(message)
|
||||
await self.send_via_ws(message)
|
||||
await self.storage.store_message(message, chat_stream)
|
||||
logger.success(f"PFC消息已发送: {content}")
|
||||
except Exception as e:
|
||||
logger.error(f"PFC消息发送失败: {str(e)}")
|
||||
|
||||
@@ -4,6 +4,7 @@ from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from ..chat.message import Message
|
||||
from ..knowledge.knowledge_lib import qa_manager
|
||||
|
||||
logger = get_module_logger("knowledge_fetcher")
|
||||
|
||||
@@ -19,6 +20,25 @@ class KnowledgeFetcher:
|
||||
request_type="knowledge_fetch",
|
||||
)
|
||||
|
||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
|
||||
Returns:
|
||||
str: 构造好的,带相关度的知识
|
||||
"""
|
||||
|
||||
logger.debug("正在从LPMM知识库中获取知识")
|
||||
try:
|
||||
knowledge_info = qa_manager.get_knowledge(query)
|
||||
logger.debug(f"LPMM知识库查询结果: {knowledge_info:150}")
|
||||
return knowledge_info
|
||||
except Exception as e:
|
||||
logger.error(f"LPMM知识库搜索工具执行失败: {str(e)}")
|
||||
return "未找到匹配的知识"
|
||||
|
||||
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
|
||||
"""获取相关知识
|
||||
|
||||
@@ -43,13 +63,16 @@ class KnowledgeFetcher:
|
||||
max_depth=3,
|
||||
fast_retrieval=False,
|
||||
)
|
||||
|
||||
knowledge = ""
|
||||
if related_memory:
|
||||
knowledge = ""
|
||||
sources = []
|
||||
for memory in related_memory:
|
||||
knowledge += memory[1] + "\n"
|
||||
sources.append(f"记忆片段{memory[0]}")
|
||||
return knowledge.strip(), ",".join(sources)
|
||||
knowledge = knowledge.strip(), ",".join(sources)
|
||||
|
||||
knowledge += "现在有以下**知识**可供参考:\n "
|
||||
knowledge += self._lpmm_get_knowledge(query)
|
||||
knowledge += "请记住这些**知识**,并根据**知识**回答问题。\n"
|
||||
|
||||
return "未找到相关知识", "无记忆匹配"
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import json
|
||||
import datetime
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from ..message.message_base import UserInfo
|
||||
from maim_message import UserInfo
|
||||
|
||||
logger = get_module_logger("reply_checker")
|
||||
|
||||
@@ -22,7 +21,7 @@ class ReplyChecker:
|
||||
self.max_retries = 3 # 最大重试次数
|
||||
|
||||
async def check(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], retry_count: int = 0
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_text: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
"""检查生成的回复是否合适
|
||||
|
||||
@@ -36,7 +35,6 @@ class ReplyChecker:
|
||||
"""
|
||||
# 不再从 observer 获取,直接使用传入的 chat_history
|
||||
# messages = self.chat_observer.get_cached_messages(limit=20)
|
||||
chat_history_text = ""
|
||||
try:
|
||||
# 筛选出最近由 Bot 自己发送的消息
|
||||
bot_messages = []
|
||||
@@ -78,16 +76,9 @@ class ReplyChecker:
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"检查回复时出错: 类型={type(e)}, 值={e}")
|
||||
logger.error(traceback.format_exc()) # 打印详细的回溯信息
|
||||
|
||||
for msg in chat_history[-20:]:
|
||||
time_str = datetime.datetime.fromtimestamp(msg["time"]).strftime("%H:%M:%S")
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
sender = user_info.user_nickname or f"用户{user_info.user_id}"
|
||||
if sender == self.name:
|
||||
sender = "你说"
|
||||
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
|
||||
logger.error(f"检查回复时出错: 类型={type(e)}, 值={e}")
|
||||
logger.error(traceback.format_exc()) # 打印详细的回溯信息
|
||||
|
||||
prompt = f"""请检查以下回复或消息是否合适:
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from .reply_checker import ReplyChecker
|
||||
from src.individuality.individuality import Individuality
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
logger = get_module_logger("reply_generator")
|
||||
|
||||
@@ -68,23 +69,19 @@ class ReplyGenerator:
|
||||
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||
|
||||
# 获取聊天历史记录
|
||||
chat_history_list = (
|
||||
observation_info.chat_history[-20:]
|
||||
if len(observation_info.chat_history) >= 20
|
||||
else observation_info.chat_history
|
||||
)
|
||||
chat_history_text = ""
|
||||
for msg in chat_history_list:
|
||||
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
|
||||
if observation_info.new_messages_count > 0:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
|
||||
chat_history_text += f"有{observation_info.new_messages_count}条新消息:\n"
|
||||
for msg in new_messages_list:
|
||||
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
|
||||
|
||||
observation_info.clear_unprocessed_messages()
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
|
||||
# await observation_info.clear_unprocessed_messages()
|
||||
|
||||
identity_details_only = self.identity_detail_info
|
||||
identity_addon = ""
|
||||
@@ -173,7 +170,7 @@ class ReplyGenerator:
|
||||
return "抱歉,我现在有点混乱,让我重新思考一下..."
|
||||
|
||||
async def check_reply(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], retry_count: int = 0
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_str: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
"""检查回复是否合适
|
||||
|
||||
@@ -185,4 +182,4 @@ class ReplyGenerator:
|
||||
Returns:
|
||||
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
|
||||
"""
|
||||
return await self.reply_checker.check(reply, goal, chat_history, retry_count)
|
||||
return await self.reply_checker.check(reply, goal, chat_history, chat_history_str, retry_count)
|
||||
|
||||
@@ -4,7 +4,7 @@ MaiMBot插件系统
|
||||
"""
|
||||
|
||||
from .chat.chat_stream import chat_manager
|
||||
from .chat.emoji_manager import emoji_manager
|
||||
from .emoji_system.emoji_manager import emoji_manager
|
||||
from .person_info.relationship_manager import relationship_manager
|
||||
from .moods.moods import MoodManager
|
||||
from .willing.willing_manager import willing_manager
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .emoji_manager import emoji_manager
|
||||
from ..emoji_system.emoji_manager import emoji_manager
|
||||
from ..person_info.relationship_manager import relationship_manager
|
||||
from .chat_stream import chat_manager
|
||||
from .message_sender import message_manager
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Dict, Optional
|
||||
|
||||
|
||||
from ...common.database import db
|
||||
from ..message.message_base import GroupInfo, UserInfo
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from src.common.logger import get_module_logger, LogConfig, CHAT_STREAM_STYLE_CONFIG
|
||||
|
||||
|
||||
@@ -1,595 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from ...common.database import db
|
||||
from ...config.config import global_config
|
||||
from ..chat.utils import get_embedding
|
||||
from ..chat.utils_image import ImageManager, image_path_to_base64
|
||||
from ..models.utils_model import LLMRequest
|
||||
from src.common.logger import get_module_logger, LogConfig, EMOJI_STYLE_CONFIG
|
||||
|
||||
emoji_log_config = LogConfig(
|
||||
console_format=EMOJI_STYLE_CONFIG["console_format"],
|
||||
file_format=EMOJI_STYLE_CONFIG["file_format"],
|
||||
)
|
||||
|
||||
logger = get_module_logger("emoji", config=emoji_log_config)
|
||||
|
||||
|
||||
image_manager = ImageManager()
|
||||
|
||||
|
||||
class EmojiManager:
|
||||
_instance = None
|
||||
EMOJI_DIR = os.path.join("data", "emoji") # 表情包存储目录
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self._scan_task = None
|
||||
self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
|
||||
self.llm_emotion_judge = LLMRequest(
|
||||
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji"
|
||||
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||
|
||||
self.emoji_num = 0
|
||||
self.emoji_num_max = global_config.max_emoji_num
|
||||
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
|
||||
|
||||
logger.info("启动表情包管理器")
|
||||
|
||||
def _ensure_emoji_dir(self):
|
||||
"""确保表情存储目录存在"""
|
||||
os.makedirs(self.EMOJI_DIR, exist_ok=True)
|
||||
|
||||
def _update_emoji_count(self):
|
||||
"""更新表情包数量统计
|
||||
|
||||
检查数据库中的表情包数量并更新到 self.emoji_num
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
self.emoji_num = db.emoji.count_documents({})
|
||||
logger.info(f"[统计] 当前表情包数量: {self.emoji_num}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 更新表情包数量失败: {str(e)}")
|
||||
|
||||
def initialize(self):
|
||||
"""初始化数据库连接和表情目录"""
|
||||
if not self._initialized:
|
||||
try:
|
||||
self._ensure_emoji_collection()
|
||||
self._ensure_emoji_dir()
|
||||
self._initialized = True
|
||||
# 更新表情包数量
|
||||
self._update_emoji_count()
|
||||
# 启动时执行一次完整性检查
|
||||
self.check_emoji_file_integrity()
|
||||
except Exception:
|
||||
logger.exception("初始化表情管理器失败")
|
||||
|
||||
def _ensure_db(self):
|
||||
"""确保数据库已初始化"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
if not self._initialized:
|
||||
raise RuntimeError("EmojiManager not initialized")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_emoji_collection():
|
||||
"""确保emoji集合存在并创建索引
|
||||
|
||||
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
|
||||
|
||||
索引的作用是加快数据库查询速度:
|
||||
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
|
||||
- tags字段的普通索引: 加快按标签搜索表情包的速度
|
||||
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
|
||||
|
||||
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
|
||||
"""
|
||||
if "emoji" not in db.list_collection_names():
|
||||
db.create_collection("emoji")
|
||||
db.emoji.create_index([("embedding", "2dsphere")])
|
||||
db.emoji.create_index([("filename", 1)], unique=True)
|
||||
|
||||
def record_usage(self, emoji_id: str):
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
db.emoji.update_one({"_id": emoji_id}, {"$inc": {"usage_count": 1}})
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
|
||||
async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str, str]]:
|
||||
"""根据文本内容获取相关表情包
|
||||
Args:
|
||||
text: 输入文本
|
||||
Returns:
|
||||
Optional[str]: 表情包文件路径,如果没有找到则返回None
|
||||
|
||||
|
||||
可不可以通过 配置文件中的指令 来自定义使用表情包的逻辑?
|
||||
我觉得可行
|
||||
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
|
||||
# 获取文本的embedding
|
||||
text_for_search = await self._get_kimoji_for_text(text)
|
||||
if not text_for_search:
|
||||
logger.error("无法获取文本的情绪")
|
||||
return None
|
||||
text_embedding = await get_embedding(text_for_search, request_type="emoji")
|
||||
if not text_embedding:
|
||||
logger.error("无法获取文本的embedding")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 获取所有表情包
|
||||
all_emojis = [
|
||||
e
|
||||
for e in db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1, "blacklist": 1})
|
||||
if "blacklist" not in e
|
||||
]
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("数据库中没有任何表情包")
|
||||
return None
|
||||
|
||||
# 计算余弦相似度并排序
|
||||
def cosine_similarity(v1, v2):
|
||||
if not v1 or not v2:
|
||||
return 0
|
||||
dot_product = sum(a * b for a, b in zip(v1, v2))
|
||||
norm_v1 = sum(a * a for a in v1) ** 0.5
|
||||
norm_v2 = sum(b * b for b in v2) ** 0.5
|
||||
if norm_v1 == 0 or norm_v2 == 0:
|
||||
return 0
|
||||
return dot_product / (norm_v1 * norm_v2)
|
||||
|
||||
# 计算所有表情包与输入文本的相似度
|
||||
emoji_similarities = [
|
||||
(emoji, cosine_similarity(text_embedding, emoji.get("embedding", []))) for emoji in all_emojis
|
||||
]
|
||||
|
||||
# 按相似度降序排序
|
||||
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 获取前3个最相似的表情包
|
||||
top_10_emojis = emoji_similarities[: 10 if len(emoji_similarities) > 10 else len(emoji_similarities)]
|
||||
|
||||
if not top_10_emojis:
|
||||
logger.warning("未找到匹配的表情包")
|
||||
return None
|
||||
|
||||
# 从前3个中随机选择一个
|
||||
selected_emoji, similarity = random.choice(top_10_emojis)
|
||||
|
||||
if selected_emoji and "path" in selected_emoji:
|
||||
# 更新使用次数
|
||||
db.emoji.update_one({"_id": selected_emoji["_id"]}, {"$inc": {"usage_count": 1}})
|
||||
|
||||
logger.info(
|
||||
f"[匹配] 找到表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})"
|
||||
)
|
||||
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
|
||||
return selected_emoji["path"], "[ %s ]" % selected_emoji.get("description", "无描述")
|
||||
|
||||
except Exception as search_error:
|
||||
logger.error(f"[错误] 搜索表情包失败: {str(search_error)}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 获取表情包失败: {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _get_emoji_description(image_base64: str) -> str:
|
||||
"""获取表情包的标签,使用image_manager的描述生成功能"""
|
||||
|
||||
try:
|
||||
# 使用image_manager获取描述,去掉前后的方括号和"表情包:"前缀
|
||||
description = await image_manager.get_emoji_description(image_base64)
|
||||
# 去掉[表情包:xxx]的格式,只保留描述内容
|
||||
description = description.strip("[]").replace("表情包:", "")
|
||||
return description
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 获取表情包描述失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _check_emoji(self, image_base64: str, image_format: str) -> str:
|
||||
try:
|
||||
prompt = (
|
||||
f'这是一个表情包,请回答这个表情包是否满足"{global_config.EMOJI_CHECK_PROMPT}"的要求,是则回答是,'
|
||||
f"否则回答否,不要出现任何其他内容"
|
||||
)
|
||||
|
||||
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
logger.debug(f"[检查] 表情包检查结果: {content}")
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 表情包检查失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _get_kimoji_for_text(self, text: str):
|
||||
try:
|
||||
prompt = (
|
||||
f"这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,"
|
||||
f"请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,"
|
||||
f'注意不要输出任何对消息内容的分析内容,只输出"一种什么样的感觉"中间的形容词部分。'
|
||||
)
|
||||
|
||||
content, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=1.5)
|
||||
logger.info(f"[情感] 表情包情感描述: {content}")
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 获取表情包情感失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def scan_new_emojis(self):
|
||||
"""扫描新的表情包"""
|
||||
try:
|
||||
emoji_dir = self.EMOJI_DIR
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
|
||||
# 获取所有支持的图片文件
|
||||
files_to_process = [
|
||||
f for f in os.listdir(emoji_dir) if f.lower().endswith((".jpg", ".jpeg", ".png", ".gif"))
|
||||
]
|
||||
|
||||
# 检查当前表情包数量
|
||||
self._update_emoji_count()
|
||||
if self.emoji_num >= self.emoji_num_max:
|
||||
logger.warning(f"[警告] 表情包数量已达到上限({self.emoji_num}/{self.emoji_num_max}),跳过注册")
|
||||
return
|
||||
|
||||
# 计算还可以注册的数量
|
||||
remaining_slots = self.emoji_num_max - self.emoji_num
|
||||
logger.info(f"[注册] 还可以注册 {remaining_slots} 个表情包")
|
||||
|
||||
for filename in files_to_process:
|
||||
# 如果已经达到上限,停止注册
|
||||
if self.emoji_num >= self.emoji_num_max:
|
||||
logger.warning(f"[警告] 表情包数量已达到上限({self.emoji_num}/{self.emoji_num_max}),停止注册")
|
||||
break
|
||||
|
||||
image_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
# 获取图片的base64编码和哈希值
|
||||
image_base64 = image_path_to_base64(image_path)
|
||||
if image_base64 is None:
|
||||
os.remove(image_path)
|
||||
continue
|
||||
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
# 检查是否已经注册过
|
||||
existing_emoji_by_path = db["emoji"].find_one({"filename": filename})
|
||||
existing_emoji_by_hash = db["emoji"].find_one({"hash": image_hash})
|
||||
if existing_emoji_by_path and existing_emoji_by_hash:
|
||||
if existing_emoji_by_path["_id"] != existing_emoji_by_hash["_id"]:
|
||||
logger.error(f"[错误] 表情包已存在但记录不一致: {filename}")
|
||||
db.emoji.delete_one({"_id": existing_emoji_by_path["_id"]})
|
||||
db.emoji.delete_one({"_id": existing_emoji_by_hash["_id"]})
|
||||
existing_emoji = None
|
||||
else:
|
||||
existing_emoji = existing_emoji_by_hash
|
||||
elif existing_emoji_by_hash:
|
||||
logger.error(f"[错误] 表情包hash已存在但path不存在: {filename}")
|
||||
db.emoji.delete_one({"_id": existing_emoji_by_hash["_id"]})
|
||||
existing_emoji = None
|
||||
elif existing_emoji_by_path:
|
||||
logger.error(f"[错误] 表情包path已存在但hash不存在: {filename}")
|
||||
db.emoji.delete_one({"_id": existing_emoji_by_path["_id"]})
|
||||
existing_emoji = None
|
||||
else:
|
||||
existing_emoji = None
|
||||
|
||||
description = None
|
||||
|
||||
if existing_emoji:
|
||||
# 即使表情包已存在,也检查是否需要同步到images集合
|
||||
description = existing_emoji.get("description")
|
||||
# 检查是否在images集合中存在
|
||||
existing_image = db.images.find_one({"hash": image_hash})
|
||||
if not existing_image:
|
||||
# 同步到images集合
|
||||
image_doc = {
|
||||
"hash": image_hash,
|
||||
"path": image_path,
|
||||
"type": "emoji",
|
||||
"description": description,
|
||||
"timestamp": int(time.time()),
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
# 保存描述到image_descriptions集合
|
||||
image_manager._save_description_to_db(image_hash, description, "emoji")
|
||||
logger.success(f"[同步] 已同步表情包到images集合: {filename}")
|
||||
continue
|
||||
|
||||
# 检查是否在images集合中已有描述
|
||||
existing_description = image_manager._get_description_from_db(image_hash, "emoji")
|
||||
|
||||
if existing_description:
|
||||
description = existing_description
|
||||
else:
|
||||
# 获取表情包的描述
|
||||
description = await self._get_emoji_description(image_base64)
|
||||
|
||||
if global_config.EMOJI_CHECK:
|
||||
check = await self._check_emoji(image_base64, image_format)
|
||||
if "是" not in check:
|
||||
os.remove(image_path)
|
||||
logger.info(f"[过滤] 表情包描述: {description}")
|
||||
logger.info(f"[过滤] 表情包不满足规则,已移除: {check}")
|
||||
continue
|
||||
logger.info(f"[检查] 表情包检查通过: {check}")
|
||||
|
||||
if description is not None:
|
||||
embedding = await get_embedding(description, request_type="emoji")
|
||||
if not embedding:
|
||||
logger.error("获取消息嵌入向量失败")
|
||||
raise ValueError("获取消息嵌入向量失败")
|
||||
# 准备数据库记录
|
||||
emoji_record = {
|
||||
"filename": filename,
|
||||
"path": image_path,
|
||||
"embedding": embedding,
|
||||
"description": description,
|
||||
"hash": image_hash,
|
||||
"timestamp": int(time.time()),
|
||||
}
|
||||
|
||||
# 保存到emoji数据库
|
||||
db["emoji"].insert_one(emoji_record)
|
||||
logger.success(f"[注册] 新表情包: {filename}")
|
||||
logger.info(f"[描述] {description}")
|
||||
|
||||
# 更新当前表情包数量
|
||||
self.emoji_num += 1
|
||||
logger.info(f"[统计] 当前表情包数量: {self.emoji_num}/{self.emoji_num_max}")
|
||||
|
||||
# 保存到images数据库
|
||||
image_doc = {
|
||||
"hash": image_hash,
|
||||
"path": image_path,
|
||||
"type": "emoji",
|
||||
"description": description,
|
||||
"timestamp": int(time.time()),
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
# 保存描述到image_descriptions集合
|
||||
image_manager._save_description_to_db(image_hash, description, "emoji")
|
||||
logger.success(f"[同步] 已保存到images集合: {filename}")
|
||||
else:
|
||||
logger.warning(f"[跳过] 表情包: {filename}")
|
||||
|
||||
except Exception:
|
||||
logger.exception("[错误] 扫描表情包失败")
|
||||
|
||||
def check_emoji_file_integrity(self):
|
||||
"""检查表情包文件完整性
|
||||
如果文件已被删除,则从数据库中移除对应记录
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
# 获取所有表情包记录
|
||||
all_emojis = list(db.emoji.find())
|
||||
removed_count = 0
|
||||
total_count = len(all_emojis)
|
||||
|
||||
for emoji in all_emojis:
|
||||
try:
|
||||
if "path" not in emoji:
|
||||
logger.warning(f"[检查] 发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}")
|
||||
db.emoji.delete_one({"_id": emoji["_id"]})
|
||||
removed_count += 1
|
||||
continue
|
||||
|
||||
if "embedding" not in emoji:
|
||||
logger.warning(f"[检查] 发现过时记录(缺少embedding字段),ID: {emoji.get('_id', 'unknown')}")
|
||||
db.emoji.delete_one({"_id": emoji["_id"]})
|
||||
removed_count += 1
|
||||
continue
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(emoji["path"]):
|
||||
logger.warning(f"[检查] 表情包文件已被删除: {emoji['path']}")
|
||||
# 从数据库中删除记录
|
||||
result = db.emoji.delete_one({"_id": emoji["_id"]})
|
||||
if result.deleted_count > 0:
|
||||
logger.debug(f"[清理] 成功删除数据库记录: {emoji['_id']}")
|
||||
removed_count += 1
|
||||
else:
|
||||
logger.error(f"[错误] 删除数据库记录失败: {emoji['_id']}")
|
||||
continue
|
||||
|
||||
if "hash" not in emoji:
|
||||
logger.warning(f"[检查] 发现缺失记录(缺少hash字段),ID: {emoji.get('_id', 'unknown')}")
|
||||
hash = hashlib.md5(open(emoji["path"], "rb").read()).hexdigest()
|
||||
db.emoji.update_one({"_id": emoji["_id"]}, {"$set": {"hash": hash}})
|
||||
else:
|
||||
file_hash = hashlib.md5(open(emoji["path"], "rb").read()).hexdigest()
|
||||
if emoji["hash"] != file_hash:
|
||||
logger.warning(f"[检查] 表情包文件hash不匹配,ID: {emoji.get('_id', 'unknown')}")
|
||||
db.emoji.delete_one({"_id": emoji["_id"]})
|
||||
removed_count += 1
|
||||
|
||||
# 修复拼写错误
|
||||
if "discription" in emoji:
|
||||
desc = emoji["discription"]
|
||||
db.emoji.update_one(
|
||||
{"_id": emoji["_id"]}, {"$unset": {"discription": ""}, "$set": {"description": desc}}
|
||||
)
|
||||
|
||||
except Exception as item_error:
|
||||
logger.error(f"[错误] 处理表情包记录时出错: {str(item_error)}")
|
||||
continue
|
||||
|
||||
# 验证清理结果
|
||||
remaining_count = db.emoji.count_documents({})
|
||||
if removed_count > 0:
|
||||
logger.success(f"[清理] 已清理 {removed_count} 个失效的表情包记录")
|
||||
logger.info(f"[统计] 清理前: {total_count} | 清理后: {remaining_count}")
|
||||
else:
|
||||
logger.info(f"[检查] 已检查 {total_count} 个表情包记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def check_emoji_file_full(self):
|
||||
"""检查表情包文件是否完整,如果数量超出限制且允许删除,则删除多余的表情包
|
||||
|
||||
删除规则:
|
||||
1. 优先删除创建时间更早的表情包
|
||||
2. 优先删除使用次数少的表情包,但使用次数多的也有小概率被删除
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
# 更新表情包数量
|
||||
self._update_emoji_count()
|
||||
|
||||
# 检查是否超出限制
|
||||
if self.emoji_num <= self.emoji_num_max:
|
||||
return
|
||||
|
||||
# 如果超出限制但不允许删除,则只记录警告
|
||||
if not global_config.max_reach_deletion:
|
||||
logger.warning(f"[警告] 表情包数量({self.emoji_num})超出限制({self.emoji_num_max}),但未开启自动删除")
|
||||
return
|
||||
|
||||
# 计算需要删除的数量
|
||||
delete_count = self.emoji_num - self.emoji_num_max
|
||||
logger.info(f"[清理] 需要删除 {delete_count} 个表情包")
|
||||
|
||||
# 获取所有表情包,按时间戳升序(旧的在前)排序
|
||||
all_emojis = list(db.emoji.find().sort([("timestamp", 1)]))
|
||||
|
||||
# 计算权重:使用次数越多,被删除的概率越小
|
||||
weights = []
|
||||
max_usage = max((emoji.get("usage_count", 0) for emoji in all_emojis), default=1)
|
||||
for emoji in all_emojis:
|
||||
usage_count = emoji.get("usage_count", 0)
|
||||
# 使用指数衰减函数计算权重,使用次数越多权重越小
|
||||
weight = 1.0 / (1.0 + usage_count / max(1, max_usage))
|
||||
weights.append(weight)
|
||||
|
||||
# 根据权重随机选择要删除的表情包
|
||||
to_delete = []
|
||||
remaining_indices = list(range(len(all_emojis)))
|
||||
|
||||
while len(to_delete) < delete_count and remaining_indices:
|
||||
# 计算当前剩余表情包的权重
|
||||
current_weights = [weights[i] for i in remaining_indices]
|
||||
# 归一化权重
|
||||
total_weight = sum(current_weights)
|
||||
if total_weight == 0:
|
||||
break
|
||||
normalized_weights = [w / total_weight for w in current_weights]
|
||||
|
||||
# 随机选择一个表情包
|
||||
selected_idx = random.choices(remaining_indices, weights=normalized_weights, k=1)[0]
|
||||
to_delete.append(all_emojis[selected_idx])
|
||||
remaining_indices.remove(selected_idx)
|
||||
|
||||
# 删除选中的表情包
|
||||
deleted_count = 0
|
||||
for emoji in to_delete:
|
||||
try:
|
||||
# 删除文件
|
||||
if "path" in emoji and os.path.exists(emoji["path"]):
|
||||
os.remove(emoji["path"])
|
||||
logger.info(f"[删除] 文件: {emoji['path']} (使用次数: {emoji.get('usage_count', 0)})")
|
||||
|
||||
# 删除数据库记录
|
||||
db.emoji.delete_one({"_id": emoji["_id"]})
|
||||
deleted_count += 1
|
||||
|
||||
# 同时从images集合中删除
|
||||
if "hash" in emoji:
|
||||
db.images.delete_one({"hash": emoji["hash"]})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除表情包失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 更新表情包数量
|
||||
self._update_emoji_count()
|
||||
logger.success(f"[清理] 已删除 {deleted_count} 个表情包,当前数量: {self.emoji_num}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 检查表情包数量失败: {str(e)}")
|
||||
|
||||
async def start_periodic_check_register(self):
|
||||
"""定期检查表情包完整性和数量"""
|
||||
while True:
|
||||
logger.info("[扫描] 开始检查表情包完整性...")
|
||||
self.check_emoji_file_integrity()
|
||||
logger.info("[扫描] 开始删除所有图片缓存...")
|
||||
await self.delete_all_images()
|
||||
logger.info("[扫描] 开始扫描新表情包...")
|
||||
if self.emoji_num < self.emoji_num_max:
|
||||
await self.scan_new_emojis()
|
||||
if self.emoji_num > self.emoji_num_max:
|
||||
logger.warning(f"[警告] 表情包数量超过最大限制: {self.emoji_num} > {self.emoji_num_max},跳过注册")
|
||||
if not global_config.max_reach_deletion:
|
||||
logger.warning("表情包数量超过最大限制,终止注册")
|
||||
break
|
||||
else:
|
||||
logger.warning("表情包数量超过最大限制,开始删除表情包")
|
||||
self.check_emoji_file_full()
|
||||
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
||||
|
||||
@staticmethod
|
||||
async def delete_all_images():
|
||||
"""删除 data/image 目录下的所有文件"""
|
||||
try:
|
||||
image_dir = os.path.join("data", "image")
|
||||
if not os.path.exists(image_dir):
|
||||
logger.warning(f"[警告] 目录不存在: {image_dir}")
|
||||
return
|
||||
|
||||
deleted_count = 0
|
||||
failed_count = 0
|
||||
|
||||
# 遍历目录下的所有文件
|
||||
for filename in os.listdir(image_dir):
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
deleted_count += 1
|
||||
logger.debug(f"[删除] 文件: {file_path}")
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
logger.error(f"[错误] 删除文件失败 {file_path}: {str(e)}")
|
||||
|
||||
logger.success(f"[清理] 已删除 {deleted_count} 个文件,失败 {failed_count} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除图片目录失败: {str(e)}")
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
emoji_manager = EmojiManager()
|
||||
@@ -7,7 +7,7 @@ import urllib3
|
||||
from src.common.logger import get_module_logger
|
||||
from .chat_stream import ChatStream
|
||||
from .utils_image import image_manager
|
||||
from ..message.message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
|
||||
logger = get_module_logger("chat_message")
|
||||
|
||||
@@ -127,12 +127,12 @@ class MessageRecv(Message):
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_image_description(seg.data)
|
||||
return "[图片]"
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif seg.type == "emoji":
|
||||
self.is_emoji = True
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_emoji_description(seg.data)
|
||||
return "[表情]"
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
except Exception as e:
|
||||
@@ -141,14 +141,8 @@ class MessageRecv(Message):
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
"""生成详细文本,包含时间和用户信息"""
|
||||
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
|
||||
timestamp = self.message_info.time
|
||||
user_info = self.message_info.user_info
|
||||
# name = (
|
||||
# f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
|
||||
# if user_info.user_cardname != None
|
||||
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
|
||||
# )
|
||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
||||
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
|
||||
|
||||
@@ -222,11 +216,11 @@ class MessageProcessBase(Message):
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_image_description(seg.data)
|
||||
return "[图片]"
|
||||
return "[图片,网卡了加载不出来]"
|
||||
elif seg.type == "emoji":
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_emoji_description(seg.data)
|
||||
return "[表情]"
|
||||
return "[表情,网卡了加载不出来]"
|
||||
elif seg.type == "at":
|
||||
return f"[@{seg.data}]"
|
||||
elif seg.type == "reply":
|
||||
|
||||
@@ -3,7 +3,7 @@ from src.common.logger import get_module_logger
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from .message import MessageRecv
|
||||
from ..message.message_base import BaseMessageInfo, GroupInfo, Seg
|
||||
from maim_message import BaseMessageInfo, GroupInfo
|
||||
import hashlib
|
||||
from typing import Dict
|
||||
from collections import OrderedDict
|
||||
@@ -128,58 +128,67 @@ class MessageBuffer:
|
||||
if result:
|
||||
async with self.lock: # 再次加锁
|
||||
# 清理所有早于当前消息的已处理消息, 收集所有早于当前消息的F消息的processed_plain_text
|
||||
keep_msgs = OrderedDict()
|
||||
combined_text = []
|
||||
found = False
|
||||
type = "seglist"
|
||||
is_update = True
|
||||
for msg_id, msg in self.buffer_pool[person_id_].items():
|
||||
keep_msgs = OrderedDict() # 用于存放 T 消息之后的消息
|
||||
collected_texts = [] # 用于收集 T 消息及之前 F 消息的文本
|
||||
process_target_found = False
|
||||
|
||||
# 遍历当前用户的所有缓冲消息
|
||||
for msg_id, cache_msg in self.buffer_pool[person_id_].items():
|
||||
# 如果找到了目标处理消息 (T 状态)
|
||||
if msg_id == message.message_info.message_id:
|
||||
found = True
|
||||
if msg.message.message_segment.type != "seglist":
|
||||
type = msg.message.message_segment.type
|
||||
else:
|
||||
if (
|
||||
isinstance(msg.message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in msg.message.message_segment.data)
|
||||
and len(msg.message.message_segment.data) == 1
|
||||
):
|
||||
type = msg.message.message_segment.data[0].type
|
||||
combined_text.append(msg.message.processed_plain_text)
|
||||
continue
|
||||
if found:
|
||||
keep_msgs[msg_id] = msg
|
||||
elif msg.result == "F":
|
||||
# 收集F消息的文本内容
|
||||
f_type = "seglist"
|
||||
if msg.message.message_segment.type != "seglist":
|
||||
f_type = msg.message.message_segment.type
|
||||
else:
|
||||
if (
|
||||
isinstance(msg.message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in msg.message.message_segment.data)
|
||||
and len(msg.message.message_segment.data) == 1
|
||||
):
|
||||
f_type = msg.message.message_segment.data[0].type
|
||||
if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
|
||||
if f_type == "text":
|
||||
combined_text.append(msg.message.processed_plain_text)
|
||||
elif f_type != "text":
|
||||
is_update = False
|
||||
elif msg.result == "U":
|
||||
logger.debug(f"异常未处理信息id: {msg.message.message_info.message_id}")
|
||||
process_target_found = True
|
||||
# 收集这条 T 消息的文本 (如果有)
|
||||
if (
|
||||
hasattr(cache_msg.message, "processed_plain_text")
|
||||
and cache_msg.message.processed_plain_text
|
||||
):
|
||||
collected_texts.append(cache_msg.message.processed_plain_text)
|
||||
# 不立即放入 keep_msgs,因为它之前的 F 消息也处理完了
|
||||
|
||||
# 更新当前消息的processed_plain_text
|
||||
if combined_text and combined_text[0] != message.processed_plain_text and is_update:
|
||||
if type == "text":
|
||||
message.processed_plain_text = ",".join(combined_text)
|
||||
logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容到当前消息")
|
||||
elif type == "emoji":
|
||||
combined_text.pop()
|
||||
message.processed_plain_text = ",".join(combined_text)
|
||||
message.is_emoji = False
|
||||
logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容,覆盖当前emoji消息")
|
||||
# 如果已经找到了目标 T 消息,之后的消息需要保留
|
||||
elif process_target_found:
|
||||
keep_msgs[msg_id] = cache_msg
|
||||
|
||||
# 如果还没找到目标 T 消息,说明是之前的消息 (F 或 U)
|
||||
else:
|
||||
if cache_msg.result == "F":
|
||||
# 收集这条 F 消息的文本 (如果有)
|
||||
if (
|
||||
hasattr(cache_msg.message, "processed_plain_text")
|
||||
and cache_msg.message.processed_plain_text
|
||||
):
|
||||
collected_texts.append(cache_msg.message.processed_plain_text)
|
||||
elif cache_msg.result == "U":
|
||||
# 理论上不应该在 T 消息之前还有 U 消息,记录日志
|
||||
logger.warning(
|
||||
f"异常状态:在目标 T 消息 {message.message_info.message_id} 之前发现未处理的 U 消息 {cache_msg.message.message_info.message_id}"
|
||||
)
|
||||
# 也可以选择收集其文本
|
||||
if (
|
||||
hasattr(cache_msg.message, "processed_plain_text")
|
||||
and cache_msg.message.processed_plain_text
|
||||
):
|
||||
collected_texts.append(cache_msg.message.processed_plain_text)
|
||||
|
||||
# 更新当前消息 (message) 的 processed_plain_text
|
||||
# 只有在收集到的文本多于一条,或者只有一条但与原始文本不同时才合并
|
||||
if collected_texts:
|
||||
# 使用 OrderedDict 去重,同时保留原始顺序
|
||||
unique_texts = list(OrderedDict.fromkeys(collected_texts))
|
||||
merged_text = ",".join(unique_texts)
|
||||
|
||||
# 只有在合并后的文本与原始文本不同时才更新
|
||||
# 并且确保不是空合并
|
||||
if merged_text and merged_text != message.processed_plain_text:
|
||||
message.processed_plain_text = merged_text
|
||||
# 如果合并了文本,原消息不再视为纯 emoji
|
||||
if hasattr(message, "is_emoji"):
|
||||
message.is_emoji = False
|
||||
logger.debug(
|
||||
f"合并了 {len(unique_texts)} 条消息的文本内容到当前消息 {message.message_info.message_id}"
|
||||
)
|
||||
|
||||
# 更新缓冲池,只保留 T 消息之后的消息
|
||||
self.buffer_pool[person_id_] = keep_msgs
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -62,20 +62,10 @@ class MessageSender:
|
||||
# logger.trace(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志
|
||||
# --- 结束打字延迟 ---
|
||||
|
||||
message_json = message.to_dict()
|
||||
message_preview = truncate_message(message.processed_plain_text)
|
||||
|
||||
try:
|
||||
end_point = global_config.api_urls.get(message.message_info.platform, None)
|
||||
if end_point:
|
||||
try:
|
||||
await global_api.send_message_rest(end_point, message_json)
|
||||
except Exception as e:
|
||||
logger.error(f"REST发送失败: {str(e)}")
|
||||
logger.info(f"[{message.chat_stream.stream_id}] 尝试使用WS发送")
|
||||
await self.send_via_ws(message)
|
||||
else:
|
||||
await self.send_via_ws(message)
|
||||
await self.send_via_ws(message)
|
||||
logger.success(f"发送消息 '{message_preview}' 成功") # 调整日志格式
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}")
|
||||
|
||||
@@ -12,7 +12,7 @@ from ..models.utils_model import LLMRequest
|
||||
from ..utils.typo_generator import ChineseTypoGenerator
|
||||
from ...config.config import global_config
|
||||
from .message import MessageRecv, Message
|
||||
from ..message.message_base import UserInfo
|
||||
from maim_message import UserInfo
|
||||
from .chat_stream import ChatStream
|
||||
from ..moods.moods import MoodManager
|
||||
from ...common.database import db
|
||||
@@ -234,6 +234,13 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
||||
Returns:
|
||||
List[str]: 分割和合并后的句子列表
|
||||
"""
|
||||
# 预处理:处理多余的换行符
|
||||
# 1. 将连续的换行符替换为单个换行符
|
||||
text = re.sub(r"\n\s*\n+", "\n", text)
|
||||
# 2. 处理换行符和其他分隔符的组合
|
||||
text = re.sub(r"\n\s*([,,。;\s])", r"\1", text)
|
||||
text = re.sub(r"([,,。;\s])\s*\n", r"\1", text)
|
||||
|
||||
# 处理两个汉字中间的换行符
|
||||
text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text)
|
||||
|
||||
@@ -365,12 +372,16 @@ def random_remove_punctuation(text: str) -> str:
|
||||
|
||||
def process_llm_response(text: str) -> List[str]:
|
||||
# 先保护颜文字
|
||||
protected_text, kaomoji_mapping = protect_kaomoji(text)
|
||||
logger.trace(f"保护颜文字后的文本: {protected_text}")
|
||||
if global_config.enable_kaomoji_protection:
|
||||
protected_text, kaomoji_mapping = protect_kaomoji(text)
|
||||
logger.trace(f"保护颜文字后的文本: {protected_text}")
|
||||
else:
|
||||
protected_text = text
|
||||
kaomoji_mapping = {}
|
||||
# 提取被 () 或 [] 包裹且包含中文的内容
|
||||
pattern = re.compile(r"[\(\[\(](?=.*[\u4e00-\u9fff]).*?[\)\]\)]")
|
||||
# _extracted_contents = pattern.findall(text)
|
||||
extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
|
||||
_extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
|
||||
# 去除 () 和 [] 及其包裹的内容
|
||||
cleaned_text = pattern.sub("", protected_text)
|
||||
|
||||
@@ -413,13 +424,14 @@ def process_llm_response(text: str) -> List[str]:
|
||||
if len(sentences) > max_sentence_num:
|
||||
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||
return [f"{global_config.BOT_NICKNAME}不知道哦"]
|
||||
if extracted_contents:
|
||||
for content in extracted_contents:
|
||||
sentences.append(content)
|
||||
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
|
||||
sentences = recover_kaomoji(sentences, kaomoji_mapping)
|
||||
|
||||
print(sentences)
|
||||
# if extracted_contents:
|
||||
# for content in extracted_contents:
|
||||
# sentences.append(content)
|
||||
|
||||
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
|
||||
if global_config.enable_kaomoji_protection:
|
||||
sentences = recover_kaomoji(sentences, kaomoji_mapping)
|
||||
|
||||
return sentences
|
||||
|
||||
|
||||
@@ -121,7 +121,7 @@ class ImageManager:
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
else:
|
||||
prompt = "这是一个表情包,请用使用1-2个词描述一下表情包所表达的情感和内容,简短一些"
|
||||
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
@@ -130,7 +130,7 @@ class ImageManager:
|
||||
return f"[表达了:{cached_description}]"
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.EMOJI_SAVE:
|
||||
if global_config.save_emoji:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
@@ -152,7 +152,7 @@ class ImageManager:
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
logger.success(f"保存表情包: {file_path}")
|
||||
logger.trace(f"保存表情包: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件失败: {str(e)}")
|
||||
|
||||
@@ -196,7 +196,7 @@ class ImageManager:
|
||||
return "[图片]"
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.EMOJI_SAVE:
|
||||
if global_config.save_pic:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
@@ -309,11 +309,15 @@ def image_path_to_base64(image_path: str) -> str:
|
||||
image_path: 图片文件路径
|
||||
Returns:
|
||||
str: base64编码的图片数据
|
||||
Raises:
|
||||
FileNotFoundError: 当图片文件不存在时
|
||||
IOError: 当读取图片文件失败时
|
||||
"""
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"读取图片失败: {image_path}, 错误: {str(e)}")
|
||||
return None
|
||||
if not os.path.exists(image_path):
|
||||
raise FileNotFoundError(f"图片文件不存在: {image_path}")
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
if not image_data:
|
||||
raise IOError(f"读取图片文件失败: {image_path}")
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
827
src/plugins/emoji_system/emoji_manager.py
Normal file
827
src/plugins/emoji_system/emoji_manager.py
Normal file
@@ -0,0 +1,827 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import io
|
||||
import re
|
||||
|
||||
from ...common.database import db
|
||||
from ...config.config import global_config
|
||||
from ..chat.utils_image import image_path_to_base64, image_manager
|
||||
from ..models.utils_model import LLMRequest
|
||||
from src.common.logger import get_module_logger, LogConfig, EMOJI_STYLE_CONFIG
|
||||
|
||||
|
||||
emoji_log_config = LogConfig(
|
||||
console_format=EMOJI_STYLE_CONFIG["console_format"],
|
||||
file_format=EMOJI_STYLE_CONFIG["file_format"],
|
||||
)
|
||||
|
||||
logger = get_module_logger("emoji", config=emoji_log_config)
|
||||
|
||||
BASE_DIR = os.path.join("data")
|
||||
EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
||||
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
||||
|
||||
|
||||
"""
|
||||
还没经过测试,有些地方数据库和内存数据同步可能不完全
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class MaiEmoji:
|
||||
"""定义一个表情包"""
|
||||
|
||||
def __init__(self, filename: str, path: str):
|
||||
self.path = path # 存储目录路径
|
||||
self.filename = filename
|
||||
self.embedding = []
|
||||
self.hash = "" # 初始为空,在创建实例时会计算
|
||||
self.description = ""
|
||||
self.emotion = []
|
||||
self.usage_count = 0
|
||||
self.last_used_time = time.time()
|
||||
self.register_time = time.time()
|
||||
self.is_deleted = False # 标记是否已被删除
|
||||
self.format = ""
|
||||
|
||||
async def initialize_hash_format(self):
|
||||
"""从文件创建表情包实例
|
||||
|
||||
参数:
|
||||
file_path: 文件的完整路径
|
||||
|
||||
返回:
|
||||
MaiEmoji: 创建的表情包实例,如果失败则返回None
|
||||
"""
|
||||
try:
|
||||
file_path = os.path.join(self.path, self.filename)
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"[错误] 表情包文件不存在: {file_path}")
|
||||
return None
|
||||
|
||||
image_base64 = image_path_to_base64(file_path)
|
||||
if image_base64 is None:
|
||||
logger.error(f"[错误] 无法读取图片: {file_path}")
|
||||
return None
|
||||
|
||||
# 计算哈希值
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
self.hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 获取图片格式
|
||||
self.format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 初始化表情包失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
async def register_to_db(self):
|
||||
"""
|
||||
注册表情包
|
||||
将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下
|
||||
并修改对应的实例属性,然后将表情包信息保存到数据库中
|
||||
"""
|
||||
try:
|
||||
# 确保目标目录存在
|
||||
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
|
||||
|
||||
# 源路径是当前实例的完整路径
|
||||
source_path = os.path.join(self.path, self.filename)
|
||||
# 目标路径
|
||||
destination_path = os.path.join(EMOJI_REGISTED_DIR, self.filename)
|
||||
|
||||
# 检查源文件是否存在
|
||||
if not os.path.exists(source_path):
|
||||
logger.error(f"[错误] 源文件不存在: {source_path}")
|
||||
return False
|
||||
|
||||
# --- 文件移动 ---
|
||||
try:
|
||||
# 如果目标文件已存在,先删除 (确保移动成功)
|
||||
if os.path.exists(destination_path):
|
||||
os.remove(destination_path)
|
||||
|
||||
os.rename(source_path, destination_path)
|
||||
logger.info(f"[移动] 文件从 {source_path} 移动到 {destination_path}")
|
||||
# 更新实例的路径属性为新目录
|
||||
self.path = EMOJI_REGISTED_DIR
|
||||
except Exception as move_error:
|
||||
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
|
||||
return False # 文件移动失败,不继续
|
||||
|
||||
# --- 数据库操作 ---
|
||||
try:
|
||||
# 准备数据库记录 for emoji collection
|
||||
emoji_record = {
|
||||
"filename": self.filename,
|
||||
"path": os.path.join(self.path, self.filename), # 使用更新后的路径
|
||||
"embedding": self.embedding,
|
||||
"description": self.description,
|
||||
"emotion": self.emotion, # 添加情感标签字段
|
||||
"hash": self.hash,
|
||||
"format": self.format,
|
||||
"timestamp": int(self.register_time), # 使用实例的注册时间
|
||||
"usage_count": self.usage_count,
|
||||
"last_used_time": self.last_used_time,
|
||||
}
|
||||
|
||||
# 使用upsert确保记录存在或被更新
|
||||
db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True)
|
||||
logger.success(f"[注册] 表情包信息保存到数据库: {self.description}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"[错误] 保存数据库失败: {str(db_error)}")
|
||||
# 考虑是否需要将文件移回?为了简化,暂时只记录错误
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 注册表情包失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def delete(self):
|
||||
"""删除表情包
|
||||
|
||||
删除表情包的文件和数据库记录
|
||||
|
||||
返回:
|
||||
bool: 是否成功删除
|
||||
"""
|
||||
try:
|
||||
# 1. 删除文件
|
||||
if os.path.exists(os.path.join(self.path, self.filename)):
|
||||
try:
|
||||
os.remove(os.path.join(self.path, self.filename))
|
||||
logger.info(f"[删除] 文件: {os.path.join(self.path, self.filename)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除文件失败 {os.path.join(self.path, self.filename)}: {str(e)}")
|
||||
# 继续执行,即使文件删除失败也尝试删除数据库记录
|
||||
|
||||
# 2. 删除数据库记录
|
||||
result = db.emoji.delete_one({"hash": self.hash})
|
||||
deleted_in_db = result.deleted_count > 0
|
||||
|
||||
if deleted_in_db:
|
||||
logger.success(f"[删除] 成功删除表情包记录: {self.description}")
|
||||
|
||||
# 3. 标记对象已被删除
|
||||
self.is_deleted = True
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[错误] 删除表情包记录失败: {self.hash}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除表情包失败: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class EmojiManager:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self._scan_task = None
|
||||
self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
|
||||
self.llm_emotion_judge = LLMRequest(
|
||||
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji"
|
||||
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||
|
||||
self.emoji_num = 0
|
||||
self.emoji_num_max = global_config.max_emoji_num
|
||||
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
|
||||
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型
|
||||
|
||||
logger.info("启动表情包管理器")
|
||||
|
||||
def _ensure_emoji_dir(self):
|
||||
"""确保表情存储目录存在"""
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
|
||||
def initialize(self):
|
||||
"""初始化数据库连接和表情目录"""
|
||||
if not self._initialized:
|
||||
try:
|
||||
self._ensure_emoji_collection()
|
||||
self._ensure_emoji_dir()
|
||||
self._initialized = True
|
||||
# 更新表情包数量
|
||||
# 启动时执行一次完整性检查
|
||||
# await self.check_emoji_file_integrity()
|
||||
except Exception:
|
||||
logger.exception("初始化表情管理器失败")
|
||||
|
||||
def _ensure_db(self):
|
||||
"""确保数据库已初始化"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
if not self._initialized:
|
||||
raise RuntimeError("EmojiManager not initialized")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_emoji_collection():
|
||||
"""确保emoji集合存在并创建索引
|
||||
|
||||
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
|
||||
|
||||
索引的作用是加快数据库查询速度:
|
||||
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
|
||||
- tags字段的普通索引: 加快按标签搜索表情包的速度
|
||||
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
|
||||
|
||||
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
|
||||
"""
|
||||
if "emoji" not in db.list_collection_names():
|
||||
db.create_collection("emoji")
|
||||
db.emoji.create_index([("embedding", "2dsphere")])
|
||||
db.emoji.create_index([("filename", 1)], unique=True)
|
||||
|
||||
def record_usage(self, hash: str):
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
db.emoji.update_one({"hash": hash}, {"$inc": {"usage_count": 1}})
|
||||
for emoji in self.emoji_objects:
|
||||
if emoji.hash == hash:
|
||||
emoji.usage_count += 1
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
|
||||
async def get_emoji_for_text(self, text_emotion: str) -> Optional[Tuple[str, str]]:
|
||||
"""根据文本内容获取相关表情包
|
||||
Args:
|
||||
text_emotion: 输入的情感描述文本
|
||||
Returns:
|
||||
Optional[Tuple[str, str]]: (表情包文件路径, 表情包描述),如果没有找到则返回None
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
time_start = time.time()
|
||||
|
||||
# 获取所有表情包
|
||||
all_emojis = self.emoji_objects
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("数据库中没有任何表情包")
|
||||
return None
|
||||
|
||||
# 计算每个表情包与输入文本的最大情感相似度
|
||||
emoji_similarities = []
|
||||
for emoji in all_emojis:
|
||||
emotions = emoji.emotion
|
||||
if not emotions:
|
||||
continue
|
||||
|
||||
# 计算与每个emotion标签的相似度,取最大值
|
||||
max_similarity = 0
|
||||
for emotion in emotions:
|
||||
# 使用编辑距离计算相似度
|
||||
distance = self._levenshtein_distance(text_emotion, emotion)
|
||||
max_len = max(len(text_emotion), len(emotion))
|
||||
similarity = 1 - (distance / max_len if max_len > 0 else 0)
|
||||
max_similarity = max(max_similarity, similarity)
|
||||
|
||||
emoji_similarities.append((emoji, max_similarity))
|
||||
|
||||
# 按相似度降序排序
|
||||
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 获取前5个最相似的表情包
|
||||
top_5_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
|
||||
|
||||
if not top_5_emojis:
|
||||
logger.warning("未找到匹配的表情包")
|
||||
return None
|
||||
|
||||
# 从前5个中随机选择一个
|
||||
selected_emoji, similarity = random.choice(top_5_emojis)
|
||||
|
||||
# 更新使用次数
|
||||
self.record_usage(selected_emoji.hash)
|
||||
|
||||
time_end = time.time()
|
||||
|
||||
logger.info(
|
||||
f"找到[{text_emotion}]表情包,用时:{time_end - time_start:.2f}秒: {selected_emoji.description} (相似度: {similarity:.4f})"
|
||||
)
|
||||
return selected_emoji.path, f"[ {selected_emoji.description} ]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 获取表情包失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||||
"""计算两个字符串的编辑距离
|
||||
|
||||
Args:
|
||||
s1: 第一个字符串
|
||||
s2: 第二个字符串
|
||||
|
||||
Returns:
|
||||
int: 编辑距离
|
||||
"""
|
||||
if len(s1) < len(s2):
|
||||
return self._levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
async def check_emoji_file_integrity(self):
|
||||
"""检查表情包文件完整性
|
||||
遍历self.emoji_objects中的所有对象,检查文件是否存在
|
||||
如果文件已被删除,则执行对象的删除方法并从列表中移除
|
||||
"""
|
||||
try:
|
||||
if not self.emoji_objects:
|
||||
logger.warning("[检查] emoji_objects为空,跳过完整性检查")
|
||||
return
|
||||
|
||||
total_count = len(self.emoji_objects)
|
||||
removed_count = 0
|
||||
# 使用列表复制进行遍历,因为我们会在遍历过程中修改列表
|
||||
for emoji in self.emoji_objects[:]:
|
||||
try:
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(emoji.path):
|
||||
logger.warning(f"[检查] 表情包文件已被删除: {emoji.path}")
|
||||
# 执行表情包对象的删除方法
|
||||
await emoji.delete()
|
||||
# 从列表中移除该对象
|
||||
self.emoji_objects.remove(emoji)
|
||||
# 更新计数
|
||||
self.emoji_num -= 1
|
||||
removed_count += 1
|
||||
continue
|
||||
|
||||
except Exception as item_error:
|
||||
logger.error(f"[错误] 处理表情包记录时出错: {str(item_error)}")
|
||||
continue
|
||||
|
||||
# 输出清理结果
|
||||
if removed_count > 0:
|
||||
logger.success(f"[清理] 已清理 {removed_count} 个失效的表情包记录")
|
||||
logger.info(f"[统计] 清理前: {total_count} | 清理后: {len(self.emoji_objects)}")
|
||||
else:
|
||||
logger.info(f"[检查] 已检查 {total_count} 个表情包记录,全部完好")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def start_periodic_check_register(self):
|
||||
"""定期检查表情包完整性和数量"""
|
||||
await self.get_all_emoji_from_db()
|
||||
while True:
|
||||
logger.info("[扫描] 开始检查表情包完整性...")
|
||||
await self.check_emoji_file_integrity()
|
||||
await self.clear_temp_emoji()
|
||||
logger.info("[扫描] 开始扫描新表情包...")
|
||||
|
||||
# 检查表情包目录是否存在
|
||||
if not os.path.exists(EMOJI_DIR):
|
||||
logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}")
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}")
|
||||
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
||||
continue
|
||||
|
||||
# 检查目录是否为空
|
||||
files = os.listdir(EMOJI_DIR)
|
||||
if not files:
|
||||
logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}")
|
||||
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
||||
continue
|
||||
|
||||
# 检查是否需要处理表情包(数量超过最大值或不足)
|
||||
if (self.emoji_num > self.emoji_num_max and global_config.max_reach_deletion) or (
|
||||
self.emoji_num < self.emoji_num_max
|
||||
):
|
||||
try:
|
||||
# 获取目录下所有图片文件
|
||||
files_to_process = [
|
||||
f
|
||||
for f in files
|
||||
if os.path.isfile(os.path.join(EMOJI_DIR, f))
|
||||
and f.lower().endswith((".jpg", ".jpeg", ".png", ".gif"))
|
||||
]
|
||||
|
||||
# 处理每个符合条件的文件
|
||||
for filename in files_to_process:
|
||||
# 尝试注册表情包
|
||||
success = await self.register_emoji_by_filename(filename)
|
||||
if success:
|
||||
# 注册成功则跳出循环
|
||||
break
|
||||
else:
|
||||
# 注册失败则删除对应文件
|
||||
file_path = os.path.join(EMOJI_DIR, filename)
|
||||
os.remove(file_path)
|
||||
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
|
||||
|
||||
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
||||
|
||||
async def get_all_emoji_from_db(self):
|
||||
"""获取所有表情包并初始化为MaiEmoji类对象
|
||||
|
||||
参数:
|
||||
hash: 可选,如果提供则只返回指定哈希值的表情包
|
||||
|
||||
返回:
|
||||
list[MaiEmoji]: 表情包对象列表
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
|
||||
# 获取所有表情包
|
||||
all_emoji_data = list(db.emoji.find())
|
||||
|
||||
# 将数据库记录转换为MaiEmoji对象
|
||||
emoji_objects = []
|
||||
for emoji_data in all_emoji_data:
|
||||
emoji = MaiEmoji(
|
||||
filename=emoji_data.get("filename", ""),
|
||||
path=emoji_data.get("path", ""),
|
||||
)
|
||||
|
||||
# 设置额外属性
|
||||
emoji.hash = emoji_data.get("hash", "")
|
||||
emoji.usage_count = emoji_data.get("usage_count", 0)
|
||||
emoji.last_used_time = emoji_data.get("last_used_time", emoji_data.get("timestamp", time.time()))
|
||||
emoji.register_time = emoji_data.get("timestamp", time.time())
|
||||
emoji.description = emoji_data.get("description", "")
|
||||
emoji.emotion = emoji_data.get("emotion", []) # 添加情感标签的加载
|
||||
emoji_objects.append(emoji)
|
||||
|
||||
# 存储到EmojiManager中
|
||||
self.emoji_objects = emoji_objects
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 获取所有表情包对象失败: {str(e)}")
|
||||
|
||||
async def get_emoji_from_db(self, hash=None):
|
||||
"""获取所有表情包并初始化为MaiEmoji类对象
|
||||
|
||||
参数:
|
||||
hash: 可选,如果提供则只返回指定哈希值的表情包
|
||||
|
||||
返回:
|
||||
list[MaiEmoji]: 表情包对象列表
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
|
||||
# 准备查询条件
|
||||
query = {}
|
||||
if hash:
|
||||
query = {"hash": hash}
|
||||
|
||||
# 获取所有表情包
|
||||
all_emoji_data = list(db.emoji.find(query))
|
||||
|
||||
# 将数据库记录转换为MaiEmoji对象
|
||||
emoji_objects = []
|
||||
for emoji_data in all_emoji_data:
|
||||
emoji = MaiEmoji(
|
||||
filename=emoji_data.get("filename", ""),
|
||||
path=emoji_data.get("path", ""),
|
||||
)
|
||||
|
||||
# 设置额外属性
|
||||
emoji.usage_count = emoji_data.get("usage_count", 0)
|
||||
emoji.last_used_time = emoji_data.get("last_used_time", emoji_data.get("timestamp", time.time()))
|
||||
emoji.register_time = emoji_data.get("timestamp", time.time())
|
||||
emoji.description = emoji_data.get("description", "")
|
||||
emoji.emotion = emoji_data.get("emotion", []) # 添加情感标签的加载
|
||||
|
||||
emoji_objects.append(emoji)
|
||||
|
||||
# 存储到EmojiManager中
|
||||
self.emoji_objects = emoji_objects
|
||||
|
||||
return emoji_objects
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 获取所有表情包对象失败: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_emoji_from_manager(self, hash) -> MaiEmoji:
|
||||
"""从EmojiManager中获取表情包
|
||||
|
||||
参数:
|
||||
hash:如果提供则只返回指定哈希值的表情包
|
||||
"""
|
||||
for emoji in self.emoji_objects:
|
||||
if emoji.hash == hash:
|
||||
return emoji
|
||||
return None
|
||||
|
||||
async def delete_emoji(self, emoji_hash: str) -> bool:
|
||||
"""根据哈希值删除表情包
|
||||
|
||||
Args:
|
||||
emoji_hash: 表情包的哈希值
|
||||
|
||||
Returns:
|
||||
bool: 是否成功删除
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
|
||||
# 从emoji_objects中查找表情包对象
|
||||
emoji = await self.get_emoji_from_manager(emoji_hash)
|
||||
|
||||
if not emoji:
|
||||
logger.warning(f"[警告] 未找到哈希值为 {emoji_hash} 的表情包")
|
||||
return False
|
||||
|
||||
# 使用MaiEmoji对象的delete方法删除表情包
|
||||
success = await emoji.delete()
|
||||
|
||||
if success:
|
||||
# 从emoji_objects列表中移除该对象
|
||||
self.emoji_objects = [e for e in self.emoji_objects if e.hash != emoji_hash]
|
||||
# 更新计数
|
||||
self.emoji_num -= 1
|
||||
logger.info(f"[统计] 当前表情包数量: {self.emoji_num}")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[错误] 删除表情包失败: {emoji_hash}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除表情包失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def _emoji_objects_to_readable_list(self, emoji_objects):
|
||||
"""将表情包对象列表转换为可读的字符串列表
|
||||
|
||||
参数:
|
||||
emoji_objects: MaiEmoji对象列表
|
||||
|
||||
返回:
|
||||
list[str]: 可读的表情包信息字符串列表
|
||||
"""
|
||||
emoji_info_list = []
|
||||
for i, emoji in enumerate(emoji_objects):
|
||||
# 转换时间戳为可读时间
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(emoji.register_time))
|
||||
# 构建每个表情包的信息字符串
|
||||
emoji_info = (
|
||||
f"编号: {i + 1}\n描述: {emoji.description}\n使用次数: {emoji.usage_count}\n添加时间: {time_str}\n"
|
||||
)
|
||||
emoji_info_list.append(emoji_info)
|
||||
return emoji_info_list
|
||||
|
||||
async def replace_a_emoji(self, new_emoji: MaiEmoji):
|
||||
"""替换一个表情包
|
||||
|
||||
Args:
|
||||
new_emoji: 新表情包对象
|
||||
|
||||
Returns:
|
||||
bool: 是否成功替换表情包
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
|
||||
# 获取所有表情包对象
|
||||
all_emojis = self.emoji_objects
|
||||
|
||||
# 将表情包信息转换为可读的字符串
|
||||
emoji_info_list = self._emoji_objects_to_readable_list(all_emojis)
|
||||
|
||||
# 构建提示词
|
||||
prompt = (
|
||||
f"{global_config.BOT_NICKNAME}的表情包存储已满({self.emoji_num}/{self.emoji_num_max}),"
|
||||
f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n"
|
||||
f"新表情包信息:\n"
|
||||
f"描述: {new_emoji.description}\n\n"
|
||||
f"现有表情包列表:\n" + "\n".join(emoji_info_list) + "\n\n"
|
||||
"请决定:\n"
|
||||
"1. 是否要删除某个现有表情包来为新表情包腾出空间?\n"
|
||||
"2. 如果要删除,应该删除哪一个(给出编号)?\n"
|
||||
"请只回答:'不删除'或'删除编号X'(X为表情包编号)。"
|
||||
)
|
||||
|
||||
# 调用大模型进行决策
|
||||
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8)
|
||||
logger.info(f"[决策] 大模型决策结果: {decision}")
|
||||
|
||||
# 解析决策结果
|
||||
if "不删除" in decision:
|
||||
logger.info("[决策] 决定不删除任何表情包")
|
||||
return False
|
||||
|
||||
# 尝试从决策中提取表情包编号
|
||||
match = re.search(r"删除编号(\d+)", decision)
|
||||
if match:
|
||||
emoji_index = int(match.group(1)) - 1 # 转换为0-based索引
|
||||
|
||||
# 检查索引是否有效
|
||||
if 0 <= emoji_index < len(all_emojis):
|
||||
emoji_to_delete = all_emojis[emoji_index]
|
||||
|
||||
# 删除选定的表情包
|
||||
logger.info(f"[决策] 决定删除表情包: {emoji_to_delete.description}")
|
||||
delete_success = await self.delete_emoji(emoji_to_delete.hash)
|
||||
|
||||
if delete_success:
|
||||
# 修复:等待异步注册完成
|
||||
register_success = await new_emoji.register_to_db()
|
||||
if register_success:
|
||||
self.emoji_objects.append(new_emoji)
|
||||
self.emoji_num += 1
|
||||
logger.success(f"[成功] 注册表情包: {new_emoji.description}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[错误] 注册表情包到数据库失败: {new_emoji.filename}")
|
||||
return False
|
||||
else:
|
||||
logger.error("[错误] 删除表情包失败,无法完成替换")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"[错误] 无效的表情包编号: {emoji_index + 1}")
|
||||
else:
|
||||
logger.error(f"[错误] 无法从决策中提取表情包编号: {decision}")
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 替换表情包失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]:
|
||||
"""获取表情包描述和情感列表
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
|
||||
Returns:
|
||||
Tuple[str, list]: 返回表情包描述和情感列表
|
||||
"""
|
||||
try:
|
||||
# 解码图片并获取格式
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
# 调用AI获取描述
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
image_base64 = image_manager.transform_gif(image_base64)
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,详细描述一下表情包表达的情感和内容,请关注其幽默和讽刺意味"
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
else:
|
||||
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,请关注其幽默和讽刺意味"
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
# 审核表情包
|
||||
if global_config.EMOJI_CHECK:
|
||||
prompt = f'''
|
||||
这是一个表情包,请对这个表情包进行审核,标准如下:
|
||||
1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求
|
||||
2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
|
||||
3. 不能是任何形式的截图,聊天记录或视频截图
|
||||
4. 不要出现5个以上文字
|
||||
请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
|
||||
'''
|
||||
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
if content == "否":
|
||||
return None, []
|
||||
|
||||
# 分析情感含义
|
||||
emotion_prompt = f"""
|
||||
基于这个表情包的描述:'{description}',请列出1-2个可能的情感标签,每个标签用一个词组表示,格式如下:
|
||||
幽默的讽刺
|
||||
悲伤的无奈
|
||||
愤怒的抗议
|
||||
愤怒的讽刺
|
||||
直接输出词组,词组检用逗号分隔。"""
|
||||
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7)
|
||||
|
||||
# 处理情感列表
|
||||
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
||||
|
||||
return f"[表情包:{description}]", emotions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
return "", []
|
||||
|
||||
async def register_emoji_by_filename(self, filename: str) -> bool:
|
||||
"""读取指定文件名的表情包图片,分析并注册到数据库
|
||||
|
||||
Args:
|
||||
filename: 表情包文件名,必须位于EMOJI_DIR目录下
|
||||
|
||||
Returns:
|
||||
bool: 注册是否成功
|
||||
"""
|
||||
try:
|
||||
# 使用MaiEmoji类创建表情包实例
|
||||
new_emoji = MaiEmoji(filename, EMOJI_DIR)
|
||||
await new_emoji.initialize_hash_format()
|
||||
emoji_base64 = image_path_to_base64(os.path.join(EMOJI_DIR, filename))
|
||||
description, emotions = await self.build_emoji_description(emoji_base64)
|
||||
if description == "":
|
||||
return False
|
||||
new_emoji.description = description
|
||||
new_emoji.emotion = emotions
|
||||
|
||||
# 检查是否已经注册过
|
||||
# 对比内存中是否存在相同哈希值的表情包
|
||||
if await self.get_emoji_from_manager(new_emoji.hash):
|
||||
logger.warning(f"[警告] 表情包已存在: {filename}")
|
||||
return False
|
||||
|
||||
if self.emoji_num >= self.emoji_num_max:
|
||||
logger.warning(f"表情包数量已达到上限({self.emoji_num}/{self.emoji_num_max})")
|
||||
replaced = await self.replace_a_emoji(new_emoji)
|
||||
if not replaced:
|
||||
logger.error("[错误] 替换表情包失败,无法完成注册")
|
||||
return False
|
||||
else:
|
||||
# 修复:等待异步注册完成
|
||||
register_success = await new_emoji.register_to_db()
|
||||
if register_success:
|
||||
self.emoji_objects.append(new_emoji)
|
||||
self.emoji_num += 1
|
||||
logger.success(f"[成功] 注册表情包: {filename}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[错误] 注册表情包到数据库失败: {filename}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 注册表情包失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def clear_temp_emoji(self):
|
||||
"""每天清理临时表情包
|
||||
清理/data/emoji和/data/image目录下的所有文件
|
||||
当目录中文件数超过50时,会全部删除
|
||||
"""
|
||||
|
||||
logger.info("[清理] 开始清理临时表情包...")
|
||||
|
||||
# 清理emoji目录
|
||||
emoji_dir = os.path.join(BASE_DIR, "emoji")
|
||||
if os.path.exists(emoji_dir):
|
||||
files = os.listdir(emoji_dir)
|
||||
# 如果文件数超过50就全部删除
|
||||
if len(files) > 50:
|
||||
for filename in files:
|
||||
file_path = os.path.join(emoji_dir, filename)
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
logger.debug(f"[清理] 删除表情包文件: {filename}")
|
||||
|
||||
# 清理image目录
|
||||
image_dir = os.path.join(BASE_DIR, "image")
|
||||
if os.path.exists(image_dir):
|
||||
files = os.listdir(image_dir)
|
||||
# 如果文件数超过50就全部删除
|
||||
if len(files) > 50:
|
||||
for filename in files:
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
logger.debug(f"[清理] 删除图片文件: {filename}")
|
||||
|
||||
logger.success("[清理] 临时文件清理完成")
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
emoji_manager = EmojiManager()
|
||||
74
src/plugins/heartFC_chat/heartFC_Cycleinfo.py
Normal file
74
src/plugins/heartFC_chat/heartFC_Cycleinfo.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
|
||||
class CycleInfo:
|
||||
"""循环信息记录类"""
|
||||
|
||||
def __init__(self, cycle_id: int):
|
||||
self.cycle_id = cycle_id
|
||||
self.start_time = time.time()
|
||||
self.end_time: Optional[float] = None
|
||||
self.action_taken = False
|
||||
self.action_type = "unknown"
|
||||
self.reasoning = ""
|
||||
self.timers: Dict[str, float] = {}
|
||||
self.thinking_id = ""
|
||||
self.replanned = False
|
||||
|
||||
# 添加响应信息相关字段
|
||||
self.response_info: Dict[str, Any] = {
|
||||
"response_text": [], # 回复的文本列表
|
||||
"emoji_info": "", # 表情信息
|
||||
"anchor_message_id": "", # 锚点消息ID
|
||||
"reply_message_ids": [], # 回复消息ID列表
|
||||
"sub_mind_thinking": "", # 子思维思考内容
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""将循环信息转换为字典格式"""
|
||||
return {
|
||||
"cycle_id": self.cycle_id,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"action_taken": self.action_taken,
|
||||
"action_type": self.action_type,
|
||||
"reasoning": self.reasoning,
|
||||
"timers": self.timers,
|
||||
"thinking_id": self.thinking_id,
|
||||
"response_info": self.response_info,
|
||||
}
|
||||
|
||||
def complete_cycle(self):
|
||||
"""完成循环,记录结束时间"""
|
||||
self.end_time = time.time()
|
||||
|
||||
def set_action_info(self, action_type: str, reasoning: str, action_taken: bool):
|
||||
"""设置动作信息"""
|
||||
self.action_type = action_type
|
||||
self.reasoning = reasoning
|
||||
self.action_taken = action_taken
|
||||
|
||||
def set_thinking_id(self, thinking_id: str):
|
||||
"""设置思考消息ID"""
|
||||
self.thinking_id = thinking_id
|
||||
|
||||
def set_response_info(
|
||||
self,
|
||||
response_text: Optional[List[str]] = None,
|
||||
emoji_info: Optional[str] = None,
|
||||
anchor_message_id: Optional[str] = None,
|
||||
reply_message_ids: Optional[List[str]] = None,
|
||||
sub_mind_thinking: Optional[str] = None,
|
||||
):
|
||||
"""设置响应信息"""
|
||||
if response_text is not None:
|
||||
self.response_info["response_text"] = response_text
|
||||
if emoji_info is not None:
|
||||
self.response_info["emoji_info"] = emoji_info
|
||||
if anchor_message_id is not None:
|
||||
self.response_info["anchor_message_id"] = anchor_message_id
|
||||
if reply_message_ids is not None:
|
||||
self.response_info["reply_message_ids"] = reply_message_ids
|
||||
if sub_mind_thinking is not None:
|
||||
self.response_info["sub_mind_thinking"] = sub_mind_thinking
|
||||
File diff suppressed because it is too large
Load Diff
92
src/plugins/heartFC_chat/heartFC_chatting_logic.md
Normal file
92
src/plugins/heartFC_chat/heartFC_chatting_logic.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# HeartFChatting 逻辑详解
|
||||
|
||||
`HeartFChatting` 类是心流系统(Heart Flow System)中实现**专注聊天**(`ChatState.FOCUSED`)功能的核心。顾名思义,其职责乃是在特定聊天流(`stream_id`)中,模拟更为连贯深入之对话。此非凭空臆造,而是依赖一个持续不断的 **思考(Think)-规划(Plan)-执行(Execute)** 循环。当其所系的 `SubHeartflow` 进入 `FOCUSED` 状态时,便会创建并启动 `HeartFChatting` 实例;若状态转为他途(譬如 `CHAT` 或 `ABSENT`),则会将其关闭。
|
||||
|
||||
## 1. 初始化简述 (`__init__`, `_initialize`)
|
||||
|
||||
创生之初,`HeartFChatting` 需注入若干关键之物:`chat_id`(亦即 `stream_id`)、关联的 `SubMind` 实例,以及 `Observation` 实例(用以观察环境)。
|
||||
|
||||
其内部核心组件包括:
|
||||
|
||||
- `ActionManager`: 管理当前循环可选之策(如:不应、言语、表情)。
|
||||
- `HeartFCGenerator` (`self.gpt_instance`): 专司生成回复文本之职。
|
||||
- `ToolUser` (`self.tool_user`): 虽主要用于获取工具定义,然亦备 `SubMind` 调用之需(实际执行由 `SubMind` 操持)。
|
||||
- `HeartFCSender` (`self.heart_fc_sender`): 负责消息发送诸般事宜,含"正在思考"之态。
|
||||
- `LLMRequest` (`self.planner_llm`): 配置用于执行"规划"任务的大语言模型。
|
||||
|
||||
*初始化过程采取懒加载策略,仅在首次需要访问 `ChatStream` 时(通常在 `start` 方法中)进行。*
|
||||
|
||||
## 2. 生命周期 (`start`, `shutdown`)
|
||||
|
||||
- **启动 (`start`)**: 外部调用此法,以启 `HeartFChatting` 之流程。内部会安全地启动主循环任务。
|
||||
- **关闭 (`shutdown`)**: 外部调用此法,以止其运行。会取消主循环任务,清理状态,并释放锁。
|
||||
|
||||
## 3. 核心循环 (`_hfc_loop`) 与 循环记录 (`CycleInfo`)
|
||||
|
||||
`_hfc_loop` 乃 `HeartFChatting` 之脉搏,以异步方式不舍昼夜运行(直至 `shutdown` 被调用)。其核心在于周而复始地执行 **思考-规划-执行** 之周期。
|
||||
|
||||
每一轮循环,皆会创建一个 `CycleInfo` 对象。此对象犹如史官,详细记载该次循环之点滴:
|
||||
|
||||
- **身份标识**: 循环 ID (`cycle_id`)。
|
||||
- **时间轨迹**: 起止时刻 (`start_time`, `end_time`)。
|
||||
- **行动细节**: 是否执行动作 (`action_taken`)、动作类型 (`action_type`)、决策理由 (`reasoning`)。
|
||||
- **耗时考量**: 各阶段计时 (`timers`)。
|
||||
- **关联信息**: 思考消息 ID (`thinking_id`)、是否重新规划 (`replanned`)、详尽响应信息 (`response_info`,含生成文本、表情、锚点、实际发送ID、`SubMind`思考等)。
|
||||
|
||||
这些 `CycleInfo` 被存入一个队列 (`_cycle_history`),近者得观。此记录不仅便于调试,更关键的是,它会作为**上下文信息**传递给下一次循环的"思考"阶段,使得 `SubMind` 能鉴往知来,做出更连贯的决策。
|
||||
|
||||
*循环间会根据执行情况智能引入延迟,避免空耗资源。*
|
||||
|
||||
## 4. 思考-规划-执行周期 (`_think_plan_execute_loop`)
|
||||
|
||||
此乃 `HeartFChatting` 最核心的逻辑单元,每一循环皆按序执行以下三步:
|
||||
|
||||
### 4.1. 思考 (`_get_submind_thinking`)
|
||||
|
||||
* **第一步:观察环境**: 调用 `Observation` 的 `observe()` 方法,感知聊天室是否有新动态(如新消息)。
|
||||
* **第二步:触发子思维**: 调用关联 `SubMind` 的 `do_thinking_before_reply()` 方法。
|
||||
* **关键点**: 会将**上一个循环**的 `CycleInfo` 传入,让 `SubMind` 了解上次行动的决策、理由及是否重新规划,从而实现"承前启后"的思考。
|
||||
* `SubMind` 在此阶段不仅进行思考,还可能**调用其配置的工具**来收集信息。
|
||||
* **第三步:获取成果**: `SubMind` 返回两部分重要信息:
|
||||
1. 当前的内心想法 (`current_mind`)。
|
||||
2. 通过工具调用收集到的结构化信息 (`structured_info`)。
|
||||
|
||||
### 4.2. 规划 (`_planner`)
|
||||
|
||||
* **输入**: 接收来自"思考"阶段的 `current_mind` 和 `structured_info`,以及"观察"到的最新消息。
|
||||
* **目标**: 基于当前想法、已知信息、聊天记录、机器人个性以及可用动作,决定**接下来要做什么**。
|
||||
* **决策方式**:
|
||||
1. 构建一个精心设计的提示词 (`_build_planner_prompt`)。
|
||||
2. 获取 `ActionManager` 中定义的当前可用动作(如 `no_reply`, `text_reply`, `emoji_reply`)作为"工具"选项。
|
||||
3. 调用大语言模型 (`self.planner_llm`),**强制**其选择一个动作"工具"并提供理由。可选动作包括:
|
||||
* `no_reply`: 不回复(例如,自己刚说过话或对方未回应)。
|
||||
* `text_reply`: 发送文本回复。
|
||||
* `emoji_reply`: 仅发送表情。
|
||||
* 文本回复亦可附带表情(通过 `emoji_query` 参数指定)。
|
||||
* **动态调整(重新规划)**:
|
||||
* 在做出初步决策后,会检查自规划开始后是否有新消息 (`_check_new_messages`)。
|
||||
* 若有新消息,则有一定概率触发**重新规划**。此时会再次调用规划器,但提示词会包含之前决策的信息,要求 LLM 重新考虑。
|
||||
* **输出**: 返回一个包含最终决策的字典,主要包括:
|
||||
* `action`: 选定的动作类型。
|
||||
* `reasoning`: 做出此决策的理由。
|
||||
* `emoji_query`: (可选) 如果需要发送表情,指定表情的主题。
|
||||
|
||||
### 4.3. 执行 (`_handle_action`)
|
||||
|
||||
* **输入**: 接收"规划"阶段输出的 `action`、`reasoning` 和 `emoji_query`。
|
||||
* **行动**: 根据 `action` 的类型,分派到不同的处理函数:
|
||||
* **文本回复 (`_handle_text_reply`)**:
|
||||
1. 获取锚点消息(当前实现为系统触发的占位符)。
|
||||
2. 调用 `HeartFCSender` 的 `register_thinking` 标记开始思考。
|
||||
3. 调用 `HeartFCGenerator` (`_replier_work`) 生成回复文本。**注意**: 回复器逻辑 (`_replier_work`) 本身并非独立复杂组件,主要是调用 `HeartFCGenerator` 完成文本生成。
|
||||
4. 调用 `HeartFCSender` (`_sender`) 发送生成的文本和可能的表情。**注意**: 发送逻辑 (`_sender`, `_send_response_messages`, `_handle_emoji`) 同样委托给 `HeartFCSender` 实例处理,包含模拟打字、实际发送、存储消息等细节。
|
||||
* **仅表情回复 (`_handle_emoji_reply`)**:
|
||||
1. 获取锚点消息。
|
||||
2. 调用 `HeartFCSender` 发送表情。
|
||||
* **不回复 (`_handle_no_reply`)**:
|
||||
1. 记录理由。
|
||||
2. 进入等待状态 (`_wait_for_new_message`),直到检测到新消息或超时(目前300秒),期间会监听关闭信号。
|
||||
|
||||
## 总结
|
||||
|
||||
`HeartFChatting` 通过 **观察 -> 思考(含工具)-> 规划 -> 执行** 的闭环,并利用 `CycleInfo` 进行上下文传递,实现了更加智能和连贯的专注聊天行为。其核心在于利用 `SubMind` 进行深度思考和信息收集,再通过 LLM 规划器进行决策,最后由 `HeartFCSender` 可靠地执行消息发送任务。
|
||||
@@ -8,7 +8,7 @@ from .heartflow_prompt_builder import prompt_builder
|
||||
from ..chat.utils import process_llm_response
|
||||
from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG
|
||||
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
||||
from ..utils.timer_calculater import Timer
|
||||
from ..utils.timer_calculator import Timer
|
||||
|
||||
from src.plugins.moods.moods import MoodManager
|
||||
|
||||
@@ -49,17 +49,13 @@ class HeartFCGenerator:
|
||||
|
||||
arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier()
|
||||
|
||||
with Timer() as t_generate_response:
|
||||
current_model = self.model_normal
|
||||
current_model.temperature = global_config.llm_normal["temp"] * arousal_multiplier # 激活度越高,温度越高
|
||||
model_response = await self._generate_response_with_model(
|
||||
structured_info, current_mind_info, reason, message, current_model, thinking_id
|
||||
)
|
||||
current_model = self.model_normal
|
||||
current_model.temperature = global_config.llm_normal["temp"] * arousal_multiplier # 激活度越高,温度越高
|
||||
model_response = await self._generate_response_with_model(
|
||||
structured_info, current_mind_info, reason, message, current_model, thinking_id
|
||||
)
|
||||
|
||||
if model_response:
|
||||
logger.info(
|
||||
f"{global_config.BOT_NICKNAME}的回复是:{model_response},生成回复时间: {t_generate_response.human_readable}"
|
||||
)
|
||||
model_processed_response = await self._process_response(model_response)
|
||||
|
||||
return model_processed_response
|
||||
@@ -78,7 +74,7 @@ class HeartFCGenerator:
|
||||
) -> str:
|
||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||
|
||||
with Timer() as t_build_prompt:
|
||||
with Timer() as _build_prompt:
|
||||
prompt = await prompt_builder.build_prompt(
|
||||
build_mode="focus",
|
||||
reason=reason,
|
||||
|
||||
159
src/plugins/heartFC_chat/heartFC_readme.md
Normal file
159
src/plugins/heartFC_chat/heartFC_readme.md
Normal file
@@ -0,0 +1,159 @@
|
||||
# HeartFC_chat 工作原理文档
|
||||
|
||||
HeartFC_chat 是一个基于心流理论的聊天系统,通过模拟人类的思维过程和情感变化来实现自然的对话交互。系统采用Plan-Replier-Sender循环机制,实现了智能化的对话决策和生成。
|
||||
|
||||
## 核心工作流程
|
||||
|
||||
### 1. 消息处理与存储 (HeartFCProcessor)
|
||||
[代码位置: src/plugins/heartFC_chat/heartflow_processor.py]
|
||||
|
||||
消息处理器负责接收和预处理消息,主要完成以下工作:
|
||||
```mermaid
|
||||
graph TD
|
||||
A[接收原始消息] --> B[解析为MessageRecv对象]
|
||||
B --> C[消息缓冲处理]
|
||||
C --> D[过滤检查]
|
||||
D --> E[存储到数据库]
|
||||
```
|
||||
|
||||
核心实现:
|
||||
- 消息处理入口:`process_message()` [行号: 38-215]
|
||||
- 消息解析和缓冲:`message_buffer.start_caching_messages()` [行号: 63]
|
||||
- 过滤检查:`_check_ban_words()`, `_check_ban_regex()` [行号: 196-215]
|
||||
- 消息存储:`storage.store_message()` [行号: 108]
|
||||
|
||||
### 2. 对话管理循环 (HeartFChatting)
|
||||
[代码位置: src/plugins/heartFC_chat/heartFC_chat.py]
|
||||
|
||||
HeartFChatting是系统的核心组件,实现了完整的对话管理循环:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Plan阶段] -->|决策是否回复| B[Replier阶段]
|
||||
B -->|生成回复内容| C[Sender阶段]
|
||||
C -->|发送消息| D[等待新消息]
|
||||
D --> A
|
||||
```
|
||||
|
||||
#### Plan阶段 [行号: 282-386]
|
||||
- 主要函数:`_planner()`
|
||||
- 功能实现:
|
||||
* 获取观察信息:`observation.observe()` [行号: 297]
|
||||
* 思维处理:`sub_mind.do_thinking_before_reply()` [行号: 301]
|
||||
* LLM决策:使用`PLANNER_TOOL_DEFINITION`进行动作规划 [行号: 13-42]
|
||||
|
||||
#### Replier阶段 [行号: 388-416]
|
||||
- 主要函数:`_replier_work()`
|
||||
- 调用生成器:`gpt_instance.generate_response()` [行号: 394]
|
||||
- 处理生成结果和错误情况
|
||||
|
||||
#### Sender阶段 [行号: 418-450]
|
||||
- 主要函数:`_sender()`
|
||||
- 发送实现:
|
||||
* 创建消息:`_create_thinking_message()` [行号: 452-477]
|
||||
* 发送回复:`_send_response_messages()` [行号: 479-525]
|
||||
* 处理表情:`_handle_emoji()` [行号: 527-567]
|
||||
|
||||
### 3. 回复生成机制 (HeartFCGenerator)
|
||||
[代码位置: src/plugins/heartFC_chat/heartFC_generator.py]
|
||||
|
||||
回复生成器负责产生高质量的回复内容:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[获取上下文信息] --> B[构建提示词]
|
||||
B --> C[调用LLM生成]
|
||||
C --> D[后处理优化]
|
||||
D --> E[返回回复集]
|
||||
```
|
||||
|
||||
核心实现:
|
||||
- 生成入口:`generate_response()` [行号: 39-67]
|
||||
* 情感调节:`arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier()` [行号: 47]
|
||||
* 模型生成:`_generate_response_with_model()` [行号: 69-95]
|
||||
* 响应处理:`_process_response()` [行号: 97-106]
|
||||
|
||||
### 4. 提示词构建系统 (HeartFlowPromptBuilder)
|
||||
[代码位置: src/plugins/heartFC_chat/heartflow_prompt_builder.py]
|
||||
|
||||
提示词构建器支持两种工作模式,HeartFC_chat专门使用Focus模式,而Normal模式是为normal_chat设计的:
|
||||
|
||||
#### 专注模式 (Focus Mode) - HeartFC_chat专用
|
||||
- 实现函数:`_build_prompt_focus()` [行号: 116-141]
|
||||
- 特点:
|
||||
* 专注于当前对话状态和思维
|
||||
* 更强的目标导向性
|
||||
* 用于HeartFC_chat的Plan-Replier-Sender循环
|
||||
* 简化的上下文处理,专注于决策
|
||||
|
||||
#### 普通模式 (Normal Mode) - Normal_chat专用
|
||||
- 实现函数:`_build_prompt_normal()` [行号: 143-215]
|
||||
- 特点:
|
||||
* 用于normal_chat的常规对话
|
||||
* 完整的个性化处理
|
||||
* 关系系统集成
|
||||
* 知识库检索:`get_prompt_info()` [行号: 217-591]
|
||||
|
||||
HeartFC_chat的Focus模式工作流程:
|
||||
```mermaid
|
||||
graph TD
|
||||
A[获取结构化信息] --> B[获取当前思维状态]
|
||||
B --> C[构建专注模式提示词]
|
||||
C --> D[用于Plan阶段决策]
|
||||
D --> E[用于Replier阶段生成]
|
||||
```
|
||||
|
||||
## 智能特性
|
||||
|
||||
### 1. 对话决策机制
|
||||
- LLM决策工具定义:`PLANNER_TOOL_DEFINITION` [heartFC_chat.py 行号: 13-42]
|
||||
- 决策执行:`_planner()` [heartFC_chat.py 行号: 282-386]
|
||||
- 考虑因素:
|
||||
* 上下文相关性
|
||||
* 情感状态
|
||||
* 兴趣程度
|
||||
* 对话时机
|
||||
|
||||
### 2. 状态管理
|
||||
[代码位置: src/plugins/heartFC_chat/heartFC_chat.py]
|
||||
- 状态机实现:`HeartFChatting`类 [行号: 44-567]
|
||||
- 核心功能:
|
||||
* 初始化:`_initialize()` [行号: 89-112]
|
||||
* 循环控制:`_run_pf_loop()` [行号: 192-281]
|
||||
* 状态转换:`_handle_loop_completion()` [行号: 166-190]
|
||||
|
||||
### 3. 回复生成策略
|
||||
[代码位置: src/plugins/heartFC_chat/heartFC_generator.py]
|
||||
- 温度调节:`current_model.temperature = global_config.llm_normal["temp"] * arousal_multiplier` [行号: 48]
|
||||
- 生成控制:`_generate_response_with_model()` [行号: 69-95]
|
||||
- 响应处理:`_process_response()` [行号: 97-106]
|
||||
|
||||
## 系统配置
|
||||
|
||||
### 关键参数
|
||||
- LLM配置:`model_normal` [heartFC_generator.py 行号: 32-37]
|
||||
- 过滤规则:`_check_ban_words()`, `_check_ban_regex()` [heartflow_processor.py 行号: 196-215]
|
||||
- 状态控制:`INITIAL_DURATION = 60.0` [heartFC_chat.py 行号: 11]
|
||||
|
||||
### 优化建议
|
||||
1. 调整LLM参数:`temperature`和`max_tokens`
|
||||
2. 优化提示词模板:`init_prompt()` [heartflow_prompt_builder.py 行号: 8-115]
|
||||
3. 配置状态转换条件
|
||||
4. 维护过滤规则
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 系统稳定性
|
||||
- 异常处理:各主要函数都包含try-except块
|
||||
- 状态检查:`_processing_lock`确保并发安全
|
||||
- 循环控制:`_loop_active`和`_loop_task`管理
|
||||
|
||||
2. 性能优化
|
||||
- 缓存使用:`message_buffer`系统
|
||||
- LLM调用优化:批量处理和复用
|
||||
- 异步处理:使用`asyncio`
|
||||
|
||||
3. 质量控制
|
||||
- 日志记录:使用`get_module_logger()`
|
||||
- 错误追踪:详细的异常记录
|
||||
- 响应监控:完整的状态跟踪
|
||||
153
src/plugins/heartFC_chat/heartFC_sender.py
Normal file
153
src/plugins/heartFC_chat/heartFC_sender.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# src/plugins/heartFC_chat/heartFC_sender.py
|
||||
import asyncio # 重新导入 asyncio
|
||||
from typing import Dict, Optional # 重新导入类型
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from ..message.api import global_api
|
||||
from ..chat.message import MessageSending, MessageThinking # 只保留 MessageSending 和 MessageThinking
|
||||
from ..storage.storage import MessageStorage
|
||||
from ..chat.utils import truncate_message
|
||||
from src.common.logger import LogConfig, SENDER_STYLE_CONFIG
|
||||
from src.plugins.chat.utils import calculate_typing_time
|
||||
|
||||
# 定义日志配置
|
||||
sender_config = LogConfig(
|
||||
# 使用消息发送专用样式
|
||||
console_format=SENDER_STYLE_CONFIG["console_format"],
|
||||
file_format=SENDER_STYLE_CONFIG["file_format"],
|
||||
)
|
||||
|
||||
logger = get_module_logger("msg_sender", config=sender_config)
|
||||
|
||||
|
||||
class HeartFCSender:
|
||||
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
|
||||
|
||||
def __init__(self):
|
||||
self.storage = MessageStorage()
|
||||
# 用于存储活跃的思考消息
|
||||
self.thinking_messages: Dict[str, Dict[str, MessageThinking]] = {}
|
||||
self._thinking_lock = asyncio.Lock() # 保护 thinking_messages 的锁
|
||||
|
||||
async def send_message(self, message: MessageSending) -> None:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
message_preview = truncate_message(message.processed_plain_text)
|
||||
|
||||
try:
|
||||
# 直接调用API发送消息
|
||||
await global_api.send_message(message)
|
||||
logger.success(f"发送消息 '{message_preview}' 成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}")
|
||||
if not message.message_info.platform:
|
||||
raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置,请检查配置文件") from e
|
||||
raise e # 重新抛出其他异常
|
||||
|
||||
async def register_thinking(self, thinking_message: MessageThinking):
|
||||
"""注册一个思考中的消息。"""
|
||||
if not thinking_message.chat_stream or not thinking_message.message_info.message_id:
|
||||
logger.error("无法注册缺少 chat_stream 或 message_id 的思考消息")
|
||||
return
|
||||
|
||||
chat_id = thinking_message.chat_stream.stream_id
|
||||
message_id = thinking_message.message_info.message_id
|
||||
|
||||
async with self._thinking_lock:
|
||||
if chat_id not in self.thinking_messages:
|
||||
self.thinking_messages[chat_id] = {}
|
||||
if message_id in self.thinking_messages[chat_id]:
|
||||
logger.warning(f"[{chat_id}] 尝试注册已存在的思考消息 ID: {message_id}")
|
||||
self.thinking_messages[chat_id][message_id] = thinking_message
|
||||
logger.debug(f"[{chat_id}] Registered thinking message: {message_id}")
|
||||
|
||||
async def complete_thinking(self, chat_id: str, message_id: str):
|
||||
"""完成并移除一个思考中的消息记录。"""
|
||||
async with self._thinking_lock:
|
||||
if chat_id in self.thinking_messages and message_id in self.thinking_messages[chat_id]:
|
||||
del self.thinking_messages[chat_id][message_id]
|
||||
logger.debug(f"[{chat_id}] Completed thinking message: {message_id}")
|
||||
if not self.thinking_messages[chat_id]:
|
||||
del self.thinking_messages[chat_id]
|
||||
logger.debug(f"[{chat_id}] Removed empty thinking message container.")
|
||||
|
||||
def is_thinking(self, chat_id: str, message_id: str) -> bool:
|
||||
"""检查指定的消息 ID 是否当前正处于思考状态。"""
|
||||
return chat_id in self.thinking_messages and message_id in self.thinking_messages[chat_id]
|
||||
|
||||
async def get_thinking_start_time(self, chat_id: str, message_id: str) -> Optional[float]:
|
||||
"""获取已注册思考消息的开始时间。"""
|
||||
async with self._thinking_lock:
|
||||
thinking_message = self.thinking_messages.get(chat_id, {}).get(message_id)
|
||||
return thinking_message.thinking_start_time if thinking_message else None
|
||||
|
||||
async def type_and_send_message(self, message: MessageSending, type=False):
|
||||
"""
|
||||
立即处理、发送并存储单个 MessageSending 消息。
|
||||
调用此方法前,应先调用 register_thinking 注册对应的思考消息。
|
||||
此方法执行后会调用 complete_thinking 清理思考状态。
|
||||
"""
|
||||
if not message.chat_stream:
|
||||
logger.error("消息缺少 chat_stream,无法发送")
|
||||
return
|
||||
if not message.message_info or not message.message_info.message_id:
|
||||
logger.error("消息缺少 message_info 或 message_id,无法发送")
|
||||
return
|
||||
|
||||
chat_id = message.chat_stream.stream_id
|
||||
message_id = message.message_info.message_id
|
||||
|
||||
try:
|
||||
_ = message.update_thinking_time()
|
||||
|
||||
# --- 条件应用 set_reply 逻辑 ---
|
||||
if message.apply_set_reply_logic and message.is_head and not message.is_private_message():
|
||||
logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...")
|
||||
message.set_reply()
|
||||
# --- 结束条件 set_reply ---
|
||||
|
||||
await message.process()
|
||||
|
||||
if type:
|
||||
typing_time = calculate_typing_time(
|
||||
input_string=message.processed_plain_text,
|
||||
thinking_start_time=message.thinking_start_time,
|
||||
is_emoji=message.is_emoji,
|
||||
)
|
||||
await asyncio.sleep(typing_time)
|
||||
|
||||
await self.send_message(message)
|
||||
await self.storage.store_message(message, message.chat_stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
|
||||
raise e
|
||||
finally:
|
||||
await self.complete_thinking(chat_id, message_id)
|
||||
|
||||
async def send_and_store(self, message: MessageSending):
|
||||
"""处理、发送并存储单个消息,不涉及思考状态管理。"""
|
||||
if not message.chat_stream:
|
||||
logger.error(f"[{message.message_info.platform or 'UnknownPlatform'}] 消息缺少 chat_stream,无法发送")
|
||||
return
|
||||
if not message.message_info or not message.message_info.message_id:
|
||||
logger.error(
|
||||
f"[{message.chat_stream.stream_id if message.chat_stream else 'UnknownStream'}] 消息缺少 message_info 或 message_id,无法发送"
|
||||
)
|
||||
return
|
||||
|
||||
chat_id = message.chat_stream.stream_id
|
||||
message_id = message.message_info.message_id # 获取消息ID用于日志
|
||||
|
||||
try:
|
||||
await message.process()
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await self.send_message(message) # 使用现有的发送方法
|
||||
await self.storage.store_message(message, message.chat_stream) # 使用现有的存储方法
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
|
||||
# 重新抛出异常,让调用者知道失败了
|
||||
raise e
|
||||
@@ -5,13 +5,14 @@ from ...config.config import global_config
|
||||
from ..chat.message import MessageRecv
|
||||
from ..storage.storage import MessageStorage
|
||||
from ..chat.utils import is_mentioned_bot_in_message
|
||||
from ..message import Seg
|
||||
from maim_message import Seg
|
||||
from src.heart_flow.heartflow import heartflow
|
||||
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
||||
from ..chat.chat_stream import chat_manager
|
||||
from ..chat.message_buffer import message_buffer
|
||||
from ..utils.timer_calculater import Timer
|
||||
from ..utils.timer_calculator import Timer
|
||||
from src.plugins.person_info.relationship_manager import relationship_manager
|
||||
from typing import Optional, Tuple
|
||||
|
||||
# 定义日志配置
|
||||
processor_config = LogConfig(
|
||||
@@ -22,193 +23,202 @@ logger = get_module_logger("heartflow_processor", config=processor_config)
|
||||
|
||||
|
||||
class HeartFCProcessor:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def process_message(self, message_data: str) -> None:
|
||||
"""处理接收到的原始消息数据,完成消息解析、缓冲、过滤、存储、兴趣度计算与更新等核心流程。
|
||||
|
||||
此函数是消息处理的核心入口,负责接收原始字符串格式的消息数据,并将其转化为结构化的 `MessageRecv` 对象。
|
||||
主要执行步骤包括:
|
||||
1. 解析 `message_data` 为 `MessageRecv` 对象,提取用户信息、群组信息等。
|
||||
2. 将消息加入 `message_buffer` 进行缓冲处理,以应对消息轰炸或者某些人一条消息分几次发等情况。
|
||||
3. 获取或创建对应的 `chat_stream` 和 `subheartflow` 实例,用于管理会话状态和心流。
|
||||
4. 对消息内容进行初步处理(如提取纯文本)。
|
||||
5. 应用全局配置中的过滤词和正则表达式,过滤不符合规则的消息。
|
||||
6. 查询消息缓冲结果,如果消息被缓冲器拦截(例如,判断为消息轰炸的一部分),则中止后续处理。
|
||||
7. 对于通过缓冲的消息,将其存储到 `MessageStorage` 中。
|
||||
|
||||
8. 调用海马体(`HippocampusManager`)计算消息内容的记忆激活率。(这部分算法后续会进行优化)
|
||||
9. 根据是否被提及(@)和记忆激活率,计算最终的兴趣度增量。(提及的额外兴趣增幅)
|
||||
10. 使用计算出的增量更新 `InterestManager` 中对应会话的兴趣度。
|
||||
11. 记录处理后的消息信息及当前的兴趣度到日志。
|
||||
|
||||
注意:此函数本身不负责生成和发送回复。回复的决策和生成逻辑被移至 `HeartFC_Chat` 类中的监控任务,
|
||||
该任务会根据 `InterestManager` 中的兴趣度变化来决定何时触发回复。
|
||||
async def _handle_error(self, error: Exception, context: str, message: Optional[MessageRecv] = None) -> None:
|
||||
"""统一的错误处理函数
|
||||
|
||||
Args:
|
||||
message_data: str: 从消息源接收到的原始消息字符串。
|
||||
error: 捕获到的异常
|
||||
context: 错误发生的上下文描述
|
||||
message: 可选的消息对象,用于记录相关消息内容
|
||||
"""
|
||||
logger.error(f"{context}: {error}")
|
||||
logger.error(traceback.format_exc())
|
||||
if message and hasattr(message, "raw_message"):
|
||||
logger.error(f"相关消息原始内容: {message.raw_message}")
|
||||
|
||||
async def _process_relationship(self, message: MessageRecv) -> None:
|
||||
"""处理用户关系逻辑
|
||||
|
||||
Args:
|
||||
message: 消息对象,包含用户信息
|
||||
"""
|
||||
platform = message.message_info.platform
|
||||
user_id = message.message_info.user_info.user_id
|
||||
nickname = message.message_info.user_info.user_nickname
|
||||
cardname = message.message_info.user_info.user_cardname or nickname
|
||||
|
||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
||||
|
||||
if not is_known:
|
||||
logger.info(f"首次认识用户: {nickname}")
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
|
||||
elif not await relationship_manager.is_qved_name(platform, user_id):
|
||||
logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
|
||||
|
||||
async def _calculate_interest(self, message: MessageRecv) -> Tuple[float, bool]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
Args:
|
||||
message: 待处理的消息对象
|
||||
|
||||
Returns:
|
||||
Tuple[float, bool]: (兴趣度, 是否被提及)
|
||||
"""
|
||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
with Timer("记忆激活"):
|
||||
interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True,
|
||||
)
|
||||
logger.trace(f"记忆激活率: {interested_rate:.2f}")
|
||||
|
||||
if is_mentioned:
|
||||
interest_increase_on_mention = 1
|
||||
interested_rate += interest_increase_on_mention
|
||||
|
||||
return interested_rate, is_mentioned
|
||||
|
||||
def _get_message_type(self, message: MessageRecv) -> str:
|
||||
"""获取消息类型
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
|
||||
Returns:
|
||||
str: 消息类型
|
||||
"""
|
||||
if message.message_segment.type != "seglist":
|
||||
return message.message_segment.type
|
||||
|
||||
if (
|
||||
isinstance(message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in message.message_segment.data)
|
||||
and len(message.message_segment.data) == 1
|
||||
):
|
||||
return message.message_segment.data[0].type
|
||||
|
||||
return "seglist"
|
||||
|
||||
async def process_message(self, message_data: str) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
1. 消息解析与初始化
|
||||
2. 消息缓冲处理
|
||||
3. 过滤检查
|
||||
4. 兴趣度计算
|
||||
5. 关系处理
|
||||
|
||||
Args:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
timing_results = {} # 初始化 timing_results
|
||||
message = None
|
||||
try:
|
||||
# 1. 消息解析与初始化
|
||||
message = MessageRecv(message_data)
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
messageinfo = message.message_info
|
||||
|
||||
# 消息加入缓冲池
|
||||
# 2. 消息缓冲与流程序化
|
||||
await message_buffer.start_caching_messages(message)
|
||||
|
||||
# 创建聊天流
|
||||
chat = await chat_manager.get_or_create_stream(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
)
|
||||
|
||||
subheartflow = await heartflow.create_subheartflow(chat.stream_id)
|
||||
|
||||
subheartflow = await heartflow.get_or_create_subheartflow(chat.stream_id)
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
await heartflow.create_subheartflow(chat.stream_id)
|
||||
|
||||
await message.process()
|
||||
logger.trace(f"消息处理成功: {message.processed_plain_text}")
|
||||
|
||||
# 过滤词/正则表达式过滤
|
||||
# 3. 过滤检查
|
||||
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
|
||||
message.raw_message, chat, userinfo
|
||||
):
|
||||
return
|
||||
|
||||
# 查询缓冲器结果
|
||||
# 4. 缓冲检查
|
||||
buffer_result = await message_buffer.query_buffer_result(message)
|
||||
|
||||
# 处理缓冲器结果 (Bombing logic)
|
||||
if not buffer_result:
|
||||
f_type = "seglist"
|
||||
if message.message_segment.type != "seglist":
|
||||
f_type = message.message_segment.type
|
||||
else:
|
||||
if (
|
||||
isinstance(message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in message.message_segment.data)
|
||||
and len(message.message_segment.data) == 1
|
||||
):
|
||||
f_type = message.message_segment.data[0].type
|
||||
if f_type == "text":
|
||||
logger.debug(f"触发缓冲,消息:{message.processed_plain_text}")
|
||||
elif f_type == "image":
|
||||
logger.debug("触发缓冲,表情包/图片等待中")
|
||||
elif f_type == "seglist":
|
||||
logger.debug("触发缓冲,消息列表等待中")
|
||||
return # 被缓冲器拦截,不生成回复
|
||||
|
||||
# ---- 只有通过缓冲的消息才进行存储和后续处理 ----
|
||||
|
||||
# 存储消息 (使用可能被缓冲器更新过的 message)
|
||||
try:
|
||||
await self.storage.store_message(message, chat)
|
||||
logger.trace(f"存储成功 (通过缓冲后): {message.processed_plain_text}")
|
||||
except Exception as e:
|
||||
logger.error(f"存储消息失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
# 存储失败可能仍需考虑是否继续,暂时返回
|
||||
msg_type = self._get_message_type(message)
|
||||
type_messages = {
|
||||
"text": f"触发缓冲,消息:{message.processed_plain_text}",
|
||||
"image": "触发缓冲,表情包/图片等待中",
|
||||
"seglist": "触发缓冲,消息列表等待中",
|
||||
}
|
||||
logger.debug(type_messages.get(msg_type, "触发未知类型缓冲"))
|
||||
return
|
||||
|
||||
# 激活度计算 (使用可能被缓冲器更新过的 message.processed_plain_text)
|
||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0 # 默认值
|
||||
try:
|
||||
with Timer("记忆激活", timing_results):
|
||||
interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True, # 使用更新后的文本
|
||||
)
|
||||
logger.trace(f"记忆激活率 (通过缓冲后): {interested_rate:.2f}")
|
||||
except Exception as e:
|
||||
logger.error(f"计算记忆激活率失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
# 5. 消息存储
|
||||
await self.storage.store_message(message, chat)
|
||||
logger.trace(f"存储成功: {message.processed_plain_text}")
|
||||
|
||||
# --- 修改:兴趣度更新逻辑 --- #
|
||||
if is_mentioned:
|
||||
interest_increase_on_mention = 1
|
||||
mentioned_boost = interest_increase_on_mention # 从配置获取提及增加值
|
||||
interested_rate += mentioned_boost
|
||||
# 6. 兴趣度计算与更新
|
||||
interested_rate, is_mentioned = await self._calculate_interest(message)
|
||||
await subheartflow.interest_chatting.increase_interest(value=interested_rate)
|
||||
subheartflow.interest_chatting.add_interest_dict(message, interested_rate, is_mentioned)
|
||||
|
||||
# 更新兴趣度 (调用 SubHeartflow 的方法)
|
||||
current_time = time.time()
|
||||
await subheartflow.interest_chatting.increase_interest(current_time, value=interested_rate)
|
||||
|
||||
# 添加到 SubHeartflow 的 interest_dict,给normal_chat处理
|
||||
await subheartflow.add_interest_dict_entry(message, interested_rate, is_mentioned)
|
||||
|
||||
# 打印消息接收和处理信息
|
||||
# 7. 日志记录
|
||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
|
||||
current_time = time.strftime("%H点%M分%S秒", time.localtime(message.message_info.time))
|
||||
logger.info(
|
||||
f"[{current_time}][{mes_name}]"
|
||||
f"{message.message_info.user_info.user_nickname}:"
|
||||
f"{userinfo.user_nickname}:"
|
||||
f"{message.processed_plain_text}"
|
||||
f"[兴趣度: {interested_rate:.2f}]"
|
||||
)
|
||||
|
||||
try:
|
||||
is_known = await relationship_manager.is_known_some_one(
|
||||
message.message_info.platform, message.message_info.user_info.user_id
|
||||
)
|
||||
if not is_known:
|
||||
logger.info(f"首次认识用户: {message.message_info.user_info.user_nickname}")
|
||||
await relationship_manager.first_knowing_some_one(
|
||||
message.message_info.platform,
|
||||
message.message_info.user_info.user_id,
|
||||
message.message_info.user_info.user_nickname,
|
||||
message.message_info.user_info.user_cardname or message.message_info.user_info.user_nickname,
|
||||
"",
|
||||
)
|
||||
else:
|
||||
# logger.debug(f"已认识用户: {message.message_info.user_info.user_nickname}")
|
||||
if not await relationship_manager.is_qved_name(
|
||||
message.message_info.platform, message.message_info.user_info.user_id
|
||||
):
|
||||
logger.info(f"更新已认识但未取名的用户: {message.message_info.user_info.user_nickname}")
|
||||
await relationship_manager.first_knowing_some_one(
|
||||
message.message_info.platform,
|
||||
message.message_info.user_info.user_id,
|
||||
message.message_info.user_info.user_nickname,
|
||||
message.message_info.user_info.user_cardname
|
||||
or message.message_info.user_info.user_nickname,
|
||||
"",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理认识关系失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
# 8. 关系处理
|
||||
await self._process_relationship(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息处理失败 (process_message V3): {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
if message: # 记录失败的消息内容
|
||||
logger.error(f"失败消息原始内容: {message.raw_message}")
|
||||
await self._handle_error(e, "消息处理失败", message)
|
||||
|
||||
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
|
||||
"""检查消息中是否包含过滤词"""
|
||||
"""检查消息是否包含过滤词
|
||||
|
||||
Args:
|
||||
text: 待检查的文本
|
||||
chat: 聊天对象
|
||||
userinfo: 用户信息
|
||||
|
||||
Returns:
|
||||
bool: 是否包含过滤词
|
||||
"""
|
||||
for word in global_config.ban_words:
|
||||
if word in text:
|
||||
logger.info(
|
||||
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
|
||||
)
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式"""
|
||||
"""检查消息是否匹配过滤正则表达式
|
||||
|
||||
Args:
|
||||
text: 待检查的文本
|
||||
chat: 聊天对象
|
||||
userinfo: 用户信息
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配过滤正则
|
||||
"""
|
||||
for pattern in global_config.ban_msgs_regex:
|
||||
if pattern.search(text):
|
||||
logger.info(
|
||||
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
|
||||
)
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -21,42 +21,80 @@ logger = get_module_logger("prompt")
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
你有以下信息可供参考:
|
||||
{structured_info}
|
||||
{info_from_tools}
|
||||
{chat_target}
|
||||
{chat_talking_prompt}
|
||||
现在你想要在群里发言或者回复。\n
|
||||
你需要扮演一位网名叫{bot_name}的人进行回复,这个人的特点是:"{prompt_personality} {prompt_identity}"。
|
||||
你正在{chat_target_2},现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,你可以参考贴吧,小红书或者微博的回复风格。
|
||||
你刚刚脑子里在想:
|
||||
你正在{chat_target_2},现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,你可以参考贴吧,知乎或者微博的回复风格。
|
||||
看到以上聊天记录,你刚刚在想:
|
||||
|
||||
{current_mind_info}
|
||||
{reason}
|
||||
因为上述想法,你决定发言,原因是:{reason}
|
||||
|
||||
回复尽量简短一些。请注意把握聊天内容,不要回复的太有条理,可以有个性。请一次只回复一个话题,不要同时回复多个人,不用指出你回复的是谁。{prompt_ger}
|
||||
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要说你说过的话题 ,注意只输出回复内容。
|
||||
{moderation_prompt}。注意:不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
{moderation_prompt}。注意:回复不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
"heart_flow_prompt",
|
||||
)
|
||||
|
||||
# Planner提示词
|
||||
|
||||
Prompt(
|
||||
"""你的名字是 {bot_name}。是一个女大二学生,正在QQ聊天,正在决定是否以及如何回应当前的聊天。
|
||||
"""
|
||||
你有以下信息可供参考:
|
||||
{structured_info}
|
||||
以上的消息是你获取到的消息,或许可以帮助你更好地回复。
|
||||
""",
|
||||
"info_from_tools",
|
||||
)
|
||||
|
||||
# Planner提示词 - 优化版
|
||||
Prompt(
|
||||
"""你的名字是{bot_name},一名大二女生,正在QQ群聊中。需要基于以下信息决定如何参与对话:
|
||||
{structured_info_block}
|
||||
{chat_content_block}
|
||||
看了以上内容,你产生的内心想法是:
|
||||
你的内心想法:
|
||||
{current_mind_block}
|
||||
请结合你的内心想法和观察到的聊天内容,分析情况并使用 'decide_reply_action' 工具来决定你的最终行动。
|
||||
决策依据:
|
||||
1. 如果聊天内容无聊、与你无关、或者你的内心想法认为不适合回复(例如在讨论你不懂或不感兴趣的话题),选择 'no_reply'。
|
||||
2. 如果聊天内容值得回应,且适合用文字表达(参考你的内心想法),选择 'text_reply'。如果你有情绪想表达,想在文字后追加一个表达情绪的表情,请同时提供 'emoji_query' (例如:'开心的'、'惊讶的')。
|
||||
3. 如果聊天内容或你的内心想法适合用一个表情来回应(例如表示赞同、惊讶、无语等),选择 'emoji_reply' 并提供表情主题 'emoji_query'。
|
||||
4. 如果最后一条消息是你自己发的,观察到的内容只有你自己的发言,并且之后没有人回复你,通常选择 'no_reply',除非有特殊原因需要追问。
|
||||
5. 如果聊天记录中最新的消息是你自己发送的,并且你还想继续回复,你应该紧紧衔接你发送的消息,进行话题的深入,补充,或追问等等;。
|
||||
6. 表情包是用来表达情绪的,不要直接回复或评价别人的表情包,而是根据对话内容和情绪选择是否用表情回应。
|
||||
7. 不要回复你自己的话,不要把自己的话当做别人说的。
|
||||
必须调用 'decide_reply_action' 工具并提供 'action' 和 'reasoning'。如果选择了 'emoji_reply' 或者选择了 'text_reply' 并想追加表情,则必须提供 'emoji_query'。""",
|
||||
{replan}
|
||||
|
||||
请综合分析聊天内容和你看到的新消息,参考内心想法,使用'decide_reply_action'工具做出决策。决策时请注意:
|
||||
|
||||
【回复原则】
|
||||
1. 不回复(no_reply)适用:
|
||||
- 话题无关/无聊/不感兴趣
|
||||
- 最后一条消息是你自己发的且无人回应你
|
||||
- 讨论你不懂的专业话题
|
||||
- 你发送了太多消息,且无人回复
|
||||
|
||||
2. 文字回复(text_reply)适用:
|
||||
- 有实质性内容需要表达
|
||||
- 有人提到你,但你还没有回应他
|
||||
- 可以追加emoji_query表达情绪(格式:情绪描述,如"俏皮的调侃")
|
||||
- 不要追加太多表情
|
||||
|
||||
3. 纯表情回复(emoji_reply)适用:
|
||||
- 适合用表情回应的场景
|
||||
- 需提供明确的emoji_query
|
||||
|
||||
4. 自我对话处理:
|
||||
- 如果是自己发的消息想继续,需自然衔接
|
||||
- 避免重复或评价自己的发言
|
||||
- 不要和自己聊天
|
||||
|
||||
【必须遵守】
|
||||
- 遵守回复原则
|
||||
- 必须调用工具并包含action和reasoning
|
||||
- 你可以选择文字回复(text_reply),纯表情回复(emoji_reply),不回复(no_reply)
|
||||
- 选择text_reply或emoji_reply时必须提供emoji_query
|
||||
- 保持回复自然,符合日常聊天习惯""",
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
|
||||
Prompt(
|
||||
"""你原本打算{action},因为:{reasoning}
|
||||
但是你看到了新的消息,你决定重新决定行动。""",
|
||||
"replan_prompt",
|
||||
)
|
||||
|
||||
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
|
||||
Prompt("和群里聊天", "chat_target_group2")
|
||||
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
||||
@@ -79,9 +117,9 @@ def init_prompt():
|
||||
你的网名叫{bot_name},有人也叫你{bot_other_names},{prompt_personality}。
|
||||
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},然后给出日常且口语化的回复,平淡一些,
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger}
|
||||
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
|
||||
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要重复自己说过的话。
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""",
|
||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。,只输出回复内容""",
|
||||
"reasoning_prompt_main",
|
||||
)
|
||||
Prompt(
|
||||
@@ -116,13 +154,14 @@ class PromptBuilder:
|
||||
|
||||
elif build_mode == "focus":
|
||||
return await self._build_prompt_focus(
|
||||
reason, current_mind_info, structured_info, chat_stream,
|
||||
reason,
|
||||
current_mind_info,
|
||||
structured_info,
|
||||
chat_stream,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _build_prompt_focus(
|
||||
self, reason, current_mind_info, structured_info, chat_stream
|
||||
) -> tuple[str, str]:
|
||||
async def _build_prompt_focus(self, reason, current_mind_info, structured_info, chat_stream) -> tuple[str, str]:
|
||||
individuality = Individuality.get_instance()
|
||||
prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1)
|
||||
prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1)
|
||||
@@ -145,7 +184,7 @@ class PromptBuilder:
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
timestamp_mode="normal",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
||||
@@ -156,11 +195,18 @@ class PromptBuilder:
|
||||
if random.random() < 0.02:
|
||||
prompt_ger += "你喜欢用反问句"
|
||||
|
||||
if structured_info:
|
||||
structured_info_prompt = await global_prompt_manager.format_prompt(
|
||||
"info_from_tools", structured_info=structured_info
|
||||
)
|
||||
else:
|
||||
structured_info_prompt = ""
|
||||
|
||||
logger.debug("开始构建prompt")
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"heart_flow_prompt",
|
||||
structured_info=structured_info,
|
||||
info_from_tools=structured_info_prompt,
|
||||
chat_target=await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
if chat_in_group
|
||||
else await global_prompt_manager.get_prompt_async("chat_target_private1"),
|
||||
@@ -490,23 +536,36 @@ class PromptBuilder:
|
||||
|
||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
# 从LPMM知识库获取知识
|
||||
found_knowledge_from_lpmm = qa_manager.get_knowledge(message)
|
||||
try:
|
||||
found_knowledge_from_lpmm = qa_manager.get_knowledge(message)
|
||||
|
||||
end_time = time.time()
|
||||
if found_knowledge_from_lpmm is not None:
|
||||
logger.debug(
|
||||
f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
|
||||
)
|
||||
related_info += found_knowledge_from_lpmm
|
||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
return related_info
|
||||
else:
|
||||
logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索")
|
||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
|
||||
related_info += knowledge_from_old
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
return related_info
|
||||
end_time = time.time()
|
||||
if found_knowledge_from_lpmm is not None:
|
||||
logger.debug(
|
||||
f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
|
||||
)
|
||||
related_info += found_knowledge_from_lpmm
|
||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
return related_info
|
||||
else:
|
||||
logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索")
|
||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
|
||||
related_info += knowledge_from_old
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
return related_info
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||
try:
|
||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
|
||||
related_info += knowledge_from_old
|
||||
logger.debug(
|
||||
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
|
||||
)
|
||||
return related_info
|
||||
except Exception as e2:
|
||||
logger.error(f"使用旧版数据库获取知识时也发生异常: {str(e2)}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def get_info_from_db(
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
import statistics # 导入 statistics 模块
|
||||
from random import random
|
||||
from typing import List, Optional # 导入 Optional
|
||||
|
||||
from ..moods.moods import MoodManager
|
||||
from ...config.config import global_config
|
||||
from ..chat.emoji_manager import emoji_manager
|
||||
from ..emoji_system.emoji_manager import emoji_manager
|
||||
from .normal_chat_generator import NormalChatGenerator
|
||||
from ..chat.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
from ..chat.message_sender import message_manager
|
||||
from ..chat.utils_image import image_path_to_base64
|
||||
from ..willing.willing_manager import willing_manager
|
||||
from ..message import UserInfo, Seg
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
||||
from src.plugins.chat.chat_stream import ChatStream, chat_manager
|
||||
from src.plugins.person_info.relationship_manager import relationship_manager
|
||||
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
||||
from src.plugins.utils.timer_calculater import Timer
|
||||
from src.plugins.utils.timer_calculator import Timer
|
||||
|
||||
# 定义日志配置
|
||||
chat_config = LogConfig(
|
||||
@@ -46,6 +47,8 @@ class NormalChat:
|
||||
self.gpt = NormalChatGenerator()
|
||||
self.mood_manager = MoodManager.get_instance() # MoodManager 保持单例
|
||||
# 存储此实例的兴趣监控任务
|
||||
self.start_time = time.time()
|
||||
|
||||
self._chat_task: Optional[asyncio.Task] = None
|
||||
logger.info(f"[{self.stream_name}] NormalChat 实例初始化完成。")
|
||||
|
||||
@@ -164,14 +167,13 @@ class NormalChat:
|
||||
)
|
||||
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
|
||||
|
||||
async def _find_interested_message(self) -> None:
|
||||
async def _reply_interested_message(self) -> None:
|
||||
"""
|
||||
后台任务方法,轮询当前实例关联chat的兴趣消息
|
||||
通常由start_monitoring_interest()启动
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(1) # 每秒检查一次
|
||||
|
||||
await asyncio.sleep(0.5) # 每秒检查一次
|
||||
# 检查任务是否已被取消
|
||||
if self._chat_task is None or self._chat_task.cancelled():
|
||||
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出")
|
||||
@@ -318,6 +320,68 @@ class NormalChat:
|
||||
# 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除)
|
||||
willing_manager.delete(message.message_info.message_id)
|
||||
|
||||
# --- 新增:处理初始高兴趣消息的私有方法 ---
|
||||
async def _process_initial_interest_messages(self):
|
||||
"""处理启动时存在于 interest_dict 中的高兴趣消息。"""
|
||||
items_to_process = list(self.interest_dict.items())
|
||||
if not items_to_process:
|
||||
return # 没有初始消息,直接返回
|
||||
|
||||
logger.info(f"[{self.stream_name}] 发现 {len(items_to_process)} 条初始兴趣消息,开始处理高兴趣部分...")
|
||||
interest_values = [item[1][1] for item in items_to_process] # 提取兴趣值列表
|
||||
|
||||
messages_to_reply = [] # 需要立即回复的消息
|
||||
|
||||
if len(interest_values) == 1:
|
||||
# 如果只有一个消息,直接处理
|
||||
messages_to_reply.append(items_to_process[0])
|
||||
logger.info(f"[{self.stream_name}] 只有一条初始消息,直接处理。")
|
||||
elif len(interest_values) > 1:
|
||||
# 计算均值和标准差
|
||||
try:
|
||||
mean_interest = statistics.mean(interest_values)
|
||||
stdev_interest = statistics.stdev(interest_values)
|
||||
threshold = mean_interest + stdev_interest
|
||||
logger.info(
|
||||
f"[{self.stream_name}] 初始兴趣值 均值: {mean_interest:.2f}, 标准差: {stdev_interest:.2f}, 阈值: {threshold:.2f}"
|
||||
)
|
||||
|
||||
# 找出高于阈值的消息
|
||||
for item in items_to_process:
|
||||
msg_id, (message, interest_value, is_mentioned) = item
|
||||
if interest_value > threshold:
|
||||
messages_to_reply.append(item)
|
||||
logger.info(f"[{self.stream_name}] 找到 {len(messages_to_reply)} 条高于阈值的初始消息进行处理。")
|
||||
except statistics.StatisticsError as e:
|
||||
logger.error(f"[{self.stream_name}] 计算初始兴趣统计值时出错: {e},跳过初始处理。")
|
||||
|
||||
# 处理需要回复的消息
|
||||
processed_count = 0
|
||||
# --- 修改:迭代前创建要处理的ID列表副本,防止迭代时修改 ---
|
||||
messages_to_process_initially = list(messages_to_reply) # 创建副本
|
||||
# --- 修改结束 ---
|
||||
for item in messages_to_process_initially: # 使用副本迭代
|
||||
msg_id, (message, interest_value, is_mentioned) = item
|
||||
# --- 修改:在处理前尝试 pop,防止竞争 ---
|
||||
popped_item = self.interest_dict.pop(msg_id, None)
|
||||
if popped_item is None:
|
||||
logger.warning(f"[{self.stream_name}] 初始兴趣消息 {msg_id} 在处理前已被移除,跳过。")
|
||||
continue # 如果消息已被其他任务处理(pop),则跳过
|
||||
# --- 修改结束 ---
|
||||
|
||||
try:
|
||||
logger.info(f"[{self.stream_name}] 处理初始高兴趣消息 {msg_id} (兴趣值: {interest_value:.2f})")
|
||||
await self.normal_response(message=message, is_mentioned=is_mentioned, interested_rate=interest_value)
|
||||
processed_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] 处理初始兴趣消息 {msg_id} 时出错: {e}\\n{traceback.format_exc()}")
|
||||
|
||||
logger.info(
|
||||
f"[{self.stream_name}] 初始高兴趣消息处理完毕,共处理 {processed_count} 条。剩余 {len(self.interest_dict)} 条待轮询。"
|
||||
)
|
||||
|
||||
# --- 新增结束 ---
|
||||
|
||||
# 保持 staticmethod, 因为不依赖实例状态, 但需要 chat 对象来获取日志上下文
|
||||
@staticmethod
|
||||
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||
@@ -351,38 +415,42 @@ class NormalChat:
|
||||
# 改为实例方法, 移除 chat 参数
|
||||
|
||||
async def start_chat(self):
|
||||
"""为此 NormalChat 实例关联的 ChatStream 启动聊天任务(如果尚未运行)。"""
|
||||
"""为此 NormalChat 实例关联的 ChatStream 启动聊天任务(如果尚未运行),
|
||||
并在后台处理一次初始的高兴趣消息。""" # 文言文注释示例:启聊之始,若有遗珠,当于暗处拂拭,勿碍正途。
|
||||
if self._chat_task is None or self._chat_task.done():
|
||||
logger.info(f"[{self.stream_name}] 启动聊天任务...")
|
||||
task = asyncio.create_task(self._find_interested_message())
|
||||
task.add_done_callback(lambda t: self._handle_task_completion(t)) # 回调现在是实例方法
|
||||
self._chat_task = task
|
||||
# --- 修改:使用 create_task 启动初始消息处理 ---
|
||||
logger.info(f"[{self.stream_name}] 开始后台处理初始兴趣消息...")
|
||||
# 创建一个任务来处理初始消息,不阻塞当前流程
|
||||
_initial_process_task = asyncio.create_task(self._process_initial_interest_messages())
|
||||
# 可以考虑给这个任务也添加完成回调来记录日志或处理错误
|
||||
# initial_process_task.add_done_callback(...)
|
||||
# --- 修改结束 ---
|
||||
|
||||
# 启动后台轮询任务 (这部分不变)
|
||||
logger.info(f"[{self.stream_name}] 启动后台兴趣消息轮询任务...")
|
||||
polling_task = asyncio.create_task(self._reply_interested_message()) # 注意变量名区分
|
||||
polling_task.add_done_callback(lambda t: self._handle_task_completion(t))
|
||||
self._chat_task = polling_task # self._chat_task 仍然指向主要的轮询任务
|
||||
else:
|
||||
logger.info(f"[{self.stream_name}] 聊天轮询任务已在运行中。")
|
||||
|
||||
# 改为实例方法, 移除 stream_id 参数
|
||||
def _handle_task_completion(self, task: asyncio.Task):
|
||||
"""兴趣监控任务完成时的回调函数。"""
|
||||
# 检查完成的任务是否是当前实例的任务
|
||||
"""任务完成回调处理"""
|
||||
if task is not self._chat_task:
|
||||
logger.warning(f"[{self.stream_name}] 收到一个未知或过时任务的完成回调。")
|
||||
logger.warning(f"[{self.stream_name}] 收到未知任务回调")
|
||||
return
|
||||
|
||||
try:
|
||||
# 检查任务是否因异常而结束
|
||||
exception = task.exception()
|
||||
if exception:
|
||||
logger.error(f"[{self.stream_name}] 兴趣监控任务因异常结束: {exception}")
|
||||
logger.error(traceback.format_exc()) # 记录完整的 traceback
|
||||
# else: # 减少日志
|
||||
# logger.info(f"[{self.stream_name}] 兴趣监控任务正常结束。")
|
||||
if exc := task.exception():
|
||||
logger.error(f"[{self.stream_name}] 任务异常: {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] 兴趣监控任务被取消。")
|
||||
logger.info(f"[{self.stream_name}] 任务已取消")
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] 处理任务完成回调时出错: {e}")
|
||||
logger.error(f"[{self.stream_name}] 回调处理错误: {e}")
|
||||
finally:
|
||||
# 标记任务已完成/移除
|
||||
if self._chat_task is task: # 再次确认是当前任务
|
||||
if self._chat_task is task:
|
||||
self._chat_task = None
|
||||
logger.debug(f"[{self.stream_name}] 聊天任务已被标记为完成/移除。")
|
||||
logger.debug(f"[{self.stream_name}] 任务清理完成")
|
||||
|
||||
# 改为实例方法, 移除 stream_id 参数
|
||||
async def stop_chat(self):
|
||||
@@ -402,7 +470,7 @@ class NormalChat:
|
||||
# 确保任务状态更新,即使等待出错 (回调函数也会尝试更新)
|
||||
if self._chat_task is task:
|
||||
self._chat_task = None
|
||||
|
||||
|
||||
# 清理所有未处理的思考消息
|
||||
try:
|
||||
container = await message_manager.get_container(self.stream_id)
|
||||
|
||||
@@ -5,7 +5,7 @@ from ...config.config import global_config
|
||||
from ..chat.message import MessageThinking
|
||||
from .heartflow_prompt_builder import prompt_builder
|
||||
from ..chat.utils import process_llm_response
|
||||
from ..utils.timer_calculater import Timer
|
||||
from ..utils.timer_calculator import Timer
|
||||
from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG
|
||||
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
||||
|
||||
|
||||
@@ -404,7 +404,7 @@ class Hippocampus:
|
||||
# logger.info("没有找到有效的关键词节点")
|
||||
return []
|
||||
|
||||
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
# 从每个关键词获取记忆
|
||||
all_memories = []
|
||||
@@ -576,7 +576,7 @@ class Hippocampus:
|
||||
# logger.info("没有找到有效的关键词节点")
|
||||
return []
|
||||
|
||||
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
# 从每个关键词获取记忆
|
||||
all_memories = []
|
||||
@@ -761,7 +761,7 @@ class Hippocampus:
|
||||
# logger.info("没有找到有效的关键词节点")
|
||||
return 0
|
||||
|
||||
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
# 从每个关键词获取记忆
|
||||
activate_map = {} # 存储每个词的累计激活值
|
||||
|
||||
@@ -3,23 +3,8 @@
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from .api import global_api
|
||||
from .message_base import (
|
||||
Seg,
|
||||
GroupInfo,
|
||||
UserInfo,
|
||||
FormatInfo,
|
||||
TemplateInfo,
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Seg",
|
||||
"global_api",
|
||||
"GroupInfo",
|
||||
"UserInfo",
|
||||
"FormatInfo",
|
||||
"TemplateInfo",
|
||||
"BaseMessageInfo",
|
||||
"MessageBase",
|
||||
]
|
||||
|
||||
@@ -1,250 +1,6 @@
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from typing import Dict, Any, Callable, List, Set, Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.message.message_base import MessageBase
|
||||
from src.common.server import global_server
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import uvicorn
|
||||
import os
|
||||
import traceback
|
||||
|
||||
logger = get_module_logger("api")
|
||||
|
||||
|
||||
class BaseMessageHandler:
|
||||
"""消息处理基类"""
|
||||
|
||||
def __init__(self):
|
||||
self.message_handlers: List[Callable] = []
|
||||
self.background_tasks = set()
|
||||
|
||||
def register_message_handler(self, handler: Callable):
|
||||
"""注册消息处理函数"""
|
||||
self.message_handlers.append(handler)
|
||||
|
||||
async def process_message(self, message: Dict[str, Any]):
|
||||
"""处理单条消息"""
|
||||
tasks = []
|
||||
for handler in self.message_handlers:
|
||||
try:
|
||||
tasks.append(handler(message))
|
||||
except Exception as e:
|
||||
logger.error(f"消息处理出错: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
# 不抛出异常,而是记录错误并继续处理其他消息
|
||||
continue
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def _handle_message(self, message: Dict[str, Any]):
|
||||
"""后台处理单个消息"""
|
||||
try:
|
||||
await self.process_message(message)
|
||||
except Exception as e:
|
||||
raise RuntimeError(str(e)) from e
|
||||
|
||||
|
||||
class MessageServer(BaseMessageHandler):
|
||||
"""WebSocket服务端"""
|
||||
|
||||
_class_handlers: List[Callable] = [] # 类级别的消息处理器
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 18000,
|
||||
enable_token=False,
|
||||
app: Optional[FastAPI] = None,
|
||||
path: str = "/ws",
|
||||
):
|
||||
super().__init__()
|
||||
# 将类级别的处理器添加到实例处理器中
|
||||
self.message_handlers.extend(self._class_handlers)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.path = path
|
||||
self.app = app or FastAPI()
|
||||
self.own_app = app is None # 标记是否使用自己创建的app
|
||||
self.active_websockets: Set[WebSocket] = set()
|
||||
self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射
|
||||
self.valid_tokens: Set[str] = set()
|
||||
self.enable_token = enable_token
|
||||
self._setup_routes()
|
||||
self._running = False
|
||||
|
||||
def _setup_routes(self):
|
||||
@self.app.post("/api/message")
|
||||
async def handle_message(message: Dict[str, Any]):
|
||||
try:
|
||||
# 创建后台任务处理消息
|
||||
asyncio.create_task(self._handle_message(message))
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
@self.app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
headers = dict(websocket.headers)
|
||||
token = headers.get("authorization")
|
||||
platform = headers.get("platform", "default") # 获取platform标识
|
||||
if self.enable_token:
|
||||
if not token or not await self.verify_token(token):
|
||||
await websocket.close(code=1008, reason="Invalid or missing token")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
self.active_websockets.add(websocket)
|
||||
|
||||
# 添加到platform映射
|
||||
if platform not in self.platform_websockets:
|
||||
self.platform_websockets[platform] = websocket
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_json()
|
||||
# print(f"Received message: {message}")
|
||||
asyncio.create_task(self._handle_message(message))
|
||||
except WebSocketDisconnect:
|
||||
self._remove_websocket(websocket, platform)
|
||||
except Exception as e:
|
||||
self._remove_websocket(websocket, platform)
|
||||
raise RuntimeError(str(e)) from e
|
||||
finally:
|
||||
self._remove_websocket(websocket, platform)
|
||||
|
||||
@classmethod
|
||||
def register_class_handler(cls, handler: Callable):
|
||||
"""注册类级别的消息处理器"""
|
||||
if handler not in cls._class_handlers:
|
||||
cls._class_handlers.append(handler)
|
||||
|
||||
def register_message_handler(self, handler: Callable):
|
||||
"""注册实例级别的消息处理器"""
|
||||
if handler not in self.message_handlers:
|
||||
self.message_handlers.append(handler)
|
||||
|
||||
async def verify_token(self, token: str) -> bool:
|
||||
if not self.enable_token:
|
||||
return True
|
||||
return token in self.valid_tokens
|
||||
|
||||
def add_valid_token(self, token: str):
|
||||
self.valid_tokens.add(token)
|
||||
|
||||
def remove_valid_token(self, token: str):
|
||||
self.valid_tokens.discard(token)
|
||||
|
||||
def run_sync(self):
|
||||
"""同步方式运行服务器"""
|
||||
if not self.own_app:
|
||||
raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法")
|
||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||
|
||||
async def run(self):
|
||||
"""异步方式运行服务器"""
|
||||
self._running = True
|
||||
try:
|
||||
if self.own_app:
|
||||
# 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器
|
||||
# 禁用 uvicorn 默认日志和访问日志
|
||||
config = uvicorn.Config(
|
||||
self.app, host=self.host, port=self.port, loop="asyncio", log_config=None, access_log=False
|
||||
)
|
||||
self.server = uvicorn.Server(config)
|
||||
await self.server.serve()
|
||||
else:
|
||||
# 如果使用外部 FastAPI 实例,保持运行状态以处理消息
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
await self.stop()
|
||||
raise
|
||||
except Exception as e:
|
||||
await self.stop()
|
||||
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||
finally:
|
||||
await self.stop()
|
||||
|
||||
async def start_server(self):
|
||||
"""启动服务器的异步方法"""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
await self.run()
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务器"""
|
||||
# 清理platform映射
|
||||
self.platform_websockets.clear()
|
||||
|
||||
# 取消所有后台任务
|
||||
for task in self.background_tasks:
|
||||
task.cancel()
|
||||
# 等待所有任务完成
|
||||
await asyncio.gather(*self.background_tasks, return_exceptions=True)
|
||||
self.background_tasks.clear()
|
||||
|
||||
# 关闭所有WebSocket连接
|
||||
for websocket in self.active_websockets:
|
||||
await websocket.close()
|
||||
self.active_websockets.clear()
|
||||
|
||||
if hasattr(self, "server") and self.own_app:
|
||||
self._running = False
|
||||
# 正确关闭 uvicorn 服务器
|
||||
self.server.should_exit = True
|
||||
await self.server.shutdown()
|
||||
# 等待服务器完全停止
|
||||
if hasattr(self.server, "started") and self.server.started:
|
||||
await self.server.main_loop()
|
||||
# 清理处理程序
|
||||
self.message_handlers.clear()
|
||||
|
||||
def _remove_websocket(self, websocket: WebSocket, platform: str):
|
||||
"""从所有集合中移除websocket"""
|
||||
if websocket in self.active_websockets:
|
||||
self.active_websockets.remove(websocket)
|
||||
if platform in self.platform_websockets:
|
||||
if self.platform_websockets[platform] == websocket:
|
||||
del self.platform_websockets[platform]
|
||||
|
||||
async def broadcast_message(self, message: Dict[str, Any]):
|
||||
disconnected = set()
|
||||
for websocket in self.active_websockets:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception:
|
||||
disconnected.add(websocket)
|
||||
for websocket in disconnected:
|
||||
self.active_websockets.remove(websocket)
|
||||
|
||||
async def broadcast_to_platform(self, platform: str, message: Dict[str, Any]):
|
||||
"""向指定平台的所有WebSocket客户端广播消息"""
|
||||
if platform not in self.platform_websockets:
|
||||
raise ValueError(f"平台:{platform} 未连接")
|
||||
|
||||
disconnected = set()
|
||||
try:
|
||||
await self.platform_websockets[platform].send_json(message)
|
||||
except Exception:
|
||||
disconnected.add(self.platform_websockets[platform])
|
||||
|
||||
# 清理断开的连接
|
||||
for websocket in disconnected:
|
||||
self._remove_websocket(websocket, platform)
|
||||
|
||||
async def send_message(self, message: MessageBase):
|
||||
await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
|
||||
|
||||
@staticmethod
|
||||
async def send_message_rest(url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送消息到指定端点"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
raise e
|
||||
from maim_message import MessageServer
|
||||
|
||||
|
||||
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Optional, Union, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seg:
|
||||
"""消息片段类,用于表示消息的不同部分
|
||||
|
||||
Attributes:
|
||||
type: 片段类型,可以是 'text'、'image'、'seglist' 等
|
||||
data: 片段的具体内容
|
||||
- 对于 text 类型,data 是字符串
|
||||
- 对于 image 类型,data 是 base64 字符串
|
||||
- 对于 seglist 类型,data 是 Seg 列表
|
||||
"""
|
||||
|
||||
type: str
|
||||
data: Union[str, List["Seg"]]
|
||||
|
||||
# def __init__(self, type: str, data: Union[str, List['Seg']],):
|
||||
# """初始化实例,确保字典和属性同步"""
|
||||
# # 先初始化字典
|
||||
# self.type = type
|
||||
# self.data = data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "Seg":
|
||||
"""从字典创建Seg实例"""
|
||||
type = data.get("type")
|
||||
data = data.get("data")
|
||||
if type == "seglist":
|
||||
data = [Seg.from_dict(seg) for seg in data]
|
||||
return cls(type=type, data=data)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
result = {"type": self.type}
|
||||
if self.type == "seglist":
|
||||
result["data"] = [seg.to_dict() for seg in self.data]
|
||||
else:
|
||||
result["data"] = self.data
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupInfo:
|
||||
"""群组信息类"""
|
||||
|
||||
platform: Optional[str] = None
|
||||
group_id: Optional[int] = None
|
||||
group_name: Optional[str] = None # 群名称
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "GroupInfo":
|
||||
"""从字典创建GroupInfo实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
GroupInfo: 新的实例
|
||||
"""
|
||||
if data.get("group_id") is None:
|
||||
return None
|
||||
return cls(
|
||||
platform=data.get("platform"), group_id=data.get("group_id"), group_name=data.get("group_name", None)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInfo:
|
||||
"""用户信息类"""
|
||||
|
||||
platform: Optional[str] = None
|
||||
user_id: Optional[int] = None
|
||||
user_nickname: Optional[str] = None # 用户昵称
|
||||
user_cardname: Optional[str] = None # 用户群昵称
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "UserInfo":
|
||||
"""从字典创建UserInfo实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
UserInfo: 新的实例
|
||||
"""
|
||||
return cls(
|
||||
platform=data.get("platform"),
|
||||
user_id=data.get("user_id"),
|
||||
user_nickname=data.get("user_nickname", None),
|
||||
user_cardname=data.get("user_cardname", None),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormatInfo:
|
||||
"""格式信息类"""
|
||||
|
||||
"""
|
||||
目前maimcore可接受的格式为text,image,emoji
|
||||
可发送的格式为text,emoji,reply
|
||||
"""
|
||||
|
||||
content_format: Optional[str] = None
|
||||
accept_format: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "FormatInfo":
|
||||
"""从字典创建FormatInfo实例
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
Returns:
|
||||
FormatInfo: 新的实例
|
||||
"""
|
||||
return cls(
|
||||
content_format=data.get("content_format"),
|
||||
accept_format=data.get("accept_format"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateInfo:
|
||||
"""模板信息类"""
|
||||
|
||||
template_items: Optional[Dict] = None
|
||||
template_name: Optional[str] = None
|
||||
template_default: bool = True
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "TemplateInfo":
|
||||
"""从字典创建TemplateInfo实例
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
Returns:
|
||||
TemplateInfo: 新的实例
|
||||
"""
|
||||
return cls(
|
||||
template_items=data.get("template_items"),
|
||||
template_name=data.get("template_name"),
|
||||
template_default=data.get("template_default", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseMessageInfo:
|
||||
"""消息信息类"""
|
||||
|
||||
platform: Optional[str] = None
|
||||
message_id: Union[str, int, None] = None
|
||||
time: Optional[float] = None
|
||||
group_info: Optional[GroupInfo] = None
|
||||
user_info: Optional[UserInfo] = None
|
||||
format_info: Optional[FormatInfo] = None
|
||||
template_info: Optional[TemplateInfo] = None
|
||||
additional_config: Optional[dict] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
result = {}
|
||||
for field, value in asdict(self).items():
|
||||
if value is not None:
|
||||
if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)):
|
||||
result[field] = value.to_dict()
|
||||
else:
|
||||
result[field] = value
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "BaseMessageInfo":
|
||||
"""从字典创建BaseMessageInfo实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
BaseMessageInfo: 新的实例
|
||||
"""
|
||||
group_info = GroupInfo.from_dict(data.get("group_info", {}))
|
||||
user_info = UserInfo.from_dict(data.get("user_info", {}))
|
||||
format_info = FormatInfo.from_dict(data.get("format_info", {}))
|
||||
template_info = TemplateInfo.from_dict(data.get("template_info", {}))
|
||||
return cls(
|
||||
platform=data.get("platform"),
|
||||
message_id=data.get("message_id"),
|
||||
time=data.get("time"),
|
||||
additional_config=data.get("additional_config", None),
|
||||
group_info=group_info,
|
||||
user_info=user_info,
|
||||
format_info=format_info,
|
||||
template_info=template_info,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageBase:
|
||||
"""消息类"""
|
||||
|
||||
message_info: BaseMessageInfo
|
||||
message_segment: Seg
|
||||
raw_message: Optional[str] = None # 原始消息,包含未解析的cq码
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式
|
||||
|
||||
Returns:
|
||||
Dict: 包含所有非None字段的字典,其中:
|
||||
- message_info: 转换为字典格式
|
||||
- message_segment: 转换为字典格式
|
||||
- raw_message: 如果存在则包含
|
||||
"""
|
||||
result = {"message_info": self.message_info.to_dict(), "message_segment": self.message_segment.to_dict()}
|
||||
if self.raw_message is not None:
|
||||
result["raw_message"] = self.raw_message
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "MessageBase":
|
||||
"""从字典创建MessageBase实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
MessageBase: 新的实例
|
||||
"""
|
||||
message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
|
||||
message_segment = Seg.from_dict(data.get("message_segment", {}))
|
||||
raw_message = data.get("raw_message", None)
|
||||
return cls(message_info=message_info, message_segment=message_segment, raw_message=raw_message)
|
||||
@@ -178,395 +178,6 @@ class LLMRequest:
|
||||
output_cost = (completion_tokens / 1000000) * self.pri_out
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
'''
|
||||
async def _execute_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
prompt: str = None,
|
||||
image_base64: str = None,
|
||||
image_format: str = None,
|
||||
payload: dict = None,
|
||||
retry_policy: dict = None,
|
||||
response_handler: callable = None,
|
||||
user_id: str = "system",
|
||||
request_type: str = None,
|
||||
):
|
||||
"""统一请求执行入口
|
||||
Args:
|
||||
endpoint: API端点路径 (如 "chat/completions")
|
||||
prompt: prompt文本
|
||||
image_base64: 图片的base64编码
|
||||
image_format: 图片格式
|
||||
payload: 请求体数据
|
||||
retry_policy: 自定义重试策略
|
||||
response_handler: 自定义响应处理器
|
||||
user_id: 用户ID
|
||||
request_type: 请求类型
|
||||
"""
|
||||
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
|
||||
# 合并重试策略
|
||||
default_retry = {
|
||||
"max_retries": 3,
|
||||
"base_wait": 10,
|
||||
"retry_codes": [429, 413, 500, 503],
|
||||
"abort_codes": [400, 401, 402, 403],
|
||||
}
|
||||
policy = {**default_retry, **(retry_policy or {})}
|
||||
|
||||
# 常见Error Code Mapping
|
||||
error_code_mapping = {
|
||||
400: "参数不正确",
|
||||
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~",
|
||||
402: "账号余额不足",
|
||||
403: "需要实名,或余额不足",
|
||||
404: "Not Found",
|
||||
429: "请求过于频繁,请稍后再试",
|
||||
500: "服务器内部故障",
|
||||
503: "服务器负载过高",
|
||||
}
|
||||
|
||||
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||
# 判断是否为流式
|
||||
stream_mode = self.stream
|
||||
# logger_msg = "进入流式输出模式," if stream_mode else ""
|
||||
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
|
||||
# logger.info(f"使用模型: {self.model_name}")
|
||||
|
||||
# 构建请求体
|
||||
if image_base64:
|
||||
payload = await self._build_payload(prompt, image_base64, image_format)
|
||||
elif payload is None:
|
||||
payload = await self._build_payload(prompt)
|
||||
|
||||
# 流式输出标志
|
||||
# 先构建payload,再添加流式输出标志
|
||||
if stream_mode:
|
||||
payload["stream"] = stream_mode
|
||||
|
||||
for retry in range(policy["max_retries"]):
|
||||
try:
|
||||
# 使用上下文管理器处理会话
|
||||
headers = await self._build_headers()
|
||||
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
||||
if stream_mode:
|
||||
headers["Accept"] = "text/event-stream"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(api_url, headers=headers, json=payload) as response:
|
||||
# 处理需要重试的状态码
|
||||
if response.status in policy["retry_codes"]:
|
||||
wait_time = policy["base_wait"] * (2**retry)
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试"
|
||||
)
|
||||
if response.status == 413:
|
||||
logger.warning("请求体过大,尝试压缩...")
|
||||
image_base64 = compress_base64_image_by_scale(image_base64)
|
||||
payload = await self._build_payload(prompt, image_base64, image_format)
|
||||
elif response.status in [500, 503]:
|
||||
logger.error(
|
||||
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
||||
)
|
||||
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
|
||||
else:
|
||||
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
|
||||
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
elif response.status in policy["abort_codes"]:
|
||||
logger.error(
|
||||
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
||||
)
|
||||
# 尝试获取并记录服务器返回的详细错误信息
|
||||
try:
|
||||
error_json = await response.json()
|
||||
if error_json and isinstance(error_json, list) and len(error_json) > 0:
|
||||
for error_item in error_json:
|
||||
if "error" in error_item and isinstance(error_item["error"], dict):
|
||||
error_obj = error_item["error"]
|
||||
error_code = error_obj.get("code")
|
||||
error_message = error_obj.get("message")
|
||||
error_status = error_obj.get("status")
|
||||
logger.error(
|
||||
f"服务器错误详情: 代码={error_code}, 状态={error_status}, "
|
||||
f"消息={error_message}"
|
||||
)
|
||||
elif isinstance(error_json, dict) and "error" in error_json:
|
||||
# 处理单个错误对象的情况
|
||||
error_obj = error_json.get("error", {})
|
||||
error_code = error_obj.get("code")
|
||||
error_message = error_obj.get("message")
|
||||
error_status = error_obj.get("status")
|
||||
logger.error(
|
||||
f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
|
||||
)
|
||||
else:
|
||||
# 记录原始错误响应内容
|
||||
logger.error(f"服务器错误响应: {error_json}")
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析服务器错误响应: {str(e)}")
|
||||
|
||||
if response.status == 403:
|
||||
# 只针对硅基流动的V3和R1进行降级处理
|
||||
if (
|
||||
self.model_name.startswith("Pro/deepseek-ai")
|
||||
and self.base_url == "https://api.siliconflow.cn/v1/"
|
||||
):
|
||||
old_model_name = self.model_name
|
||||
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
|
||||
logger.warning(
|
||||
f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}"
|
||||
)
|
||||
|
||||
# 对全局配置进行更新
|
||||
if global_config.llm_normal.get("name") == old_model_name:
|
||||
global_config.llm_normal["name"] = self.model_name
|
||||
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
|
||||
|
||||
if global_config.llm_reasoning.get("name") == old_model_name:
|
||||
global_config.llm_reasoning["name"] = self.model_name
|
||||
logger.warning(
|
||||
f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}"
|
||||
)
|
||||
|
||||
# 更新payload中的模型名
|
||||
if payload and "model" in payload:
|
||||
payload["model"] = self.model_name
|
||||
|
||||
# 重新尝试请求
|
||||
retry -= 1 # 不计入重试次数
|
||||
continue
|
||||
|
||||
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
|
||||
|
||||
response.raise_for_status()
|
||||
reasoning_content = ""
|
||||
|
||||
# 将流式输出转化为非流式输出
|
||||
if stream_mode:
|
||||
flag_delta_content_finished = False
|
||||
accumulated_content = ""
|
||||
usage = None # 初始化usage变量,避免未定义错误
|
||||
|
||||
async for line_bytes in response.content:
|
||||
try:
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
if flag_delta_content_finished:
|
||||
chunk_usage = chunk.get("usage", None)
|
||||
if chunk_usage:
|
||||
usage = chunk_usage # 获取token用量
|
||||
else:
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
delta_content = delta.get("content")
|
||||
if delta_content is None:
|
||||
delta_content = ""
|
||||
accumulated_content += delta_content
|
||||
# 检测流式输出文本是否结束
|
||||
finish_reason = chunk["choices"][0].get("finish_reason")
|
||||
if delta.get("reasoning_content", None):
|
||||
reasoning_content += delta["reasoning_content"]
|
||||
if finish_reason == "stop":
|
||||
chunk_usage = chunk.get("usage", None)
|
||||
if chunk_usage:
|
||||
usage = chunk_usage
|
||||
break
|
||||
# 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
|
||||
flag_delta_content_finished = True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}")
|
||||
except GeneratorExit:
|
||||
logger.warning("模型 {self.model_name} 流式输出被中断,正在清理资源...")
|
||||
# 确保资源被正确清理
|
||||
await response.release()
|
||||
# 返回已经累积的内容
|
||||
result = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": accumulated_content,
|
||||
"reasoning_content": reasoning_content,
|
||||
# 流式输出可能没有工具调用,此处不需要添加tool_calls字段
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": usage,
|
||||
}
|
||||
return (
|
||||
response_handler(result)
|
||||
if response_handler
|
||||
else self._default_response_handler(result, user_id, request_type, endpoint)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}")
|
||||
# 确保在发生错误时也能正确清理资源
|
||||
try:
|
||||
await response.release()
|
||||
except Exception as cleanup_error:
|
||||
logger.error(f"清理资源时发生错误: {cleanup_error}")
|
||||
# 返回已经累积的内容
|
||||
result = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": accumulated_content,
|
||||
"reasoning_content": reasoning_content,
|
||||
# 流式输出可能没有工具调用,此处不需要添加tool_calls字段
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": usage,
|
||||
}
|
||||
return (
|
||||
response_handler(result)
|
||||
if response_handler
|
||||
else self._default_response_handler(result, user_id, request_type, endpoint)
|
||||
)
|
||||
content = accumulated_content
|
||||
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||||
if think_match:
|
||||
reasoning_content = think_match.group(1).strip()
|
||||
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
||||
# 构造一个伪result以便调用自定义响应处理器或默认处理器
|
||||
result = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": content,
|
||||
"reasoning_content": reasoning_content,
|
||||
# 流式输出可能没有工具调用,此处不需要添加tool_calls字段
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": usage,
|
||||
}
|
||||
return (
|
||||
response_handler(result)
|
||||
if response_handler
|
||||
else self._default_response_handler(result, user_id, request_type, endpoint)
|
||||
)
|
||||
else:
|
||||
result = await response.json()
|
||||
# 使用自定义处理器或默认处理
|
||||
return (
|
||||
response_handler(result)
|
||||
if response_handler
|
||||
else self._default_response_handler(result, user_id, request_type, endpoint)
|
||||
)
|
||||
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
if retry < policy["max_retries"] - 1:
|
||||
wait_time = policy["base_wait"] * (2**retry)
|
||||
logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(e)}")
|
||||
raise RuntimeError(f"网络请求失败: {str(e)}") from e
|
||||
except Exception as e:
|
||||
logger.critical(f"模型 {self.model_name} 未预期的错误: {str(e)}")
|
||||
raise RuntimeError(f"请求过程中发生错误: {str(e)}") from e
|
||||
|
||||
except aiohttp.ClientResponseError as e:
|
||||
# 处理aiohttp抛出的响应错误
|
||||
if retry < policy["max_retries"] - 1:
|
||||
wait_time = policy["base_wait"] * (2**retry)
|
||||
logger.error(
|
||||
f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}"
|
||||
)
|
||||
try:
|
||||
if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
|
||||
error_text = await e.response.text()
|
||||
try:
|
||||
error_json = json.loads(error_text)
|
||||
if isinstance(error_json, list) and len(error_json) > 0:
|
||||
for error_item in error_json:
|
||||
if "error" in error_item and isinstance(error_item["error"], dict):
|
||||
error_obj = error_item["error"]
|
||||
logger.error(
|
||||
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
|
||||
f"状态={error_obj.get('status')}, "
|
||||
f"消息={error_obj.get('message')}"
|
||||
)
|
||||
elif isinstance(error_json, dict) and "error" in error_json:
|
||||
error_obj = error_json.get("error", {})
|
||||
logger.error(
|
||||
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
|
||||
f"状态={error_obj.get('status')}, "
|
||||
f"消息={error_obj.get('message')}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
|
||||
except (json.JSONDecodeError, TypeError) as json_err:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
|
||||
)
|
||||
except (AttributeError, TypeError, ValueError) as parse_err:
|
||||
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
|
||||
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.critical(
|
||||
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}"
|
||||
)
|
||||
# 安全地检查和记录请求详情
|
||||
if (
|
||||
image_base64
|
||||
and payload
|
||||
and isinstance(payload, dict)
|
||||
and "messages" in payload
|
||||
and len(payload["messages"]) > 0
|
||||
):
|
||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||
content = payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
||||
raise RuntimeError(f"模型 {self.model_name} API请求失败: 状态码 {e.status}, {e.message}") from e
|
||||
except Exception as e:
|
||||
if retry < policy["max_retries"] - 1:
|
||||
wait_time = policy["base_wait"] * (2**retry)
|
||||
logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.critical(f"模型 {self.model_name} 请求失败: {str(e)}")
|
||||
# 安全地检查和记录请求详情
|
||||
if (
|
||||
image_base64
|
||||
and payload
|
||||
and isinstance(payload, dict)
|
||||
and "messages" in payload
|
||||
and len(payload["messages"]) > 0
|
||||
):
|
||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||
content = payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
||||
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(e)}") from e
|
||||
|
||||
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
|
||||
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败")
|
||||
'''
|
||||
|
||||
async def _prepare_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
@@ -820,6 +431,7 @@ class LLMRequest:
|
||||
policy = request_content["policy"]
|
||||
payload = request_content["payload"]
|
||||
wait_time = policy["base_wait"] * (2**retry_count)
|
||||
keep_request = False
|
||||
if retry_count < policy["max_retries"] - 1:
|
||||
keep_request = True
|
||||
if isinstance(exception, RequestAbortException):
|
||||
|
||||
@@ -256,7 +256,7 @@ class MoodManager:
|
||||
def print_mood_status(self) -> None:
|
||||
"""打印当前情绪状态"""
|
||||
logger.info(
|
||||
f"[情绪状态]愉悦度: {self.current_mood.valence:.2f}, "
|
||||
f"愉悦度: {self.current_mood.valence:.2f}, "
|
||||
f"唤醒度: {self.current_mood.arousal:.2f}, "
|
||||
f"心情: {self.current_mood.text}"
|
||||
)
|
||||
|
||||
@@ -53,7 +53,7 @@ person_info_default = {
|
||||
# "impression" : None,
|
||||
# "gender" : Unkown,
|
||||
"konw_time": 0,
|
||||
"msg_interval": 3000,
|
||||
"msg_interval": 2000,
|
||||
"msg_interval_list": [],
|
||||
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
|
||||
|
||||
@@ -384,18 +384,30 @@ class PersonInfoManager:
|
||||
if delta > 0:
|
||||
time_interval.append(delta)
|
||||
|
||||
time_interval = [t for t in time_interval if 500 <= t <= 8000]
|
||||
if len(time_interval) >= 30:
|
||||
time_interval = [t for t in time_interval if 200 <= t <= 8000]
|
||||
# --- 修改后的逻辑 ---
|
||||
# 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断)
|
||||
if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条
|
||||
time_interval.sort()
|
||||
|
||||
# 画图(log)
|
||||
# 画图(log) - 这部分保留
|
||||
msg_interval_map = True
|
||||
log_dir = Path("logs/person_info")
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
plt.figure(figsize=(10, 6))
|
||||
time_series = pd.Series(time_interval)
|
||||
plt.hist(time_series, bins=50, density=True, alpha=0.4, color="pink", label="Histogram")
|
||||
time_series.plot(kind="kde", color="mediumpurple", linewidth=1, label="Density")
|
||||
# 使用截断前的数据画图,更能反映原始分布
|
||||
time_series_original = pd.Series(time_interval)
|
||||
plt.hist(
|
||||
time_series_original,
|
||||
bins=50,
|
||||
density=True,
|
||||
alpha=0.4,
|
||||
color="pink",
|
||||
label="Histogram (Original Filtered)",
|
||||
)
|
||||
time_series_original.plot(
|
||||
kind="kde", color="mediumpurple", linewidth=1, label="Density (Original Filtered)"
|
||||
)
|
||||
plt.grid(True, alpha=0.2)
|
||||
plt.xlim(0, 8000)
|
||||
plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)")
|
||||
@@ -405,15 +417,24 @@ class PersonInfoManager:
|
||||
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
|
||||
plt.savefig(img_path)
|
||||
plt.close()
|
||||
# 画图
|
||||
# 画图结束
|
||||
|
||||
q25, q75 = np.percentile(time_interval, [25, 75])
|
||||
iqr = q75 - q25
|
||||
filtered = [x for x in time_interval if (q25 - 1.5 * iqr) <= x <= (q75 + 1.5 * iqr)]
|
||||
# 去掉头尾各 5 个数据点
|
||||
trimmed_interval = time_interval[5:-5]
|
||||
|
||||
msg_interval = int(round(np.percentile(filtered, 80)))
|
||||
await self.update_one_field(person_id, "msg_interval", msg_interval)
|
||||
logger.trace(f"用户{person_id}的msg_interval已经被更新为{msg_interval}")
|
||||
# 计算截断后数据的 37% 分位数
|
||||
if trimmed_interval: # 确保截断后列表不为空
|
||||
msg_interval = int(round(np.percentile(trimmed_interval, 37)))
|
||||
# 更新数据库
|
||||
await self.update_one_field(person_id, "msg_interval", msg_interval)
|
||||
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}")
|
||||
else:
|
||||
logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval")
|
||||
else:
|
||||
logger.trace(
|
||||
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
|
||||
)
|
||||
# --- 修改结束 ---
|
||||
except Exception as e:
|
||||
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}")
|
||||
continue
|
||||
|
||||
@@ -168,7 +168,10 @@ async def _build_readable_messages_internal(
|
||||
user_info = msg.get("user_info", {})
|
||||
platform = user_info.get("platform")
|
||||
user_id = user_info.get("user_id")
|
||||
user_nickname = user_info.get("nickname")
|
||||
|
||||
user_nickname = user_info.get("user_nickname")
|
||||
user_cardname = user_info.get("user_cardname")
|
||||
|
||||
timestamp = msg.get("time")
|
||||
content = msg.get("processed_plain_text", "") # 默认空字符串
|
||||
|
||||
@@ -186,7 +189,12 @@ async def _build_readable_messages_internal(
|
||||
|
||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||
if not person_name:
|
||||
person_name = user_nickname
|
||||
if user_cardname:
|
||||
person_name = f"昵称:{user_cardname}"
|
||||
elif user_nickname:
|
||||
person_name = f"{user_nickname}"
|
||||
else:
|
||||
person_name = "某人"
|
||||
|
||||
message_details.append((timestamp, person_name, content))
|
||||
|
||||
@@ -303,9 +311,7 @@ async def build_readable_messages(
|
||||
)
|
||||
|
||||
readable_read_mark = translate_timestamp_to_human_readable(read_mark, mode=timestamp_mode)
|
||||
read_mark_line = (
|
||||
f"\n\n--- 以上消息已读 (标记时间: {readable_read_mark}) ---\n--- 请关注你上次思考之后以下的新消息---\n"
|
||||
)
|
||||
read_mark_line = f"\n--- 以上消息是你已经思考过的内容已读 (标记时间: {readable_read_mark}) ---\n--- 请关注以下未读的新消息---\n"
|
||||
|
||||
# 组合结果,确保空部分不引入多余的标记或换行
|
||||
if formatted_before and formatted_after:
|
||||
|
||||
Reference in New Issue
Block a user