This commit is contained in:
UnCLASPrommer
2025-07-15 18:02:43 +08:00
29 changed files with 1487 additions and 295 deletions

View File

@@ -11,7 +11,7 @@ import pandas as pd
import faiss
# from .llm_client import LLMClient
from .lpmmconfig import global_config
# from .lpmmconfig import global_config
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
@@ -27,15 +27,12 @@ from rich.progress import (
)
from src.manager.local_store_manager import local_storage
from src.chat.utils.utils import get_embedding
from src.config.config import global_config
install(extra_lines=3)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = (
os.path.join(ROOT_PATH, "data", "embedding")
if global_config["persistence"]["embedding_data_dir"] is None
else os.path.join(ROOT_PATH, global_config["persistence"]["embedding_data_dir"])
)
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
@@ -260,7 +257,7 @@ class EmbeddingStore:
# L2归一化
faiss.normalize_L2(embeddings)
# 构建索引
self.faiss_index = faiss.IndexFlatIP(global_config["embedding"]["dimension"])
self.faiss_index = faiss.IndexFlatIP(global_config.lpmm_knowledge.embedding_dimension)
self.faiss_index.add(embeddings)
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:

View File

@@ -9,7 +9,7 @@ from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
@@ -141,6 +141,29 @@ class ChatBot:
logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息
async def do_s4u(self, message_data: Dict[str, Any]):
message = MessageRecvS4U(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
message.update_chat_stream(chat)
# 处理消息内容
await message.process()
await self.s4u_message_processor.process_message(message)
return
async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
@@ -158,6 +181,10 @@ class ChatBot:
try:
# 确保所有任务已启动
await self._ensure_started()
if ENABLE_S4U_CHAT:
await self.do_s4u(message_data)
return
if message_data["message_info"].get("group_info") is not None:
message_data["message_info"]["group_info"]["group_id"] = str(
@@ -221,11 +248,6 @@ class ChatBot:
template_group_name = None
async def preprocess():
if ENABLE_S4U_CHAT:
logger.info("进入S4U流程")
await self.s4u_message_processor.process_message(message)
return
await self.heartflow_message_receiver.process_message(message)
if template_group_name:

View File

@@ -38,7 +38,6 @@ class Message(MessageBase):
message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None,
reply: Optional["MessageRecv"] = None,
detailed_plain_text: str = "",
processed_plain_text: str = "",
):
# 使用传入的时间戳或当前时间
@@ -58,7 +57,6 @@ class Message(MessageBase):
self.chat_stream = chat_stream
# 文本处理相关属性
self.processed_plain_text = processed_plain_text
self.detailed_plain_text = detailed_plain_text
# 回复消息
self.reply = reply
@@ -104,7 +102,6 @@ class MessageRecv(Message):
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
self.processed_plain_text = message_dict.get("processed_plain_text", "")
self.detailed_plain_text = message_dict.get("detailed_plain_text", "")
self.is_emoji = False
self.has_emoji = False
self.is_picid = False
@@ -123,7 +120,6 @@ class MessageRecv(Message):
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
@@ -182,12 +178,97 @@ class MessageRecv(Message):
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageRecvS4U(MessageRecv):
def __init__(self, message_dict: dict[str, Any]):
super().__init__(message_dict)
self.is_gift = False
self.is_superchat = False
self.gift_info = None
self.gift_name = None
self.gift_count = None
self.superchat_info = None
self.superchat_price = None
self.superchat_message_text = None
async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
return segment.data # type: ignore
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
self.is_emoji = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "gift":
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1)
self.gift_info = segment.data
self.gift_name = name.strip()
self.gift_count = int(count.strip())
return ""
elif segment.type == "superchat":
self.is_superchat = True
self.superchat_info = segment.data
price,message_text = segment.data.split(":", 1)
self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
return self.processed_plain_text
else:
return ""
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass
@@ -472,7 +553,6 @@ def message_from_db_dict(db_dict: dict) -> MessageRecv:
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
"raw_message": None, # 数据库中未存储原始消息
"processed_plain_text": processed_text,
"detailed_plain_text": db_dict.get("detailed_plain_text", ""),
}
# 创建 MessageRecv 实例

View File

@@ -2,7 +2,6 @@ import traceback
import time
import asyncio
import random
import ast
import re
from typing import List, Optional, Dict, Any, Tuple

View File

@@ -121,27 +121,6 @@ async def get_embedding(text, request_type="embedding"):
return embedding
def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False):
filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)]
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
if not recent_messages:
return []
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages)
message_detailed_plain_text_list = []
for msg_db_data in recent_messages:
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text_list
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人
filter_query = {"chat_id": chat_stream_id}