Merge remote-tracking branch 'upstream/debug' into debug

This commit is contained in:
tcmofashi
2025-03-12 08:25:42 +08:00
45 changed files with 1842 additions and 659 deletions

View File

@@ -1,7 +1,6 @@
from typing import Optional
from pymongo import MongoClient
from pymongo.database import Database as MongoDatabase
class Database:
_instance: Optional["Database"] = None
@@ -27,7 +26,7 @@ class Database:
else:
# 否则使用无认证连接
self.client = MongoClient(host, port)
self.db = self.client[db_name]
self.db: MongoDatabase = self.client[db_name]
@classmethod
def initialize(
@@ -39,18 +38,18 @@ class Database:
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
) -> "Database":
) -> MongoDatabase:
if cls._instance is None:
cls._instance = cls(
host, port, db_name, username, password, auth_source, uri
)
return cls._instance
return cls._instance.db
@classmethod
def get_instance(cls) -> "Database":
def get_instance(cls) -> MongoDatabase:
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance
return cls._instance.db
#测试用

View File

@@ -46,7 +46,7 @@ class ReasoningGUI:
# 初始化数据库连接
try:
self.db = Database.get_instance().db
self.db = Database.get_instance()
logger.success("数据库连接成功")
except RuntimeError:
logger.warning("数据库未初始化,正在尝试初始化...")
@@ -60,7 +60,7 @@ class ReasoningGUI:
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance().db
self.db = Database.get_instance()
logger.success("数据库初始化成功")
except Exception:
logger.exception("数据库初始化失败")

View File

@@ -4,7 +4,7 @@ import os
from loguru import logger
from nonebot import get_driver, on_message, require
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment,MessageEvent
from nonebot.typing import T_State
from ...common.database import Database
@@ -32,26 +32,14 @@ _message_manager_started = False
driver = get_driver()
config = driver.config
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
logger.success("初始化数据库成功")
# 初始化表情管理器
emoji_manager.initialize()
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
# 创建机器人实例
chat_bot = ChatBot()
# 注册消息处理器
group_msg = on_message(priority=5)
# 注册消息处理器
msg_in = on_message(priority=5)
# 创建定时任务
scheduler = require("nonebot_plugin_apscheduler").scheduler
@@ -103,8 +91,8 @@ async def _(bot: Bot):
asyncio.create_task(chat_manager._auto_save_task())
@group_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
@msg_in.handle()
async def _(bot: Bot, event: MessageEvent, state: T_State):
await chat_bot.handle_message(event, bot)
@@ -127,7 +115,7 @@ async def build_memory_task():
async def forget_memory_task():
"""每30秒执行一次记忆构建"""
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1)
await hippocampus.operation_forget_topic(percentage=global_config.memory_forget_percentage)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")

View File

@@ -2,12 +2,16 @@ import re
import time
from random import random
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
from nonebot.adapters.onebot.v11 import (
Bot,
GroupMessageEvent,
MessageEvent,
PrivateMessageEvent,
)
from ..memory_system.memory import hippocampus
from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config
from .cq_code import CQCode, cq_code_tool # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
@@ -42,39 +46,53 @@ class ChatBot:
if not self._started:
self._started = True
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的消息"""
async def handle_message(self, event: MessageEvent, bot: Bot) -> None:
"""处理收到的消息"""
self.bot = bot # 更新 bot 实例
try:
group_info_api = await bot.get_group_info(group_id=event.group_id)
logger.info(f"成功获取群信息: {group_info_api}")
group_name = group_info_api["group_name"]
except Exception as e:
logger.error(f"获取群信息失败: {str(e)}")
group_name = None
# 白名单设定由nontbot侧完成
# 消息过滤涉及到config有待更新
if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
# 用户屏蔽,不区分私聊/群聊
if event.user_id in global_config.ban_user_id:
return
user_info = UserInfo(
user_id=event.user_id,
user_nickname=event.sender.nickname,
user_cardname=event.sender.card or None,
platform="qq",
)
# 处理私聊消息
if isinstance(event, PrivateMessageEvent):
if not global_config.enable_friend_chat: # 私聊过滤
return
else:
try:
user_info = UserInfo(
user_id=event.user_id,
user_nickname=(await bot.get_stranger_info(user_id=event.user_id, no_cache=True))["nickname"],
user_cardname=None,
platform="qq",
)
except Exception as e:
logger.error(f"获取陌生人信息失败: {e}")
return
logger.debug(user_info)
group_info = GroupInfo(
group_id=event.group_id,
group_name=group_name, # 使用获取到的群名称或None
platform="qq",
)
# group_info = GroupInfo(group_id=0, group_name="私聊", platform="qq")
group_info = None
# 处理群聊消息
else:
# 白名单设定由nontbot侧完成
if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
user_info = UserInfo(
user_id=event.user_id,
user_nickname=event.sender.nickname,
user_cardname=event.sender.card or None,
platform="qq",
)
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
# group_info = await bot.get_group_info(group_id=event.group_id)
# sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
message_cq = MessageRecvCQ(
message_id=event.message_id,
@@ -88,7 +106,6 @@ class ChatBot:
# 进入maimbot
message = MessageRecv(message_json)
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
messageinfo = message.message_info
@@ -108,7 +125,9 @@ class ChatBot:
# 过滤词
for word in global_config.ban_words:
if word in message.processed_plain_text:
logger.info(f"[群{groupinfo.group_id}]{userinfo.user_nickname}:{message.processed_plain_text}")
logger.info(
f"[{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{userinfo.user_nickname}:{message.processed_plain_text}"
)
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return
@@ -116,7 +135,7 @@ class ChatBot:
for pattern in global_config.ban_msgs_regex:
if re.search(pattern, message.raw_message):
logger.info(
f"[{message.message_info.group_info.group_id}]{message.user_nickname}:{message.raw_message}"
f"[{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{message.user_nickname}:{message.raw_message}"
)
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return
@@ -124,8 +143,8 @@ class ChatBot:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
# topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
topic = ""
interested_rate = 0
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100
logger.debug(f"{message.processed_plain_text}的激活度:{interested_rate}")
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
@@ -144,7 +163,7 @@ class ChatBot:
current_willing = willing_manager.get_willing(chat_stream=chat)
logger.info(
f"[{current_time}][{chat.group_info.group_id}]{chat.user_info.user_nickname}:"
f"[{current_time}][{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{chat.user_info.user_nickname}:"
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]"
)
@@ -152,12 +171,17 @@ class ChatBot:
if random() < reply_probability:
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ, user_nickname=global_config.BOT_NICKNAME, platform=messageinfo.platform
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
platform=messageinfo.platform,
)
thinking_time_point = round(time.time(), 2)
think_id = "mt" + str(thinking_time_point)
thinking_message = MessageThinking(
message_id=think_id, chat_stream=chat, bot_user_info=bot_user_info, reply=message
message_id=think_id,
chat_stream=chat,
bot_user_info=bot_user_info,
reply=message,
)
message_manager.add_message(thinking_message)
@@ -196,15 +220,16 @@ class ChatBot:
# print(f"\033[1;32m[回复内容]\033[0m {msg}")
# 通过时间改变时间戳
typing_time = calculate_typing_time(msg)
print(f"typing_time: {typing_time}")
logger.debug(f"typing_time: {typing_time}")
accu_typing_time += typing_time
timepoint = thinking_time_point + accu_typing_time
message_segment = Seg(type="text", data=msg)
print(f"message_segment: {message_segment}")
# logger.debug(f"message_segment: {message_segment}")
bot_message = MessageSending(
message_id=think_id,
chat_stream=chat,
bot_user_info=bot_user_info,
sender_info=userinfo,
message_segment=message_segment,
reply=message,
is_head=not mark_head,
@@ -218,7 +243,9 @@ class ChatBot:
# message_set 可以直接加入 message_manager
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
print(f"添加message_set到message_manager")
logger.debug("添加message_set到message_manager")
message_manager.add_message(message_set)
bot_response_time = thinking_time_point
@@ -242,6 +269,7 @@ class ChatBot:
message_id=think_id,
chat_stream=chat,
bot_user_info=bot_user_info,
sender_info=userinfo,
message_segment=message_segment,
reply=message,
is_head=False,

View File

@@ -111,11 +111,11 @@ class ChatManager:
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.db.list_collection_names():
self.db.db.create_collection("chat_streams")
if "chat_streams" not in self.db.list_collection_names():
self.db.create_collection("chat_streams")
# 创建索引
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index(
self.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
@@ -168,7 +168,7 @@ class ChatManager:
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
data = self.db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
@@ -204,7 +204,7 @@ class ChatManager:
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
self.db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
@@ -216,7 +216,7 @@ class ChatManager:
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})
all_streams = self.db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream

View File

@@ -37,8 +37,7 @@ class BotConfig:
ban_user_id = set()
build_memory_interval: int = 30 # 记忆构建间隔(秒)
forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
EMOJI_SAVE: bool = True # 偷表情包
@@ -69,6 +68,7 @@ class BotConfig:
enable_advance_output: bool = False # 是否启用高级输出
enable_kuuki_read: bool = True # 是否启用读空气功能
enable_debug_output: bool = False # 是否启用调试输出
enable_friend_chat: bool = False # 是否启用好友聊天
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
@@ -94,7 +94,13 @@ class BotConfig:
PERSONALITY_1: float = 0.6 # 第一种人格概率
PERSONALITY_2: float = 0.3 # 第二种人格概率
PERSONALITY_3: float = 0.1 # 第三种人格概率
build_memory_interval: int = 600 # 记忆构建间隔(秒)
forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
memory_forget_time: int = 24 # 记忆遗忘时间(小时)
memory_forget_percentage: float = 0.01 # 记忆遗忘比例
memory_compress_rate: float = 0.1 # 记忆压缩率
memory_ban_words: list = field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
) # 添加新的配置项默认值
@@ -293,6 +299,11 @@ class BotConfig:
# 在版本 >= 0.0.4 时才处理新增的配置项
if config.INNER_VERSION in SpecifierSet(">=0.0.4"):
config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
if config.INNER_VERSION in SpecifierSet(">=0.0.7"):
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
config.memory_forget_percentage = memory_config.get("memory_forget_percentage", config.memory_forget_percentage)
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
def mood(parent: dict):
mood_config = parent["mood"]
@@ -327,7 +338,9 @@ class BotConfig:
others_config = parent["others"]
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
config.enable_debug_output = others_config.get("enable_debug_output", config.enable_debug_output)
if config.INNER_VERSION in SpecifierSet(">=0.0.7"):
config.enable_debug_output = others_config.get("enable_debug_output", config.enable_debug_output)
config.enable_friend_chat = others_config.get("enable_friend_chat", config.enable_friend_chat)
# 版本表达式:>=1.0.0,<2.0.0
# 允许字段func: method, support: str, notice: str, necessary: bool

View File

@@ -76,16 +76,16 @@ class EmojiManager:
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if 'emoji' not in self.db.db.list_collection_names():
self.db.db.create_collection('emoji')
self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('filename', 1)], unique=True)
if 'emoji' not in self.db.list_collection_names():
self.db.create_collection('emoji')
self.db.emoji.create_index([('embedding', '2dsphere')])
self.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str):
"""记录表情使用次数"""
try:
self._ensure_db()
self.db.db.emoji.update_one(
self.db.emoji.update_one(
{'_id': emoji_id},
{'$inc': {'usage_count': 1}}
)
@@ -119,7 +119,7 @@ class EmojiManager:
try:
# 获取所有表情包
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
if not all_emojis:
logger.warning("数据库中没有任何表情包")
@@ -157,10 +157,11 @@ class EmojiManager:
if selected_emoji and 'path' in selected_emoji:
# 更新使用次数
self.db.db.emoji.update_one(
self.db.emoji.update_one(
{'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}}
)
logger.success(
f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})")
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
@@ -176,8 +177,10 @@ class EmojiManager:
logger.error(f"获取表情包失败: {str(e)}")
return None
async def _get_emoji_discription(self, image_base64: str) -> str:
"""获取表情包的标签使用image_manager的描述生成功能"""
try:
# 使用image_manager获取描述去掉前后的方括号和"表情包:"前缀
description = await image_manager.get_emoji_description(image_base64)
@@ -236,7 +239,7 @@ class EmojiManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
existing_emoji = self.db['emoji'].find_one({'filename': filename})
description = None
if existing_emoji:
@@ -272,11 +275,14 @@ class EmojiManager:
# 获取表情包的描述
description = await self._get_emoji_discription(image_base64)
if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64)
if '' not in check:
os.remove(image_path)
logger.info(f"描述: {description}")
logger.info(f"描述: {description}")
logger.info(f"其不满足过滤规则,被剔除 {check}")
continue
@@ -287,6 +293,7 @@ class EmojiManager:
if description is not None:
embedding = await get_embedding(description)
# 准备数据库记录
emoji_record = {
'filename': filename,
@@ -298,9 +305,10 @@ class EmojiManager:
}
# 保存到emoji数据库
self.db.db['emoji'].insert_one(emoji_record)
self.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}")
# 保存到images数据库
image_doc = {
@@ -338,7 +346,7 @@ class EmojiManager:
try:
self._ensure_db()
# 获取所有表情包记录
all_emojis = list(self.db.db.emoji.find())
all_emojis = list(self.db.emoji.find())
removed_count = 0
total_count = len(all_emojis)
@@ -346,13 +354,13 @@ class EmojiManager:
try:
if 'path' not in emoji:
logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']})
self.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']})
self.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
@@ -360,7 +368,7 @@ class EmojiManager:
if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录
result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
result = self.db.emoji.delete_one({'_id': emoji['_id']})
if result.deleted_count > 0:
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1
@@ -371,7 +379,7 @@ class EmojiManager:
continue
# 验证清理结果
remaining_count = self.db.db.emoji.count_documents({})
remaining_count = self.db.emoji.count_documents({})
if removed_count > 0:
logger.success(f"已清理 {removed_count} 个失效的表情包记录")
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")
@@ -389,5 +397,7 @@ class EmojiManager:
# 创建全局单例
emoji_manager = EmojiManager()

View File

@@ -8,7 +8,7 @@ from loguru import logger
from ...common.database import Database
from ..models.utils_model import LLM_request
from .config import global_config
from .message import MessageRecv, MessageThinking, MessageSending,Message
from .message import MessageRecv, MessageThinking, Message
from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager
from .utils import process_llm_response
@@ -154,7 +154,7 @@ class ResponseGenerator:
reasoning_content: str,
):
"""保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one(
self.db.reasoning_logs.insert_one(
{
"time": time.time(),
"chat_id": message.chat_stream.stream_id,

View File

@@ -1,19 +1,25 @@
import time
import html
import re
import json
from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional, Union
from typing import Dict, List, Optional
import urllib3
from loguru import logger
from .utils_image import image_manager
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream, chat_manager
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
# 这个类是消息数据类,用于存储和管理消息数据。
# 它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
# 它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message(MessageBase):
@@ -61,14 +67,28 @@ class Message(MessageBase):
@dataclass
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: Dict):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
self.message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {}))
message_segment = message_dict.get('message_segment', {})
if message_segment.get('data','') == '[json]':
# 提取json消息中的展示信息
pattern = r'\[CQ:json,data=(?P<json_data>.+?)\]'
match = re.search(pattern, message_dict.get('raw_message',''))
raw_json = html.unescape(match.group('json_data'))
try:
json_message = json.loads(raw_json)
except json.JSONDecodeError:
json_message = {}
message_segment['data'] = json_message.get('prompt','')
self.message_segment = Seg.from_dict(message_dict.get('message_segment', {}))
self.raw_message = message_dict.get('raw_message')
@@ -83,68 +103,74 @@ class MessageRecv(Message):
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.processed_plain_text = await self._process_message_segments(
self.message_segment
)
self.detailed_plain_text = self._generate_detailed_text()
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
if seg.type == "text":
return seg.data
elif seg.type == 'image':
elif seg.type == "image":
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
self.is_emoji=True
return "[图片]"
elif seg.type == "emoji":
self.is_emoji = True
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
return "[表情]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
logger.error(
f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}"
)
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
time_str = time.strftime(
"%m-%d %H:%M:%S", time.localtime(self.message_info.time)
)
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname!=''
if user_info.user_cardname != ""
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@@ -153,14 +179,14 @@ class MessageRecv(Message):
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None
reply: Optional["MessageRecv"] = None,
):
# 调用父类初始化
super().__init__(
@@ -169,7 +195,7 @@ class MessageProcessBase(Message):
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
reply=reply
reply=reply,
)
# 处理状态相关属性
@@ -183,78 +209,83 @@ class MessageProcessBase(Message):
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
if seg.type == "text":
return seg.data
elif seg.type == 'image':
elif seg.type == "image":
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
return "[图片]"
elif seg.type == "emoji":
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
elif seg.type == 'at':
return "[表情]"
elif seg.type == "at":
return f"[@{seg.data}]"
elif seg.type == 'reply':
if self.reply and hasattr(self.reply, 'processed_plain_text'):
elif seg.type == "reply":
if self.reply and hasattr(self.reply, "processed_plain_text"):
return f"[回复:{self.reply.processed_plain_text}]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
logger.error(
f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}"
)
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
time_str = time.strftime(
"%m-%d %H:%M:%S", time.localtime(self.message_info.time)
)
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname != ''
if user_info.user_cardname != ""
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageThinking(MessageProcessBase):
"""思考状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
reply: Optional['MessageRecv'] = None
reply: Optional["MessageRecv"] = None,
):
# 调用父类初始化
super().__init__(
@@ -262,25 +293,27 @@ class MessageThinking(MessageProcessBase):
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=None, # 思考状态不需要消息段
reply=reply
reply=reply,
)
# 思考状态特有属性
self.interrupt = False
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
sender_info: UserInfo, # 用来记录发送者信息,用于私聊回复
message_segment: Seg,
reply: Optional['MessageRecv'] = None,
reply: Optional["MessageRecv"] = None,
is_head: bool = False,
is_emoji: bool = False
is_emoji: bool = False,
):
# 调用父类初始化
super().__init__(
@@ -288,28 +321,34 @@ class MessageSending(MessageProcessBase):
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=reply
reply=reply,
)
# 发送状态特有属性
self.sender_info = sender_info
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
self.is_emoji = is_emoji
def set_reply(self, reply: Optional['MessageRecv']) -> None:
def set_reply(self, reply: Optional["MessageRecv"]) -> None:
"""设置回复消息"""
if reply:
self.reply = reply
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(type='seglist', data=[
Seg(type='reply', data=reply.message_info.message_id),
self.message_segment
])
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=reply.message_info.message_id),
self.message_segment,
],
)
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.processed_plain_text = await self._process_message_segments(
self.message_segment
)
self.detailed_plain_text = self._generate_detailed_text()
@classmethod
@@ -318,8 +357,8 @@ class MessageSending(MessageProcessBase):
thinking: MessageThinking,
message_segment: Seg,
is_head: bool = False,
is_emoji: bool = False
) -> 'MessageSending':
is_emoji: bool = False,
) -> "MessageSending":
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
@@ -328,41 +367,50 @@ class MessageSending(MessageProcessBase):
bot_user_info=thinking.message_info.user_info,
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji
is_emoji=is_emoji,
)
def to_dict(self):
ret= super().to_dict()
ret['message_info']['user_info']=self.chat_stream.user_info.to_dict()
ret = super().to_dict()
ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict()
return ret
def is_private_message(self) -> bool:
"""判断是否为私聊消息"""
return (
self.message_info.group_info is None
or self.message_info.group_info.group_id is None
)
@dataclass
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, chat_stream: ChatStream, message_id: str):
self.chat_stream = chat_stream
self.message_id = message_id
self.messages: List[MessageSending] = []
self.time = round(time.time(), 2)
def add_message(self, message: MessageSending) -> None:
"""添加消息到集合"""
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time)
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
if 0 <= index < len(self.messages):
return self.messages[index]
return None
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
if not self.messages:
return None
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
@@ -370,25 +418,22 @@ class MessageSet:
left = mid + 1
else:
right = mid
return self.messages[left]
def clear_messages(self) -> None:
"""清空所有消息"""
self.messages.clear()
def remove_message(self, message: MessageSending) -> bool:
"""移除指定消息"""
if message in self.messages:
self.messages.remove(message)
return True
return False
def __str__(self) -> str:
return f"MessageSet(id={self.message_id}, count={len(self.messages)})"
def __len__(self) -> int:
return len(self.messages)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Any, Dict
from typing import List, Optional, Union, Dict
@dataclass
class Seg:

View File

@@ -1,55 +1,47 @@
import time
from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional, Union
from typing import Dict, Optional
import urllib3
from .cq_code import CQCode, cq_code_tool
from .cq_code import cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
from .utils_user import get_groupname
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
# 这个类是消息数据类,用于存储和管理消息数据。
# 它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
# 它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageCQ(MessageBase):
"""QQ消息基类继承自MessageBase
最小必要参数:
- message_id: 消息ID
- user_id: 发送者/接收者ID
- platform: 平台标识(默认为"qq"
"""
def __init__(
self,
message_id: int,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
platform: str = "qq"
self, message_id: int, user_info: UserInfo, group_info: Optional[GroupInfo] = None, platform: str = "qq"
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform,
message_id=message_id,
time=int(time.time()),
group_info=group_info,
user_info=user_info
platform=platform, message_id=message_id, time=int(time.time()), group_info=group_info, user_info=user_info
)
# 调用父类初始化message_segment 由子类设置
super().__init__(
message_info=message_info,
message_segment=None,
raw_message=None
)
super().__init__(message_info=message_info, message_segment=None, raw_message=None)
@dataclass
class MessageRecvCQ(MessageCQ):
"""QQ接收消息类用于解析raw_message到Seg对象"""
def __init__(
self,
message_id: int,
@@ -62,9 +54,13 @@ class MessageRecvCQ(MessageCQ):
# 调用父类初始化
super().__init__(message_id, user_info, group_info, platform)
if group_info and group_info.group_name is None:
# 私聊消息不携带group_info
if group_info is None:
pass
elif group_info.group_name is None:
group_info.group_name = get_groupname(group_info.group_id)
# 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message)
self.raw_message = raw_message
@@ -73,10 +69,10 @@ class MessageRecvCQ(MessageCQ):
"""解析消息内容为Seg对象"""
cq_code_dict_list = []
segments = []
start = 0
while True:
cq_start = message.find('[CQ:', start)
cq_start = message.find("[CQ:", start)
if cq_start == -1:
if start < len(message):
text = message[start:].strip()
@@ -89,81 +85,80 @@ class MessageRecvCQ(MessageCQ):
if text:
cq_code_dict_list.append(parse_cq_code(text))
cq_end = message.find(']', cq_start)
cq_end = message.find("]", cq_start)
if cq_end == -1:
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
cq_code = message[cq_start : cq_end + 1]
cq_code_dict_list.append(parse_cq_code(cq_code))
start = cq_end + 1
# 转换CQ码为Seg对象
for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message)
message_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message)
if message_obj.translated_segments:
segments.append(message_obj.translated_segments)
# 如果只有一个segment直接返回
if len(segments) == 1:
return segments[0]
# 否则返回seglist类型的Seg
return Seg(type='seglist', data=segments)
return Seg(type="seglist", data=segments)
def to_dict(self) -> Dict:
"""转换为字典格式,包含所有必要信息"""
base_dict = super().to_dict()
return base_dict
@dataclass
class MessageSendCQ(MessageCQ):
"""QQ发送消息类用于将Seg对象转换为raw_message"""
def __init__(
self,
data: Dict
):
def __init__(self, data: Dict):
# 调用父类初始化
message_info = BaseMessageInfo.from_dict(data.get('message_info', {}))
message_segment = Seg.from_dict(data.get('message_segment', {}))
message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
message_segment = Seg.from_dict(data.get("message_segment", {}))
super().__init__(
message_info.message_id,
message_info.user_info,
message_info.group_info if message_info.group_info else None,
message_info.platform
)
message_info.message_id,
message_info.user_info,
message_info.group_info if message_info.group_info else None,
message_info.platform,
)
self.message_segment = message_segment
self.raw_message = self._generate_raw_message()
def _generate_raw_message(self, ) -> str:
def _generate_raw_message(
self,
) -> str:
"""将Seg对象转换为raw_message"""
segments = []
# 处理消息段
if self.message_segment.type == 'seglist':
if self.message_segment.type == "seglist":
for seg in self.message_segment.data:
segments.append(self._seg_to_cq_code(seg))
else:
segments.append(self._seg_to_cq_code(self.message_segment))
return ''.join(segments)
return "".join(segments)
def _seg_to_cq_code(self, seg: Seg) -> str:
"""将单个Seg对象转换为CQ码字符串"""
if seg.type == 'text':
if seg.type == "text":
return str(seg.data)
elif seg.type == 'image':
elif seg.type == "image":
return cq_code_tool.create_image_cq_base64(seg.data)
elif seg.type == 'emoji':
elif seg.type == "emoji":
return cq_code_tool.create_emoji_cq_base64(seg.data)
elif seg.type == 'at':
elif seg.type == "at":
return f"[CQ:at,qq={seg.data}]"
elif seg.type == 'reply':
elif seg.type == "reply":
return cq_code_tool.create_reply_cq(int(seg.data))
else:
return f"[{seg.data}]"

View File

@@ -5,12 +5,12 @@ from typing import Dict, List, Optional, Union
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool
from .message_cq import MessageSendCQ
from .message import MessageSending, MessageThinking, MessageRecv,MessageSet
from .message import MessageSending, MessageThinking, MessageRecv, MessageSet
from .storage import MessageStorage
from .config import global_config
from .chat_stream import chat_manager
from .utils import truncate_message
class Message_Sender:
@@ -26,49 +26,54 @@ class Message_Sender:
self._current_bot = bot
async def send_message(
self,
message: MessageSending,
self,
message: MessageSending,
) -> None:
"""发送消息"""
if isinstance(message, MessageSending):
message_json = message.to_dict()
message_send=MessageSendCQ(
data=message_json
)
if message_send.message_info.group_info:
message_send = MessageSendCQ(data=message_json)
# logger.debug(message_send.message_info,message_send.raw_message)
message_preview = truncate_message(message.processed_plain_text)
if (
message_send.message_info.group_info
and message_send.message_info.group_info.group_id
):
try:
await self._current_bot.send_group_msg(
group_id=message.message_info.group_info.group_id,
message=message_send.raw_message,
auto_escape=False
auto_escape=False,
)
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
logger.success(f"[调试] 发送消息{message_preview}成功")
except Exception as e:
logger.error(f"[调试] 发生错误 {e}")
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
logger.error(f"[调试] 发送消息{message_preview}失败")
else:
try:
logger.debug(message.message_info.user_info)
await self._current_bot.send_private_msg(
user_id=message.message_info.user_info.user_id,
user_id=message.sender_info.user_id,
message=message_send.raw_message,
auto_escape=False
auto_escape=False,
)
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
logger.success(f"[调试] 发送消息{message_preview}成功")
except Exception as e:
logger.error(f"发生错误 {e}")
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
logger.error(f"[调试] 发生错误 {e}")
logger.error(f"[调试] 发送消息{message_preview}失败")
class MessageContainer:
"""单个聊天流的发送/思考消息容器"""
def __init__(self, chat_id: str, max_size: int = 100):
self.chat_id = chat_id
self.max_size = max_size
self.messages = []
self.last_send_time = 0
self.thinking_timeout = 20 # 思考超时时间(秒)
def get_timeout_messages(self) -> List[MessageSending]:
"""获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序"""
current_time = time.time()
@@ -83,12 +88,12 @@ class MessageContainer:
timeout_messages.sort(key=lambda x: x.thinking_start_time)
return timeout_messages
def get_earliest_message(self) -> Optional[Union[MessageThinking, MessageSending]]:
"""获取thinking_start_time最早的消息对象"""
if not self.messages:
return None
earliest_time = float('inf')
earliest_time = float("inf")
earliest_message = None
for msg in self.messages:
msg_time = msg.thinking_start_time
@@ -96,7 +101,7 @@ class MessageContainer:
earliest_time = msg_time
earliest_message = msg
return earliest_message
def add_message(self, message: Union[MessageThinking, MessageSending]) -> None:
"""添加消息到队列"""
if isinstance(message, MessageSet):
@@ -104,7 +109,7 @@ class MessageContainer:
self.messages.append(single_message)
else:
self.messages.append(message)
def remove_message(self, message: Union[MessageThinking, MessageSending]) -> bool:
"""移除消息如果消息存在则返回True否则返回False"""
try:
@@ -119,7 +124,7 @@ class MessageContainer:
def has_messages(self) -> bool:
"""检查是否有待发送的消息"""
return bool(self.messages)
def get_all_messages(self) -> List[Union[MessageSending, MessageThinking]]:
"""获取所有消息"""
return list(self.messages)
@@ -127,72 +132,91 @@ class MessageContainer:
class MessageManager:
"""管理所有聊天流的消息容器"""
def __init__(self):
self.containers: Dict[str, MessageContainer] = {} # chat_id -> MessageContainer
self.storage = MessageStorage()
self._running = True
def get_container(self, chat_id: str) -> MessageContainer:
"""获取或创建聊天流的消息容器"""
if chat_id not in self.containers:
self.containers[chat_id] = MessageContainer(chat_id)
return self.containers[chat_id]
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
def add_message(
self, message: Union[MessageThinking, MessageSending, MessageSet]
) -> None:
chat_stream = message.chat_stream
if not chat_stream:
raise ValueError("无法找到对应的聊天流")
container = self.get_container(chat_stream.stream_id)
container.add_message(message)
async def process_chat_messages(self, chat_id: str):
"""处理聊天流消息"""
container = self.get_container(chat_id)
if container.has_messages():
# print(f"处理有message的容器chat_id: {chat_id}")
message_earliest = container.get_earliest_message()
if isinstance(message_earliest, MessageThinking):
message_earliest.update_thinking_time()
thinking_time = message_earliest.thinking_time
print(f"消息正在思考中,已思考{int(thinking_time)}\r", end='', flush=True)
print(
f"消息正在思考中,已思考{int(thinking_time)}\r",
end="",
flush=True,
)
# 检查是否超时
if thinking_time > global_config.thinking_timeout:
logger.warning(f"消息思考超时({thinking_time}秒),移除该消息")
container.remove_message(message_earliest)
else:
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
if (
message_earliest.is_head
and message_earliest.update_thinking_time() > 30
and not message_earliest.is_private_message() # 避免在私聊时插入reply
):
await message_sender.send_message(message_earliest.set_reply())
else:
await message_sender.send_message(message_earliest)
await message_earliest.process()
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
await self.storage.store_message(message_earliest, message_earliest.chat_stream,None)
print(
f"\033[1;34m[调试]\033[0m 消息“{truncate_message(message_earliest.processed_plain_text)}”正在发送中"
)
await self.storage.store_message(
message_earliest, message_earliest.chat_stream, None
)
container.remove_message(message_earliest)
message_timeout = container.get_timeout_messages()
if message_timeout:
logger.warning(f"发现{len(message_timeout)}条超时消息")
for msg in message_timeout:
if msg == message_earliest:
continue
try:
if msg.is_head and msg.update_thinking_time() > 30:
if (
msg.is_head
and msg.update_thinking_time() > 30
and not message_earliest.is_private_message() # 避免在私聊时插入reply
):
await message_sender.send_message(msg.set_reply())
else:
await message_sender.send_message(msg)
# if msg.is_emoji:
# msg.processed_plain_text = "[表情包]"
await msg.process()
await self.storage.store_message(msg,msg.chat_stream, None)
await self.storage.store_message(msg, msg.chat_stream, None)
if not container.remove_message(msg):
logger.warning("尝试删除不存在的消息")
except Exception:
@@ -206,7 +230,7 @@ class MessageManager:
tasks = []
for chat_id in self.containers.keys():
tasks.append(self.process_chat_messages(chat_id))
await asyncio.gather(*tasks)

View File

@@ -9,7 +9,7 @@ from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule
from .config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text
from .chat_stream import ChatStream, chat_manager
from .chat_stream import chat_manager
class PromptBuilder:
@@ -311,7 +311,7 @@ class PromptBuilder:
{"$project": {"content": 1, "similarity": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
results = list(self.db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results:

View File

@@ -1,6 +1,5 @@
import asyncio
from typing import Optional, Union
from typing import Optional, Union
from typing import Optional
from loguru import logger
from ...common.database import Database
@@ -169,7 +168,7 @@ class RelationshipManager:
async def load_all_relationships(self):
"""加载所有关系对象"""
db = Database.get_instance()
all_relationships = db.db.relationships.find({})
all_relationships = db.relationships.find({})
for data in all_relationships:
await self.load_relationship(data)
@@ -177,7 +176,7 @@ class RelationshipManager:
"""每5分钟自动保存一次关系数据"""
db = Database.get_instance()
# 获取所有关系记录
all_relationships = db.db.relationships.find({})
all_relationships = db.relationships.find({})
# 依次加载每条记录
for data in all_relationships:
await self.load_relationship(data)
@@ -207,7 +206,7 @@ class RelationshipManager:
saved = relationship.saved
db = Database.get_instance()
db.db.relationships.update_one(
db.relationships.update_one(
{'user_id': user_id, 'platform': platform},
{'$set': {
'platform': platform,

View File

@@ -1,8 +1,6 @@
from typing import Optional, Union
from typing import Optional, Union
from ...common.database import Database
from .message_base import MessageBase
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from loguru import logger
@@ -25,7 +23,7 @@ class MessageStorage:
"detailed_plain_text": message.detailed_plain_text,
"topic": topic,
}
self.db.db.messages.insert_one(message_data)
self.db.messages.insert_one(message_data)
except Exception:
logger.exception("存储消息失败")

View File

@@ -12,8 +12,8 @@ from loguru import logger
from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config
from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo
from .message import MessageRecv,Message
from .message_base import UserInfo
from .chat_stream import ChatStream
from ..moods.moods import MoodManager
@@ -39,9 +39,13 @@ def db_message_to_str(message_dict: Dict) -> str:
def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
nicknames = global_config.BOT_ALIAS_NAMES
for keyword in keywords:
if keyword in message.processed_plain_text:
return True
for nickname in nicknames:
if nickname in message.processed_plain_text:
return True
return False
@@ -402,3 +406,10 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
# 按相似度降序排序并返回前k个
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
if len(message) > max_length:
return message[:max_length] + "..."
return message

View File

@@ -1,16 +1,12 @@
import base64
import io
import os
import time
import zlib
import aiohttp
import hashlib
from typing import Optional, Tuple, Union
from urllib.parse import urlparse
from typing import Optional, Union
from loguru import logger
from nonebot import get_driver
from PIL import Image
from ...common.database import Database
from ..chat.config import global_config
@@ -44,20 +40,20 @@ class ImageManager:
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if 'images' not in self.db.db.list_collection_names():
self.db.db.create_collection('images')
if 'images' not in self.db.list_collection_names():
self.db.create_collection('images')
# 创建索引
self.db.db.images.create_index([('hash', 1)], unique=True)
self.db.db.images.create_index([('url', 1)])
self.db.db.images.create_index([('path', 1)])
self.db.images.create_index([('hash', 1)], unique=True)
self.db.images.create_index([('url', 1)])
self.db.images.create_index([('path', 1)])
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.db.list_collection_names():
self.db.db.create_collection('image_descriptions')
if 'image_descriptions' not in self.db.list_collection_names():
self.db.create_collection('image_descriptions')
# 创建索引
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)])
self.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.image_descriptions.create_index([('type', 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -69,7 +65,7 @@ class ImageManager:
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result= self.db.db.image_descriptions.find_one({
result= self.db.image_descriptions.find_one({
'hash': image_hash,
'type': description_type
})
@@ -83,7 +79,7 @@ class ImageManager:
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
self.db.db.image_descriptions.update_one(
self.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type},
{
'$set': {
@@ -125,7 +121,7 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查重
existing = self.db.db.images.find_one({'hash': image_hash})
existing = self.db.images.find_one({'hash': image_hash})
if existing:
return existing['path']
@@ -146,7 +142,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.db.images.insert_one(image_doc)
self.db.images.insert_one(image_doc)
return file_path
@@ -163,7 +159,7 @@ class ImageManager:
"""
try:
# 先查找是否已存在
existing = self.db.db.images.find_one({'url': url})
existing = self.db.images.find_one({'url': url})
if existing:
return existing['path']
@@ -207,7 +203,7 @@ class ImageManager:
Returns:
bool: 是否存在
"""
return self.db.db.images.find_one({'url': url}) is not None
return self.db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
@@ -230,7 +226,7 @@ class ImageManager:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.db.images.find_one({'hash': image_hash}) is not None
return self.db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
@@ -273,7 +269,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
self.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
@@ -330,7 +326,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
self.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True

View File

@@ -1,13 +1,9 @@
import asyncio
from typing import Dict
from loguru import logger
from typing import Dict
from loguru import logger
from .config import global_config
from .message_base import UserInfo, GroupInfo
from .chat_stream import chat_manager,ChatStream
from .chat_stream import ChatStream
class WillingManager:

View File

@@ -0,0 +1,10 @@
from nonebot import get_app
from .api import router
from loguru import logger
# 获取主应用实例并挂载路由
app = get_app()
app.include_router(router, prefix="/api")
# 打印日志方便确认API已注册
logger.success("配置重载API已注册可通过 /api/reload-config 访问")

View File

@@ -0,0 +1,17 @@
from fastapi import APIRouter, HTTPException
from src.plugins.chat.config import BotConfig
import os
# 创建APIRouter而不是FastAPI实例
router = APIRouter()
@router.post("/reload-config")
async def reload_config():
try:
bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml")
global_config = BotConfig.load_config(config_path=bot_config_path)
return {"message": "配置重载成功", "status": "success"}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}")

View File

@@ -0,0 +1,3 @@
import requests
response = requests.post("http://localhost:8080/api/reload-config")
print(response.json())

View File

@@ -1,199 +0,0 @@
import os
import sys
import time
import requests
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.dev")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
from src.common.database import Database
# 从环境变量获取配置
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class KnowledgeLibrary:
def __init__(self):
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()['data'][0]['embedding']
def process_files(self):
"""处理raw_info目录下的所有txt文件"""
for filename in os.listdir(self.raw_info_dir):
if filename.endswith('.txt'):
file_path = os.path.join(self.raw_info_dir, filename)
self.process_single_file(file_path)
def process_single_file(self, file_path: str):
"""处理单个文件"""
try:
# 检查文件是否已处理
if self.db.db.processed_files.find_one({"file_path": file_path}):
print(f"文件已处理过,跳过: {file_path}")
return
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 按1024字符分段
segments = [content[i:i+600] for i in range(0, len(content), 300)]
# 处理每个分段
for segment in segments:
if not segment.strip(): # 跳过空段
continue
# 获取embedding
embedding = self.get_embedding(segment)
if not embedding:
continue
# 存储到数据库
doc = {
"content": segment,
"embedding": embedding,
"file_path": file_path,
"segment_length": len(segment)
}
# 使用文本内容的哈希值作为唯一标识
content_hash = hash(segment)
# 更新或插入文档
self.db.db.knowledges.update_one(
{"content_hash": content_hash},
{"$set": doc},
upsert=True
)
# 记录文件已处理
self.db.db.processed_files.insert_one({
"file_path": file_path,
"processed_time": time.time()
})
print(f"成功处理文件: {file_path}")
except Exception as e:
print(f"处理文件 {file_path} 时出错: {str(e)}")
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]}
]}
]
}
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
}
}
},
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
# 测试知识库功能
print("开始处理知识库文件...")
knowledge_library.process_files()
# 测试搜索功能
test_query = "麦麦评价一下僕と花"
print(f"\n搜索与'{test_query}'相似的内容:")
results = knowledge_library.search_similar_segments(test_query)
for result in results:
print(f"相似度: {result['similarity']:.4f}")
print(f"内容: {result['content'][:100]}...")
print("-" * 50)

View File

@@ -96,7 +96,7 @@ class Memory_graph:
dot_data = {
"concept": node
}
self.db.db.store_memory_dots.insert_one(dot_data)
self.db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
@@ -106,7 +106,7 @@ class Memory_graph:
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
@@ -115,7 +115,7 @@ class Memory_graph:
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(
self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
@@ -130,34 +130,34 @@ class Memory_graph:
def save_graph_to_db(self):
# 清空现有的图数据
self.db.db.graph_data.delete_many({})
self.db.graph_data.delete_many({})
# 保存节点
for node in self.G.nodes(data=True):
node_data = {
'concept': node[0],
'memory_items': node[1].get('memory_items', []) # 默认为空列表
}
self.db.db.graph_data.nodes.insert_one(node_data)
self.db.graph_data.nodes.insert_one(node_data)
# 保存边
for edge in self.G.edges():
edge_data = {
'source': edge[0],
'target': edge[1]
}
self.db.db.graph_data.edges.insert_one(edge_data)
self.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = self.db.db.graph_data.nodes.find()
nodes = self.db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边
edges = self.db.db.graph_data.edges.find()
edges = self.db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge['source'], edge['target'])

View File

@@ -303,7 +303,7 @@ class Hippocampus:
return topic_num
async def operation_build_memory(self, chat_size=20):
time_frequency = {'near': 3, 'mid': 8, 'far': 5}
time_frequency = {'near': 1, 'mid': 4, 'far': 4}
memory_samples = self.get_memory_sample(chat_size, time_frequency)
for i, messages in enumerate(memory_samples, 1):
@@ -315,7 +315,7 @@ class Hippocampus:
bar = '' * filled_length + '-' * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
compress_rate = 0.1
compress_rate = global_config.memory_compress_rate
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
@@ -523,9 +523,14 @@ class Hippocampus:
async def operation_forget_topic(self, percentage=0.1):
"""随机选择图中一定比例的节点和边进行检查,根据时间条件决定是否遗忘"""
# 检查数据库是否为空
all_nodes = list(self.memory_graph.G.nodes())
all_edges = list(self.memory_graph.G.edges())
if not all_nodes and not all_edges:
logger.info("记忆图为空,无需进行遗忘操作")
return
check_nodes_count = max(1, int(len(all_nodes) * percentage))
check_edges_count = max(1, int(len(all_edges) * percentage))
@@ -546,7 +551,7 @@ class Hippocampus:
# print(f"float(last_modified):{float(last_modified)}" )
# print(f"current_time:{current_time}")
# print(f"current_time - last_modified:{current_time - last_modified}")
if current_time - last_modified > 3600*24: # test
if current_time - last_modified > 3600*global_config.memory_forget_time: # test
current_strength = edge_data.get('strength', 1)
new_strength = current_strength - 1
@@ -887,15 +892,6 @@ config = driver.config
start_time = time.time()
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
# 创建记忆图
memory_graph = Memory_graph()
# 创建海马体

View File

@@ -10,7 +10,6 @@ from pathlib import Path
import matplotlib.pyplot as plt
import networkx as nx
import pymongo
from dotenv import load_dotenv
from loguru import logger
import jieba

View File

@@ -41,10 +41,10 @@ class LLM_request:
"""初始化数据库集合"""
try:
# 创建llm_usage集合的索引
self.db.db.llm_usage.create_index([("timestamp", 1)])
self.db.db.llm_usage.create_index([("model_name", 1)])
self.db.db.llm_usage.create_index([("user_id", 1)])
self.db.db.llm_usage.create_index([("request_type", 1)])
self.db.llm_usage.create_index([("timestamp", 1)])
self.db.llm_usage.create_index([("model_name", 1)])
self.db.llm_usage.create_index([("user_id", 1)])
self.db.llm_usage.create_index([("request_type", 1)])
except Exception:
logger.error("创建数据库索引失败")
@@ -73,7 +73,7 @@ class LLM_request:
"status": "success",
"timestamp": datetime.now()
}
self.db.db.llm_usage.insert_one(usage_data)
self.db.llm_usage.insert_one(usage_data)
logger.info(
f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "

View File

@@ -14,16 +14,6 @@ from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class ScheduleGenerator:
def __init__(self):
# 根据global_config.llm_normal这一字典配置指定模型
@@ -56,7 +46,7 @@ class ScheduleGenerator:
schedule_text = str
existing_schedule = self.db.db.schedule.find_one({"date": date_str})
existing_schedule = self.db.schedule.find_one({"date": date_str})
if existing_schedule:
logger.debug(f"{date_str}的日程已存在:")
schedule_text = existing_schedule["schedule"]
@@ -73,7 +63,7 @@ class ScheduleGenerator:
try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
self.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
except Exception as e:
logger.error(f"生成日程失败: {str(e)}")
schedule_text = "生成日程时出错了"
@@ -153,7 +143,7 @@ class ScheduleGenerator:
"""打印完整的日程安排"""
if not self._parse_schedule(self.today_schedule_text):
logger.warning("今日日程有误,将在下次运行时重新生成")
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
self.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else:
logger.info("=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items():

View File

@@ -53,7 +53,7 @@ class LLMStatistics:
"costs_by_model": defaultdict(float)
}
cursor = self.db.db.llm_usage.find({
cursor = self.db.llm_usage.find({
"timestamp": {"$gte": start_time}
})

View File

@@ -0,0 +1,383 @@
import os
import sys
import time
import requests
from dotenv import load_dotenv
import hashlib
from datetime import datetime
from tqdm import tqdm
from rich.console import Console
from rich.table import Table
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 现在可以导入src模块
from src.common.database import Database
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.prod")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
class KnowledgeLibrary:
def __init__(self):
# 初始化数据库连接
if Database._instance is None:
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
self.console = Console()
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
def read_file(self, file_path: str) -> str:
"""读取文件内容"""
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
def split_content(self, content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,保持段落完整性
Args:
content: 要分割的文本内容
max_length: 每个块的最大长度
Returns:
list: 分割后的文本块列表
"""
# 首先按段落分割
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
chunks = []
current_chunk = []
current_length = 0
for para in paragraphs:
para_length = len(para)
# 如果单个段落就超过最大长度
if para_length > max_length:
# 如果当前chunk不为空先保存
if current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_length = 0
# 将长段落按句子分割
sentences = [s.strip() for s in para.replace('', '\n').replace('', '\n').replace('', '\n').split('\n') if s.strip()]
temp_chunk = []
temp_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if sentence_length > max_length:
# 如果单个句子超长,强制按长度分割
if temp_chunk:
chunks.append('\n'.join(temp_chunk))
temp_chunk = []
temp_length = 0
for i in range(0, len(sentence), max_length):
chunks.append(sentence[i:i + max_length])
elif temp_length + sentence_length + 1 <= max_length:
temp_chunk.append(sentence)
temp_length += sentence_length + 1
else:
chunks.append('\n'.join(temp_chunk))
temp_chunk = [sentence]
temp_length = sentence_length
if temp_chunk:
chunks.append('\n'.join(temp_chunk))
# 如果当前段落加上现有chunk不超过最大长度
elif current_length + para_length + 1 <= max_length:
current_chunk.append(para)
current_length += para_length + 1
else:
# 保存当前chunk并开始新的chunk
chunks.append('\n'.join(current_chunk))
current_chunk = [para]
current_length = para_length
# 添加最后一个chunk
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()['data'][0]['embedding']
def process_files(self, knowledge_length:int=512):
"""处理raw_info目录下的所有txt文件"""
txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith('.txt')]
if not txt_files:
self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir))
self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]")
return
total_stats = {
"processed_files": 0,
"total_chunks": 0,
"failed_files": [],
"skipped_files": []
}
self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]")
for filename in tqdm(txt_files, desc="处理文件进度"):
file_path = os.path.join(self.raw_info_dir, filename)
result = self.process_single_file(file_path, knowledge_length)
self._update_stats(total_stats, result, filename)
self._display_processing_results(total_stats)
def process_single_file(self, file_path: str, knowledge_length: int = 512):
"""处理单个文件"""
result = {
"status": "success",
"chunks_processed": 0,
"error": None
}
try:
current_hash = self.calculate_file_hash(file_path)
processed_record = self.db.db.processed_files.find_one({"file_path": file_path})
if processed_record:
if processed_record.get("hash") == current_hash:
if knowledge_length in processed_record.get("split_by", []):
result["status"] = "skipped"
return result
content = self.read_file(file_path)
chunks = self.split_content(content, knowledge_length)
for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False):
embedding = self.get_embedding(chunk)
if embedding:
knowledge = {
"content": chunk,
"embedding": embedding,
"source_file": file_path,
"split_length": knowledge_length,
"created_at": datetime.now()
}
self.db.db.knowledges.insert_one(knowledge)
result["chunks_processed"] += 1
split_by = processed_record.get("split_by", []) if processed_record else []
if knowledge_length not in split_by:
split_by.append(knowledge_length)
self.db.db.processed_files.update_one(
{"file_path": file_path},
{
"$set": {
"hash": current_hash,
"last_processed": datetime.now(),
"split_by": split_by
}
},
upsert=True
)
except Exception as e:
result["status"] = "failed"
result["error"] = str(e)
return result
def _update_stats(self, total_stats, result, filename):
"""更新总体统计信息"""
if result["status"] == "success":
total_stats["processed_files"] += 1
total_stats["total_chunks"] += result["chunks_processed"]
elif result["status"] == "failed":
total_stats["failed_files"].append((filename, result["error"]))
elif result["status"] == "skipped":
total_stats["skipped_files"].append(filename)
def _display_processing_results(self, stats):
"""显示处理结果统计"""
self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]")
table = Table(show_header=True, header_style="bold magenta")
table.add_column("统计项", style="dim")
table.add_column("数值")
table.add_row("成功处理文件数", str(stats["processed_files"]))
table.add_row("处理的知识块总数", str(stats["total_chunks"]))
table.add_row("跳过的文件数", str(len(stats["skipped_files"])))
table.add_row("失败的文件数", str(len(stats["failed_files"])))
self.console.print(table)
if stats["failed_files"]:
self.console.print("\n[bold red]处理失败的文件:[/bold red]")
for filename, error in stats["failed_files"]:
self.console.print(f"[red]- {filename}: {error}[/red]")
if stats["skipped_files"]:
self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]")
for filename in stats["skipped_files"]:
self.console.print(f"[yellow]- {filename}[/yellow]")
def calculate_file_hash(self, file_path):
"""计算文件的MD5哈希值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]}
]}
]
}
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
}
}
},
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
console = Console()
console.print("[bold green]知识库处理工具[/bold green]")
while True:
console.print("\n请选择要执行的操作:")
console.print("[1] 麦麦开始学习")
console.print("[2] 麦麦全部忘光光(仅知识)")
console.print("[q] 退出程序")
choice = input("\n请输入选项: ").strip()
if choice.lower() == 'q':
console.print("[yellow]程序退出[/yellow]")
sys.exit(0)
elif choice == '2':
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
if confirm == 'y':
knowledge_library.db.db.knowledges.delete_many({})
console.print("[green]已清空所有知识![/green]")
continue
elif choice == '1':
if not os.path.exists(knowledge_library.raw_info_dir):
console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]")
os.makedirs(knowledge_library.raw_info_dir, exist_ok=True)
# 询问分割长度
while True:
try:
length_input = input("请输入知识分割长度默认512输入q退出回车使用默认值: ").strip()
if length_input.lower() == 'q':
break
if not length_input: # 如果直接回车,使用默认值
knowledge_length = 512
break
knowledge_length = int(length_input)
if knowledge_length <= 0:
print("分割长度必须大于0请重新输入")
continue
break
except ValueError:
print("请输入有效的数字")
continue
if length_input.lower() == 'q':
continue
# 测试知识库功能
print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...")
knowledge_library.process_files(knowledge_length=knowledge_length)
else:
console.print("[red]无效的选项,请重新选择[/red]")
continue