v0.3.1 实装了记忆系统和自动发言

哈哈哈
This commit is contained in:
SengokuCola
2025-03-02 00:14:25 +08:00
parent ba5837503e
commit 50c1765b81
19 changed files with 732 additions and 327 deletions

View File

@@ -1,3 +1,4 @@
from loguru import logger
from nonebot import on_message, on_command, require, get_driver
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
from nonebot.typing import T_State
@@ -10,9 +11,6 @@ from .relationship_manager import relationship_manager
from ..schedule.schedule_generator import bot_schedule
from .willing_manager import willing_manager
from ..memory_system.memory import memory_graph
# 获取驱动器
driver = get_driver()
@@ -21,10 +19,7 @@ Database.initialize(
global_config.MONGODB_PORT,
global_config.DATABASE_NAME
)
print("\033[1;32m[初始化配置和数据库完成]\033[0m")
print("\033[1;32m[初始化数据库完成]\033[0m")
# 导入其他模块
@@ -32,6 +27,7 @@ from .bot import ChatBot
from .emoji_manager import emoji_manager
from .message_send_control import message_sender
from .relationship_manager import relationship_manager
from ..memory_system.memory import memory_graph,hippocampus
# 初始化表情管理器
emoji_manager.initialize()
@@ -39,21 +35,26 @@ emoji_manager.initialize()
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
# 创建机器人实例
chat_bot = ChatBot(global_config)
# 注册消息处理器
group_msg = on_message()
# 创建定时任务
scheduler = require("nonebot_plugin_apscheduler").scheduler
# 启动后台任务
@driver.on_startup
async def start_background_tasks():
"""启动后台任务"""
# 只启动表情包管理任务
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
bot_schedule.print_schedule()
@driver.on_startup
async def init_relationships():
"""在 NoneBot2 启动时初始化关系管理器"""
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...")
await relationship_manager.load_all_relationships()
asyncio.create_task(relationship_manager._start_relationship_manager())
@driver.on_bot_connect
async def _(bot: Bot):
@@ -68,19 +69,23 @@ async def _(bot: Bot):
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
# 启动消息发送控制任务
@driver.on_startup
async def init_relationships():
"""在 NoneBot2 启动时初始化关系管理器"""
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...")
await relationship_manager.load_all_relationships()
asyncio.create_task(relationship_manager._start_relationship_manager())
@group_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
await chat_bot.handle_message(event, bot)
'''
@scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships")
async def monitor_relationships():
"""每15秒打印一次关系数据"""
relationship_manager.print_all_relationships()
'''
# 添加build_memory定时任务
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task():
"""每30秒执行一次记忆构建"""
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
hippocampus.build_memory(chat_size=12)
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")

View File

@@ -83,7 +83,7 @@ class ChatBot:
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}")
# print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}")
message = Message(
@@ -100,14 +100,19 @@ class ChatBot:
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
all_num = 0
interested_num = 0
if topic:
for current_topic in topic:
all_num += 1
first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2)
if first_layer_items:
print(f"\033[1;32m[记忆检索-bot]\033[0m 有印象:{current_topic}")
interested_num += 1
print(f"\033[1;32m[前额叶]\033[0m 对|{current_topic}|有印象")
interested_rate = interested_num / all_num if all_num > 0 else 0
await self.storage.store_message(message, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
@@ -117,7 +122,8 @@ class ChatBot:
is_mentioned,
self.config,
event.user_id,
message.is_emoji
message.is_emoji,
interested_rate
)
current_willing = willing_manager.get_willing(event.group_id)
@@ -188,7 +194,8 @@ class ChatBot:
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=bot_response_time,
is_emoji=True
is_emoji=True,
translate_cq=False
)
message_sender.send_temp_container.add_message(bot_message)

View File

@@ -6,6 +6,8 @@ import logging
import configparser
import tomli
import sys
from loguru import logger
from dotenv import load_dotenv
@@ -21,7 +23,7 @@ class BotConfig:
MONGODB_PASSWORD: Optional[str] = None # 默认空值
MONGODB_AUTH_SOURCE: Optional[str] = None # 默认空值
BOT_QQ: Optional[int] = None
BOT_QQ: Optional[int] = 1
BOT_NICKNAME: Optional[str] = None
# 消息处理相关配置
@@ -35,6 +37,7 @@ class BotConfig:
talk_frequency_down_groups = set()
ban_user_id = set()
build_memory_interval: int = 60 # 记忆构建间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
@@ -45,9 +48,21 @@ class BotConfig:
enable_advance_output: bool = False # 是否启用高级输出
@staticmethod
def get_default_config_path() -> str:
"""获取默认配置文件路径"""
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
config_dir = os.path.join(root_dir, 'config')
return os.path.join(config_dir, 'bot_config.toml')
@classmethod
def load_config(cls, config_path: str = "bot_config.toml") -> "BotConfig":
def load_config(cls, config_path: str = None) -> "BotConfig":
"""从TOML配置文件加载配置"""
if config_path is None:
config_path = cls.get_default_config_path()
logger.info(f"使用默认配置文件路径: {config_path}")
config = cls()
if os.path.exists(config_path):
with open(config_path, "rb") as f:
@@ -93,6 +108,10 @@ class BotConfig:
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
if "memory" in toml_dict:
memory_config = toml_dict["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
# 群组配置
if "groups" in toml_dict:
groups_config = toml_dict["groups"]
@@ -104,16 +123,26 @@ class BotConfig:
others_config = toml_dict["others"]
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
print(f"\033[1;32m成功加载配置文件: {config_path}\033[0m")
logger.success(f"成功加载配置文件: {config_path}")
return config
global_config = BotConfig.load_config(".bot_config.toml")
# 获取配置文件路径
bot_config_path = BotConfig.get_default_config_path()
config_dir = os.path.dirname(bot_config_path)
env_path = os.path.join(config_dir, '.env')
from dotenv import load_dotenv
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
load_dotenv(os.path.join(root_dir, '.env'))
logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
global_config = BotConfig.load_config(config_path=bot_config_path)
# 加载环境变量
logger.info(f"尝试从 {env_path} 加载环境变量配置")
if os.path.exists(env_path):
load_dotenv(env_path)
logger.success("成功加载环境变量配置")
else:
logger.error(f"环境变量配置文件不存在: {env_path}")
@dataclass
class LLMConfig:
@@ -132,9 +161,5 @@ llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
if not global_config.enable_advance_output:
# 只降低日志级别而不是完全移除
logger.remove()
logger.add(sys.stderr, level="WARNING") # 添加一个只输出 WARNING 及以上级别的处理器
# 设置 nonebot 的日志级别
logging.getLogger('nonebot').setLevel(logging.WARNING)
# logger.remove()
pass

View File

@@ -1,186 +0,0 @@
import os
import sys
import numpy as np
import requests
import time
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import Database
from src.plugins.chat.config import llm_config
# 直接配置数据库连接信息
Database.initialize(
"127.0.0.1", # MongoDB 主机
27017, # MongoDB 端口
"MegBot" # 数据库名称
)
class KnowledgeLibrary:
def __init__(self):
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()['data'][0]['embedding']
def process_files(self):
"""处理raw_info目录下的所有txt文件"""
for filename in os.listdir(self.raw_info_dir):
if filename.endswith('.txt'):
file_path = os.path.join(self.raw_info_dir, filename)
self.process_single_file(file_path)
def process_single_file(self, file_path: str):
"""处理单个文件"""
try:
# 检查文件是否已处理
if self.db.db.processed_files.find_one({"file_path": file_path}):
print(f"文件已处理过,跳过: {file_path}")
return
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 按1024字符分段
segments = [content[i:i+300] for i in range(0, len(content), 300)]
# 处理每个分段
for segment in segments:
if not segment.strip(): # 跳过空段
continue
# 获取embedding
embedding = self.get_embedding(segment)
if not embedding:
continue
# 存储到数据库
doc = {
"content": segment,
"embedding": embedding,
"file_path": file_path,
"segment_length": len(segment)
}
# 使用文本内容的哈希值作为唯一标识
content_hash = hash(segment)
# 更新或插入文档
self.db.db.knowledges.update_one(
{"content_hash": content_hash},
{"$set": doc},
upsert=True
)
# 记录文件已处理
self.db.db.processed_files.insert_one({
"file_path": file_path,
"processed_time": time.time()
})
print(f"成功处理文件: {file_path}")
except Exception as e:
print(f"处理文件 {file_path} 时出错: {str(e)}")
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]}
]}
]
}
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
}
}
},
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
# 测试知识库功能
print("开始处理知识库文件...")
knowledge_library.process_files()
# 测试搜索功能
test_query = "麦麦评价一下僕と花"
print(f"\n搜索与'{test_query}'相似的内容:")
results = knowledge_library.search_similar_segments(test_query)
for result in results:
print(f"相似度: {result['similarity']:.4f}")
print(f"内容: {result['content'][:100]}...")
print("-" * 50)

View File

@@ -4,7 +4,7 @@ import asyncio
import requests
from functools import partial
from .message import Message
from .config import BotConfig
from .config import BotConfig, global_config
from ...common.database import Database
import random
import time
@@ -255,4 +255,4 @@ class LLMResponseGenerator:
return processed_response, emotion_tags
# 创建全局实例
llm_response = LLMResponseGenerator(config=BotConfig())
llm_response = LLMResponseGenerator(global_config)

View File

@@ -6,17 +6,13 @@ import os
from datetime import datetime
from ...common.database import Database
from PIL import Image
from .config import BotConfig, global_config
from .config import global_config
import urllib3
from .utils_user import get_user_nickname
from .utils_cq import parse_cq_code
from .cq_code import cq_code_tool,CQCode
Message = ForwardRef('Message') # 添加这行
# 加载配置
bot_config = BotConfig.load_config()
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -48,6 +44,8 @@ class Message:
is_emoji: bool = False # 是否是表情包
has_emoji: bool = False # 是否包含表情包
translate_cq: bool = True # 是否翻译cq码
reply_benefits: float = 0.0
@@ -99,7 +97,7 @@ class Message:
- cq_code_list:分割出的聊天对象包括文本和CQ码
- trans_list:翻译后的对象列表
"""
print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
# print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
cq_code_dict_list = []
trans_list = []

View File

@@ -208,7 +208,15 @@ class MessageSendControl:
print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}")
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}")
await self.storage.store_message(message, None)
if message.is_emoji:
message.processed_plain_text = "[表情包]"
await self.storage.store_message(message, None)
else:
await self.storage.store_message(message, None)
queue.update_send_time()
if queue.has_messages():
await asyncio.sleep(

View File

@@ -53,8 +53,8 @@ class PromptBuilder:
# 遍历所有topic
for current_topic in topic:
first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2)
if first_layer_items:
print(f"\033[1;32m[pb记忆检索]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}")
# if first_layer_items:
# print(f"\033[1;32m[前额叶]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}")
# 记录第一层数据
all_first_layer_items.extend(first_layer_items)
@@ -68,14 +68,14 @@ class PromptBuilder:
# 找到重叠的记忆
overlap = set(second_layer_items) & set(other_second_layer)
if overlap:
print(f"\033[1;32m[pb记忆检索]\033[0m 发现主题 '{current_topic}''{other_topic}' 有共同的第二层记忆: {overlap}")
# print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}")
overlapping_second_layer.update(overlap)
# 合并所有需要的记忆
if all_first_layer_items:
print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
if overlapping_second_layer:
print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
all_memories = all_first_layer_items + list(overlapping_second_layer)

View File

@@ -7,6 +7,8 @@ import numpy as np
from .config import llm_config, global_config
import re
from typing import Dict
from collections import Counter
import math
def combine_messages(messages: List[Message]) -> str:
@@ -81,6 +83,39 @@ def cosine_similarity(v1, v2):
norm2 = np.linalg.norm(v2)
return dot_product / (norm1 * norm2)
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录

View File

@@ -4,11 +4,9 @@ import hashlib
import time
import os
from ...common.database import Database
from .config import BotConfig
import zlib # 用于 CRC32
import base64
bot_config = BotConfig.load_config()
from .config import global_config
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
@@ -39,12 +37,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
# 连接数据库
db = Database(
host=bot_config.MONGODB_HOST,
port=bot_config.MONGODB_PORT,
db_name=bot_config.DATABASE_NAME,
username=bot_config.MONGODB_USERNAME,
password=bot_config.MONGODB_PASSWORD,
auth_source=bot_config.MONGODB_AUTH_SOURCE
host=global_config.MONGODB_HOST,
port=global_config.MONGODB_PORT,
db_name=global_config.DATABASE_NAME,
username=global_config.MONGODB_USERNAME,
password=global_config.MONGODB_PASSWORD,
auth_source=global_config.MONGODB_AUTH_SOURCE
)
# 检查是否已存在相同哈希值的图片

View File

@@ -22,22 +22,31 @@ class WillingManager:
"""设置指定群组的回复意愿"""
self.group_reply_willing[group_id] = willing
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False) -> float:
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float:
"""改变指定群组的回复意愿并返回回复概率"""
current_willing = self.group_reply_willing.get(group_id, 0)
if topic and current_willing < 1:
current_willing += 0.2
elif topic:
current_willing += 0.05
print(f"初始意愿: {current_willing}")
# if topic and current_willing < 1:
# current_willing += 0.2
# elif topic:
# current_willing += 0.05
if is_mentioned_bot and current_willing < 1.0:
current_willing += 0.9
print(f"被提及, 当前意愿: {current_willing}")
elif is_mentioned_bot:
current_willing += 0.05
print(f"被重复提及, 当前意愿: {current_willing}")
if is_emoji:
current_willing *= 0.2
current_willing *= 0.15
print(f"表情包, 当前意愿: {current_willing}")
if interested_rate > 0.6:
print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
current_willing += interested_rate-0.45
self.group_reply_willing[group_id] = min(current_willing, 3.0)
@@ -55,15 +64,15 @@ class WillingManager:
return reply_probability
def change_reply_willing_sent(self, group_id: int):
"""发送消息后降低群组的回复意愿"""
"""开始思考后降低群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
self.group_reply_willing[group_id] = max(0, current_willing - 1.8)
self.group_reply_willing[group_id] = max(0, current_willing - 2)
def change_reply_willing_after_sent(self, group_id: int):
"""发送消息后提高群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
if current_willing < 1:
self.group_reply_willing[group_id] = min(1, current_willing + 0.4)
self.group_reply_willing[group_id] = min(1, current_willing + 0.3)
async def ensure_started(self):
"""确保衰减任务已启动"""