v0.3.1 实装了记忆系统和自动发言

哈哈哈
This commit is contained in:
SengokuCola
2025-03-02 00:14:25 +08:00
parent ba5837503e
commit 50c1765b81
19 changed files with 732 additions and 327 deletions

2
.gitignore vendored
View File

@@ -3,7 +3,7 @@ mongodb/
NapCat.Framework.Windows.Once/ NapCat.Framework.Windows.Once/
log/ log/
src/plugins/memory src/plugins/memory
src/plugins/chat/bot_config.toml config/bot_config.toml
/test /test
message_queue_content.txt message_queue_content.txt
message_queue_content.bat message_queue_content.bat

View File

@@ -16,11 +16,19 @@
基于llm、napcat、nonebot和mongodb的专注于群聊天的qqbot 基于llm、napcat、nonebot和mongodb的专注于群聊天的qqbot
<div align="center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
<img src="https://i0.hdslb.com/bfs/archive/7d9fa0a88e8a1aa01b92b8a5a743a2671c0e1798.jpg" width="500" alt="麦麦演示视频">
<br>
👆 点击观看麦麦演示视频 👆
</a>
</div>
> ⚠️ **警告**:代码可能随时更改,目前版本不一定是稳定版本 > ⚠️ **警告**:代码可能随时更改,目前版本不一定是稳定版本
> ⚠️ **警告**请自行了解qqbot的风险麦麦有时候一天被腾讯肘七八次 > ⚠️ **警告**请自行了解qqbot的风险麦麦有时候一天被腾讯肘七八次
> ⚠️ **警告**由于麦麦一直在迭代所以可能存在一些bug请自行测试包括胡言乱语 > ⚠️ **警告**由于麦麦一直在迭代所以可能存在一些bug请自行测试包括胡言乱语
关于麦麦的开发和部署相关的讨论群(不建议发布无关消息)这里不会有麦麦发言! 关于麦麦的开发和建议相关的讨论群(不建议发布无关消息)这里不会有麦麦发言!
## 开发计划TODOLIST ## 开发计划TODOLIST
@@ -29,6 +37,10 @@
- 对思考链长度限制 - 对思考链长度限制
- 修复已知bug - 修复已知bug
- 完善文档 - 完善文档
- 修复转发
- config自动生成和检测
- log别用print
- 给发送消息写专门的类
<div align="center"> <div align="center">

View File

@@ -29,6 +29,11 @@ model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率 model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率 model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
[memory]
build_memory_interval = 300 # 记忆构建间隔
[others] [others]
enable_advance_output = true # 开启后输出更多日志,false关闭true开启 enable_advance_output = true # 开启后输出更多日志,false关闭true开启

View File

@@ -1,3 +1,4 @@
from loguru import logger
from nonebot import on_message, on_command, require, get_driver from nonebot import on_message, on_command, require, get_driver
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
from nonebot.typing import T_State from nonebot.typing import T_State
@@ -10,9 +11,6 @@ from .relationship_manager import relationship_manager
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
from .willing_manager import willing_manager from .willing_manager import willing_manager
from ..memory_system.memory import memory_graph
# 获取驱动器 # 获取驱动器
driver = get_driver() driver = get_driver()
@@ -21,10 +19,7 @@ Database.initialize(
global_config.MONGODB_PORT, global_config.MONGODB_PORT,
global_config.DATABASE_NAME global_config.DATABASE_NAME
) )
print("\033[1;32m[初始化数据库完成]\033[0m")
print("\033[1;32m[初始化配置和数据库完成]\033[0m")
# 导入其他模块 # 导入其他模块
@@ -32,6 +27,7 @@ from .bot import ChatBot
from .emoji_manager import emoji_manager from .emoji_manager import emoji_manager
from .message_send_control import message_sender from .message_send_control import message_sender
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from ..memory_system.memory import memory_graph,hippocampus
# 初始化表情管理器 # 初始化表情管理器
emoji_manager.initialize() emoji_manager.initialize()
@@ -39,21 +35,26 @@ emoji_manager.initialize()
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m") print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
# 创建机器人实例 # 创建机器人实例
chat_bot = ChatBot(global_config) chat_bot = ChatBot(global_config)
# 注册消息处理器 # 注册消息处理器
group_msg = on_message() group_msg = on_message()
# 创建定时任务 # 创建定时任务
scheduler = require("nonebot_plugin_apscheduler").scheduler scheduler = require("nonebot_plugin_apscheduler").scheduler
# 启动后台任务
@driver.on_startup @driver.on_startup
async def start_background_tasks(): async def start_background_tasks():
"""启动后台任务""" """启动后台任务"""
# 只启动表情包管理任务 # 只启动表情包管理任务
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL)) asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
bot_schedule.print_schedule() bot_schedule.print_schedule()
@driver.on_startup
async def init_relationships():
"""在 NoneBot2 启动时初始化关系管理器"""
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...")
await relationship_manager.load_all_relationships()
asyncio.create_task(relationship_manager._start_relationship_manager())
@driver.on_bot_connect @driver.on_bot_connect
async def _(bot: Bot): async def _(bot: Bot):
@@ -68,19 +69,23 @@ async def _(bot: Bot):
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m") print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
# 启动消息发送控制任务 # 启动消息发送控制任务
@driver.on_startup
async def init_relationships():
"""在 NoneBot2 启动时初始化关系管理器"""
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...")
await relationship_manager.load_all_relationships()
asyncio.create_task(relationship_manager._start_relationship_manager())
@group_msg.handle() @group_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State): async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
await chat_bot.handle_message(event, bot) await chat_bot.handle_message(event, bot)
'''
@scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships") @scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships")
async def monitor_relationships(): async def monitor_relationships():
"""每15秒打印一次关系数据""" """每15秒打印一次关系数据"""
relationship_manager.print_all_relationships() relationship_manager.print_all_relationships()
'''
# 添加build_memory定时任务
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task():
"""每30秒执行一次记忆构建"""
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
hippocampus.build_memory(chat_size=12)
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")

View File

@@ -83,7 +83,7 @@ class ChatBot:
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info) await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5) await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}") # print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}")
message = Message( message = Message(
@@ -100,14 +100,19 @@ class ChatBot:
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text) topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}") print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
all_num = 0
interested_num = 0
if topic: if topic:
for current_topic in topic: for current_topic in topic:
all_num += 1
first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2) first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2)
if first_layer_items: if first_layer_items:
print(f"\033[1;32m[记忆检索-bot]\033[0m 有印象:{current_topic}") interested_num += 1
print(f"\033[1;32m[前额叶]\033[0m 对|{current_topic}|有印象")
interested_rate = interested_num / all_num if all_num > 0 else 0
await self.storage.store_message(message, topic[0] if topic else None) await self.storage.store_message(message, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text) is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
@@ -117,7 +122,8 @@ class ChatBot:
is_mentioned, is_mentioned,
self.config, self.config,
event.user_id, event.user_id,
message.is_emoji message.is_emoji,
interested_rate
) )
current_willing = willing_manager.get_willing(event.group_id) current_willing = willing_manager.get_willing(event.group_id)
@@ -188,7 +194,8 @@ class ChatBot:
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name, group_name=message.group_name,
time=bot_response_time, time=bot_response_time,
is_emoji=True is_emoji=True,
translate_cq=False
) )
message_sender.send_temp_container.add_message(bot_message) message_sender.send_temp_container.add_message(bot_message)

View File

@@ -6,6 +6,8 @@ import logging
import configparser import configparser
import tomli import tomli
import sys import sys
from loguru import logger
from dotenv import load_dotenv
@@ -21,7 +23,7 @@ class BotConfig:
MONGODB_PASSWORD: Optional[str] = None # 默认空值 MONGODB_PASSWORD: Optional[str] = None # 默认空值
MONGODB_AUTH_SOURCE: Optional[str] = None # 默认空值 MONGODB_AUTH_SOURCE: Optional[str] = None # 默认空值
BOT_QQ: Optional[int] = None BOT_QQ: Optional[int] = 1
BOT_NICKNAME: Optional[str] = None BOT_NICKNAME: Optional[str] = None
# 消息处理相关配置 # 消息处理相关配置
@@ -35,6 +37,7 @@ class BotConfig:
talk_frequency_down_groups = set() talk_frequency_down_groups = set()
ban_user_id = set() ban_user_id = set()
build_memory_interval: int = 60 # 记忆构建间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
@@ -45,9 +48,21 @@ class BotConfig:
enable_advance_output: bool = False # 是否启用高级输出 enable_advance_output: bool = False # 是否启用高级输出
@staticmethod
def get_default_config_path() -> str:
"""获取默认配置文件路径"""
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
config_dir = os.path.join(root_dir, 'config')
return os.path.join(config_dir, 'bot_config.toml')
@classmethod @classmethod
def load_config(cls, config_path: str = "bot_config.toml") -> "BotConfig": def load_config(cls, config_path: str = None) -> "BotConfig":
"""从TOML配置文件加载配置""" """从TOML配置文件加载配置"""
if config_path is None:
config_path = cls.get_default_config_path()
logger.info(f"使用默认配置文件路径: {config_path}")
config = cls() config = cls()
if os.path.exists(config_path): if os.path.exists(config_path):
with open(config_path, "rb") as f: with open(config_path, "rb") as f:
@@ -93,6 +108,10 @@ class BotConfig:
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE) config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance) config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
if "memory" in toml_dict:
memory_config = toml_dict["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
# 群组配置 # 群组配置
if "groups" in toml_dict: if "groups" in toml_dict:
groups_config = toml_dict["groups"] groups_config = toml_dict["groups"]
@@ -104,16 +123,26 @@ class BotConfig:
others_config = toml_dict["others"] others_config = toml_dict["others"]
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output) config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
print(f"\033[1;32m成功加载配置文件: {config_path}\033[0m") logger.success(f"成功加载配置文件: {config_path}")
return config return config
global_config = BotConfig.load_config(".bot_config.toml") # 获取配置文件路径
bot_config_path = BotConfig.get_default_config_path()
config_dir = os.path.dirname(bot_config_path)
env_path = os.path.join(config_dir, '.env')
from dotenv import load_dotenv logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
current_dir = os.path.dirname(os.path.abspath(__file__)) global_config = BotConfig.load_config(config_path=bot_config_path)
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
load_dotenv(os.path.join(root_dir, '.env')) # 加载环境变量
logger.info(f"尝试从 {env_path} 加载环境变量配置")
if os.path.exists(env_path):
load_dotenv(env_path)
logger.success("成功加载环境变量配置")
else:
logger.error(f"环境变量配置文件不存在: {env_path}")
@dataclass @dataclass
class LLMConfig: class LLMConfig:
@@ -132,9 +161,5 @@ llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
if not global_config.enable_advance_output: if not global_config.enable_advance_output:
# 只降低日志级别而不是完全移除 # logger.remove()
logger.remove() pass
logger.add(sys.stderr, level="WARNING") # 添加一个只输出 WARNING 及以上级别的处理器
# 设置 nonebot 的日志级别
logging.getLogger('nonebot').setLevel(logging.WARNING)

View File

@@ -4,7 +4,7 @@ import asyncio
import requests import requests
from functools import partial from functools import partial
from .message import Message from .message import Message
from .config import BotConfig from .config import BotConfig, global_config
from ...common.database import Database from ...common.database import Database
import random import random
import time import time
@@ -255,4 +255,4 @@ class LLMResponseGenerator:
return processed_response, emotion_tags return processed_response, emotion_tags
# 创建全局实例 # 创建全局实例
llm_response = LLMResponseGenerator(config=BotConfig()) llm_response = LLMResponseGenerator(global_config)

View File

@@ -6,17 +6,13 @@ import os
from datetime import datetime from datetime import datetime
from ...common.database import Database from ...common.database import Database
from PIL import Image from PIL import Image
from .config import BotConfig, global_config from .config import global_config
import urllib3 import urllib3
from .utils_user import get_user_nickname from .utils_user import get_user_nickname
from .utils_cq import parse_cq_code from .utils_cq import parse_cq_code
from .cq_code import cq_code_tool,CQCode from .cq_code import cq_code_tool,CQCode
Message = ForwardRef('Message') # 添加这行 Message = ForwardRef('Message') # 添加这行
# 加载配置
bot_config = BotConfig.load_config()
# 禁用SSL警告 # 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -48,6 +44,8 @@ class Message:
is_emoji: bool = False # 是否是表情包 is_emoji: bool = False # 是否是表情包
has_emoji: bool = False # 是否包含表情包 has_emoji: bool = False # 是否包含表情包
translate_cq: bool = True # 是否翻译cq码
reply_benefits: float = 0.0 reply_benefits: float = 0.0
@@ -99,7 +97,7 @@ class Message:
- cq_code_list:分割出的聊天对象包括文本和CQ码 - cq_code_list:分割出的聊天对象包括文本和CQ码
- trans_list:翻译后的对象列表 - trans_list:翻译后的对象列表
""" """
print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}") # print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
cq_code_dict_list = [] cq_code_dict_list = []
trans_list = [] trans_list = []

View File

@@ -208,7 +208,15 @@ class MessageSendControl:
print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}") print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}")
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}") print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}")
await self.storage.store_message(message, None)
if message.is_emoji:
message.processed_plain_text = "[表情包]"
await self.storage.store_message(message, None)
else:
await self.storage.store_message(message, None)
queue.update_send_time() queue.update_send_time()
if queue.has_messages(): if queue.has_messages():
await asyncio.sleep( await asyncio.sleep(

View File

@@ -53,8 +53,8 @@ class PromptBuilder:
# 遍历所有topic # 遍历所有topic
for current_topic in topic: for current_topic in topic:
first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2) first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2)
if first_layer_items: # if first_layer_items:
print(f"\033[1;32m[pb记忆检索]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}") # print(f"\033[1;32m[前额叶]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}")
# 记录第一层数据 # 记录第一层数据
all_first_layer_items.extend(first_layer_items) all_first_layer_items.extend(first_layer_items)
@@ -68,14 +68,14 @@ class PromptBuilder:
# 找到重叠的记忆 # 找到重叠的记忆
overlap = set(second_layer_items) & set(other_second_layer) overlap = set(second_layer_items) & set(other_second_layer)
if overlap: if overlap:
print(f"\033[1;32m[pb记忆检索]\033[0m 发现主题 '{current_topic}''{other_topic}' 有共同的第二层记忆: {overlap}") # print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}")
overlapping_second_layer.update(overlap) overlapping_second_layer.update(overlap)
# 合并所有需要的记忆 # 合并所有需要的记忆
if all_first_layer_items: if all_first_layer_items:
print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆1: {all_first_layer_items}") print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
if overlapping_second_layer: if overlapping_second_layer:
print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}") print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
all_memories = all_first_layer_items + list(overlapping_second_layer) all_memories = all_first_layer_items + list(overlapping_second_layer)

View File

@@ -7,6 +7,8 @@ import numpy as np
from .config import llm_config, global_config from .config import llm_config, global_config
import re import re
from typing import Dict from typing import Dict
from collections import Counter
import math
def combine_messages(messages: List[Message]) -> str: def combine_messages(messages: List[Message]) -> str:
@@ -81,6 +83,39 @@ def cosine_similarity(v1, v2):
norm2 = np.linalg.norm(v2) norm2 = np.linalg.norm(v2)
return dot_product / (norm1 * norm2) return dot_product / (norm1 * norm2)
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录

View File

@@ -4,11 +4,9 @@ import hashlib
import time import time
import os import os
from ...common.database import Database from ...common.database import Database
from .config import BotConfig
import zlib # 用于 CRC32 import zlib # 用于 CRC32
import base64 import base64
from .config import global_config
bot_config = BotConfig.load_config()
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes: def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
@@ -39,12 +37,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
# 连接数据库 # 连接数据库
db = Database( db = Database(
host=bot_config.MONGODB_HOST, host=global_config.MONGODB_HOST,
port=bot_config.MONGODB_PORT, port=global_config.MONGODB_PORT,
db_name=bot_config.DATABASE_NAME, db_name=global_config.DATABASE_NAME,
username=bot_config.MONGODB_USERNAME, username=global_config.MONGODB_USERNAME,
password=bot_config.MONGODB_PASSWORD, password=global_config.MONGODB_PASSWORD,
auth_source=bot_config.MONGODB_AUTH_SOURCE auth_source=global_config.MONGODB_AUTH_SOURCE
) )
# 检查是否已存在相同哈希值的图片 # 检查是否已存在相同哈希值的图片

View File

@@ -22,22 +22,31 @@ class WillingManager:
"""设置指定群组的回复意愿""" """设置指定群组的回复意愿"""
self.group_reply_willing[group_id] = willing self.group_reply_willing[group_id] = willing
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False) -> float: def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float:
"""改变指定群组的回复意愿并返回回复概率""" """改变指定群组的回复意愿并返回回复概率"""
current_willing = self.group_reply_willing.get(group_id, 0) current_willing = self.group_reply_willing.get(group_id, 0)
if topic and current_willing < 1: print(f"初始意愿: {current_willing}")
current_willing += 0.2
elif topic: # if topic and current_willing < 1:
current_willing += 0.05 # current_willing += 0.2
# elif topic:
# current_willing += 0.05
if is_mentioned_bot and current_willing < 1.0: if is_mentioned_bot and current_willing < 1.0:
current_willing += 0.9 current_willing += 0.9
print(f"被提及, 当前意愿: {current_willing}")
elif is_mentioned_bot: elif is_mentioned_bot:
current_willing += 0.05 current_willing += 0.05
print(f"被重复提及, 当前意愿: {current_willing}")
if is_emoji: if is_emoji:
current_willing *= 0.2 current_willing *= 0.15
print(f"表情包, 当前意愿: {current_willing}")
if interested_rate > 0.6:
print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
current_willing += interested_rate-0.45
self.group_reply_willing[group_id] = min(current_willing, 3.0) self.group_reply_willing[group_id] = min(current_willing, 3.0)
@@ -55,15 +64,15 @@ class WillingManager:
return reply_probability return reply_probability
def change_reply_willing_sent(self, group_id: int): def change_reply_willing_sent(self, group_id: int):
"""发送消息后降低群组的回复意愿""" """开始思考后降低群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0) current_willing = self.group_reply_willing.get(group_id, 0)
self.group_reply_willing[group_id] = max(0, current_willing - 1.8) self.group_reply_willing[group_id] = max(0, current_willing - 2)
def change_reply_willing_after_sent(self, group_id: int): def change_reply_willing_after_sent(self, group_id: int):
"""发送消息后提高群组的回复意愿""" """发送消息后提高群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0) current_willing = self.group_reply_willing.get(group_id, 0)
if current_willing < 1: if current_willing < 1:
self.group_reply_willing[group_id] = min(1, current_willing + 0.4) self.group_reply_willing[group_id] = min(1, current_willing + 0.3)
async def ensure_started(self): async def ensure_started(self):
"""确保衰减任务已启动""" """确保衰减任务已启动"""

View File

@@ -0,0 +1,264 @@
# -*- coding: utf-8 -*-
import sys
import jieba
from llm_module import LLMModel
import networkx as nx
import matplotlib.pyplot as plt
import math
from collections import Counter
import datetime
import random
import time
# from chat.config import global_config
import sys
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
self.G.nodes[concept]['memory_items'].append(memory)
else:
self.G.nodes[concept]['memory_items'] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
# print(node_data)
# 创建新的Memory_dot对象
return concept,node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
# print(f"第二层: {neighbor}")
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
def store_memory(self):
for node in self.G.nodes():
dot_data = {
"concept": node
}
self.db.db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self):
# 清空现有的图数据
self.db.db.graph_data.delete_many({})
# 保存节点
for node in self.G.nodes(data=True):
node_data = {
'concept': node[0],
'memory_items': node[1].get('memory_items', []) # 默认为空列表
}
self.db.db.graph_data.nodes.insert_one(node_data)
# 保存边
for edge in self.G.edges():
edge_data = {
'source': edge[0],
'target': edge[1]
}
self.db.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = self.db.db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边
edges = self.db.db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge['source'], edge['target'])
def main():
# 初始化数据库
Database.initialize(
"127.0.0.1",
27017,
"MegBot"
)
memory_graph = Memory_graph()
# 创建LLM模型实例
memory_graph.load_graph_from_db()
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
# memory_graph.save_graph_to_db()
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
# print(items_list)
for memory_item in items_list:
print(memory_item)
else:
print("未找到相关记忆。")
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
def find_topic(text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
return prompt
def topic_what(text, topic):
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
return prompt
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 保存图到本地
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
nodes = list(G.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = G.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
count = 1 if memory_items else 0
memory_counts.append(count)
max_memories = max(memory_counts) if memory_counts else 1
for count in memory_counts:
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
if max_memories > 0:
intensity = min(1.0, count / max_memories)
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
else:
color = (0, 0, 1) # 如果没有记忆,则为蓝色
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
for node in nodes:
degree = G.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
color = (red, 0, blue)
else:
color = (0, 0, 1)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, k=1, iterations=50)
nx.draw(G, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
font_size=10,
font_family='SimHei',
font_weight='bold')
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,82 @@
import os
import requests
from dotenv import load_dotenv
from typing import Tuple, Union
import time
from ..chat.config import BotConfig
# 获取当前文件的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
env_path = os.path.join(root_dir, 'config', '.env')
# 加载环境变量
print(f"尝试从 {env_path} 加载环境变量配置")
if os.path.exists(env_path):
load_dotenv(env_path)
print("成功加载环境变量配置")
else:
print(f"环境变量配置文件不存在: {env_path}")
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
print(f"API URL: {self.base_url}") # 打印 base_url 用于调试
def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15 # 基础等待时间(秒)
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except requests.exceptions.RequestException as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""

View File

@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys
import jieba import jieba
from .llm_module import LLMModel from .llm_module import LLMModel
import networkx as nx import networkx as nx
@@ -11,8 +10,8 @@ import random
import time import time
from ..chat.config import global_config from ..chat.config import global_config
import sys import sys
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 from ...common.database import Database # 使用正确的导入语法
from src.common.database import Database # 使用正确的导入语法 from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
@@ -85,54 +84,66 @@ class Memory_graph:
return first_layer_items, second_layer_items return first_layer_items, second_layer_items
def store_memory(self):
for node in self.G.nodes():
dot_data = {
"concept": node
}
self.db.db.store_memory_dots.insert_one(dot_data)
@property @property
def dots(self): def dots(self):
# 返回所有节点对应的 Memory_dot 对象 # 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()] return [self.get_dot(node) for node in self.G.nodes()]
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self): def save_graph_to_db(self):
# 清空现有的图数据
self.db.db.graph_data.delete_many({})
# 保存节点 # 保存节点
for node in self.G.nodes(data=True): for node in self.G.nodes(data=True):
node_data = { concept = node[0]
'concept': node[0], memory_items = node[1].get('memory_items', [])
'memory_items': node[1].get('memory_items', []) # 默认为空列表
} # 查找是否存在同名节点
self.db.db.graph_data.nodes.insert_one(node_data) existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
if existing_node:
# 如果存在,合并memory_items并去重
existing_items = existing_node.get('memory_items', [])
if not isinstance(existing_items, list):
existing_items = [existing_items] if existing_items else []
# 合并并去重
all_items = list(set(existing_items + memory_items))
# 更新节点
self.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {'memory_items': all_items}}
)
else:
# 如果不存在,创建新节点
node_data = {
'concept': concept,
'memory_items': memory_items
}
self.db.db.graph_data.nodes.insert_one(node_data)
# 保存边 # 保存边
for edge in self.G.edges(): for edge in self.G.edges():
edge_data = { source, target = edge
'source': edge[0],
'target': edge[1] # 查找是否存在同样的边
} existing_edge = self.db.db.graph_data.edges.find_one({
self.db.db.graph_data.edges.insert_one(edge_data) 'source': source,
'target': target
})
if existing_edge:
# 如果存在,增加num属性
num = existing_edge.get('num', 1) + 1
self.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {'num': num}}
)
else:
# 如果不存在,创建新边
edge_data = {
'source': source,
'target': target,
'num': 1
}
self.db.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self): def load_graph_from_db(self):
# 清空当前图 # 清空当前图
@@ -147,150 +158,92 @@ class Memory_graph:
# 加载边 # 加载边
edges = self.db.db.graph_data.edges.find() edges = self.db.db.graph_data.edges.find()
for edge in edges: for edge in edges:
self.G.add_edge(edge['source'], edge['target']) self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
start_time = time.time()
Database.initialize(
global_config.MONGODB_HOST,
global_config.MONGODB_PORT,
global_config.DATABASE_NAME
)
memory_graph = Memory_graph()
llm_model = LLMModel()
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
memory_graph.load_graph_from_db()
end_time = time.time()
print(f"加载海马体耗时: {end_time - start_time:.2f}")
def main():
# 初始化数据库
Database.initialize(
"127.0.0.1",
27017,
"MegBot"
)
memory_graph = Memory_graph()
# 创建LLM模型实例
llm_model = LLMModel()
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
# 使用当前时间戳进行测试
current_timestamp = datetime.datetime.now().timestamp()
chat_text = []
chat_size =40
for _ in range(100): # 循环10次
random_time = current_timestamp - random.randint(1, 3600*39) # 随机时间
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
chat_text.append(chat_) # 拼接所有text
time.sleep(5)
# 海马体
for input_text in chat_text: class Hippocampus:
print(input_text) def __init__(self,memory_graph:Memory_graph):
first_memory = set() self.memory_graph = memory_graph
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5) self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
#将记忆加入到图谱中 def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
for topic, memory in first_memory: current_timestamp = datetime.datetime.now().timestamp()
topics = segment_text(topic) chat_text = []
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") #短期1h 中期4h 长期24h
for split_topic in topics: for _ in range(time_frequency.get('near')): # 循环10次
memory_graph.add_dot(split_topic,memory) random_time = current_timestamp - random.randint(1, 3600) # 随机时间
for split_topic in topics: # print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
for other_split_topic in topics: chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
if split_topic != other_split_topic: chat_text.append(chat_)
memory_graph.connect_dot(split_topic, other_split_topic) for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return chat_text
# memory_graph.store_memory() def build_memory(self,chat_size=12):
#最近消息获取频率
# 展示两种不同的可视化方式 time_frequency = {'near':1,'mid':2,'far':2}
print("\n按连接数量着色的图谱:") memory_sample = self.get_memory_sample(chat_size,time_frequency)
visualize_graph(memory_graph, color_by_memory=False) # print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True) for i, input_text in enumerate(memory_sample, 1):
#加载进度可视化
memory_graph.save_graph_to_db() progress = (i / len(memory_sample)) * 100
# memory_graph.load_graph_from_db() bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
while True: bar = '' * filled_length + '-' * (bar_length - filled_length)
query = input("请输入新的查询概念(输入'退出'以结束):") print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
# print(items_list)
for memory_item in items_list:
print(memory_item)
else:
print("未找到相关记忆。")
while True: # 生成压缩后记忆
query = input("请输入问题:") first_memory = set()
first_memory = self.memory_compress(input_text, 2.5)
if query.lower() == '退出': # 延时防止访问超频
break # time.sleep(5)
#将记忆加入到图谱中
topic_prompt = find_topic(query, 3) for topic, memory in first_memory:
topic_response = llm_model.generate_response(topic_prompt) topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic)
self.memory_graph.save_graph_to_db()
def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
# print(topic_num)
topic_prompt = find_topic(input_text, topic_num)
topic_response = self.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组 # 检查 topic_response 是否为元组
if isinstance(topic_response, tuple): if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else: else:
topics = topic_response.split(",") topics = topic_response.split(",")
print(topics) # print(topics)
compressed_memory = set()
for keyword in topics: for topic in topics:
items_list = memory_graph.get_related_item(keyword) topic_what_prompt = topic_what(input_text,topic)
if items_list: topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
print(items_list) compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
return compressed_memory
def memory_compress(input_text, llm_model, llm_model_small, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
print(topic_num)
topic_prompt = find_topic(input_text, topic_num)
topic_response = llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
compressed_memory = set()
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
return compressed_memory
def segment_text(text): def segment_text(text):
@@ -305,69 +258,21 @@ def topic_what(text, topic):
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好' prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
return prompt return prompt
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 保存图到本地
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
nodes = list(G.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = G.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
count = 1 if memory_items else 0
memory_counts.append(count)
max_memories = max(memory_counts) if memory_counts else 1
for count in memory_counts:
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
if max_memories > 0:
intensity = min(1.0, count / max_memories)
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
else:
color = (0, 0, 1) # 如果没有记忆,则为蓝色
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
for node in nodes:
degree = G.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
color = (red, 0, blue)
else:
color = (0, 0, 1)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, k=1, iterations=50)
nx.draw(G, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
font_size=10,
font_family='SimHei',
font_weight='bold')
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
if __name__ == "__main__":
main()
start_time = time.time()
Database.initialize(
global_config.MONGODB_HOST,
global_config.MONGODB_PORT,
global_config.DATABASE_NAME
)
#创建记忆图
memory_graph = Memory_graph()
#加载数据库中存储的记忆图
memory_graph.load_graph_from_db()
#创建海马体
hippocampus = Hippocampus(memory_graph)
end_time = time.time()
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")

View File

@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys import sys
import jieba import jieba
from llm_module import LLMModel
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import math import math
@@ -9,10 +8,12 @@ from collections import Counter
import datetime import datetime
import random import random
import time import time
import os
from dotenv import load_dotenv
# from chat.config import global_config # from chat.config import global_config
import sys
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法 from src.common.database import Database # 使用正确的导入语法
from src.plugins.memory_system.llm_module import LLMModel
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
@@ -117,22 +118,60 @@ class Memory_graph:
return [] # 如果没有找到记录,返回空列表 return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self): def save_graph_to_db(self):
# 清空现有的图数据
self.db.db.graph_data.delete_many({})
# 保存节点 # 保存节点
for node in self.G.nodes(data=True): for node in self.G.nodes(data=True):
node_data = { concept = node[0]
'concept': node[0], memory_items = node[1].get('memory_items', [])
'memory_items': node[1].get('memory_items', []) # 默认为空列表
} # 查找是否存在同名节点
self.db.db.graph_data.nodes.insert_one(node_data) existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
if existing_node:
# 如果存在,合并memory_items并去重
existing_items = existing_node.get('memory_items', [])
if not isinstance(existing_items, list):
existing_items = [existing_items] if existing_items else []
# 合并并去重
all_items = list(set(existing_items + memory_items))
# 更新节点
self.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {'memory_items': all_items}}
)
else:
# 如果不存在,创建新节点
node_data = {
'concept': concept,
'memory_items': memory_items
}
self.db.db.graph_data.nodes.insert_one(node_data)
# 保存边 # 保存边
for edge in self.G.edges(): for edge in self.G.edges():
edge_data = { source, target = edge
'source': edge[0],
'target': edge[1] # 查找是否存在同样的边
} existing_edge = self.db.db.graph_data.edges.find_one({
self.db.db.graph_data.edges.insert_one(edge_data) 'source': source,
'target': target
})
if existing_edge:
# 如果存在,增加num属性
num = existing_edge.get('num', 1) + 1
self.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {'num': num}}
)
else:
# 如果不存在,创建新边
edge_data = {
'source': source,
'target': target,
'num': 1
}
self.db.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self): def load_graph_from_db(self):
# 清空当前图 # 清空当前图
@@ -147,7 +186,7 @@ class Memory_graph:
# 加载边 # 加载边
edges = self.db.db.graph_data.edges.find() edges = self.db.db.graph_data.edges.find()
for edge in edges: for edge in edges:
self.G.add_edge(edge['source'], edge['target']) self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
def calculate_information_content(text): def calculate_information_content(text):
@@ -180,6 +219,19 @@ def calculate_information_content(text):
def main(): def main():
# 获取当前文件的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
env_path = os.path.join(root_dir, 'config', '.env')
# 加载环境变量
print(f"尝试从 {env_path} 加载环境变量配置")
if os.path.exists(env_path):
load_dotenv(env_path)
print("成功加载环境变量配置")
else:
print(f"环境变量配置文件不存在: {env_path}")
# 初始化数据库 # 初始化数据库
Database.initialize( Database.initialize(
"127.0.0.1", "127.0.0.1",
@@ -196,10 +248,10 @@ def main():
current_timestamp = datetime.datetime.now().timestamp() current_timestamp = datetime.datetime.now().timestamp()
chat_text = [] chat_text = []
chat_size =20 chat_size =25
for _ in range(10): # 循环10次 for _ in range(30): # 循环10次
random_time = current_timestamp - random.randint(1, 3600*3) # 随机时间 random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time) chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
chat_text.append(chat_) # 拼接所有text chat_text.append(chat_) # 拼接所有text
@@ -218,7 +270,7 @@ def main():
# print(input_text) # print(input_text)
first_memory = set() first_memory = set()
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5) first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
time.sleep(5) # time.sleep(5)
#将记忆加入到图谱中 #将记忆加入到图谱中
for topic, memory in first_memory: for topic, memory in first_memory: