v0.3.1 实装了记忆系统和自动发言
哈哈哈
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -3,7 +3,7 @@ mongodb/
|
|||||||
NapCat.Framework.Windows.Once/
|
NapCat.Framework.Windows.Once/
|
||||||
log/
|
log/
|
||||||
src/plugins/memory
|
src/plugins/memory
|
||||||
src/plugins/chat/bot_config.toml
|
config/bot_config.toml
|
||||||
/test
|
/test
|
||||||
message_queue_content.txt
|
message_queue_content.txt
|
||||||
message_queue_content.bat
|
message_queue_content.bat
|
||||||
|
|||||||
14
README.md
14
README.md
@@ -16,11 +16,19 @@
|
|||||||
|
|
||||||
基于llm、napcat、nonebot和mongodb的专注于群聊天的qqbot
|
基于llm、napcat、nonebot和mongodb的专注于群聊天的qqbot
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
||||||
|
<img src="https://i0.hdslb.com/bfs/archive/7d9fa0a88e8a1aa01b92b8a5a743a2671c0e1798.jpg" width="500" alt="麦麦演示视频">
|
||||||
|
<br>
|
||||||
|
👆 点击观看麦麦演示视频 👆
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
> ⚠️ **警告**:代码可能随时更改,目前版本不一定是稳定版本
|
> ⚠️ **警告**:代码可能随时更改,目前版本不一定是稳定版本
|
||||||
> ⚠️ **警告**:请自行了解qqbot的风险,麦麦有时候一天被腾讯肘七八次
|
> ⚠️ **警告**:请自行了解qqbot的风险,麦麦有时候一天被腾讯肘七八次
|
||||||
> ⚠️ **警告**:由于麦麦一直在迭代,所以可能存在一些bug,请自行测试,包括胡言乱语(
|
> ⚠️ **警告**:由于麦麦一直在迭代,所以可能存在一些bug,请自行测试,包括胡言乱语(
|
||||||
|
|
||||||
关于麦麦的开发和部署相关的讨论群(不建议发布无关消息)这里不会有麦麦发言!
|
关于麦麦的开发和建议相关的讨论群(不建议发布无关消息)这里不会有麦麦发言!
|
||||||
|
|
||||||
## 开发计划TODO:LIST
|
## 开发计划TODO:LIST
|
||||||
|
|
||||||
@@ -29,6 +37,10 @@
|
|||||||
- 对思考链长度限制
|
- 对思考链长度限制
|
||||||
- 修复已知bug
|
- 修复已知bug
|
||||||
- 完善文档
|
- 完善文档
|
||||||
|
- 修复转发
|
||||||
|
- config自动生成和检测
|
||||||
|
- log别用print
|
||||||
|
- 给发送消息写专门的类
|
||||||
|
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|||||||
@@ -29,6 +29,11 @@ model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
|
|||||||
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
|
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
|
||||||
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
|
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
|
||||||
|
|
||||||
|
[memory]
|
||||||
|
build_memory_interval = 300 # 记忆构建间隔
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[others]
|
[others]
|
||||||
enable_advance_output = true # 开启后输出更多日志,false关闭true开启
|
enable_advance_output = true # 开启后输出更多日志,false关闭true开启
|
||||||
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from loguru import logger
|
||||||
from nonebot import on_message, on_command, require, get_driver
|
from nonebot import on_message, on_command, require, get_driver
|
||||||
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
|
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
|
||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
@@ -10,9 +11,6 @@ from .relationship_manager import relationship_manager
|
|||||||
from ..schedule.schedule_generator import bot_schedule
|
from ..schedule.schedule_generator import bot_schedule
|
||||||
from .willing_manager import willing_manager
|
from .willing_manager import willing_manager
|
||||||
|
|
||||||
from ..memory_system.memory import memory_graph
|
|
||||||
|
|
||||||
|
|
||||||
# 获取驱动器
|
# 获取驱动器
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
|
|
||||||
@@ -21,10 +19,7 @@ Database.initialize(
|
|||||||
global_config.MONGODB_PORT,
|
global_config.MONGODB_PORT,
|
||||||
global_config.DATABASE_NAME
|
global_config.DATABASE_NAME
|
||||||
)
|
)
|
||||||
|
print("\033[1;32m[初始化数据库完成]\033[0m")
|
||||||
print("\033[1;32m[初始化配置和数据库完成]\033[0m")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 导入其他模块
|
# 导入其他模块
|
||||||
@@ -32,6 +27,7 @@ from .bot import ChatBot
|
|||||||
from .emoji_manager import emoji_manager
|
from .emoji_manager import emoji_manager
|
||||||
from .message_send_control import message_sender
|
from .message_send_control import message_sender
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
|
from ..memory_system.memory import memory_graph,hippocampus
|
||||||
|
|
||||||
# 初始化表情管理器
|
# 初始化表情管理器
|
||||||
emoji_manager.initialize()
|
emoji_manager.initialize()
|
||||||
@@ -39,21 +35,26 @@ emoji_manager.initialize()
|
|||||||
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
|
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
|
||||||
# 创建机器人实例
|
# 创建机器人实例
|
||||||
chat_bot = ChatBot(global_config)
|
chat_bot = ChatBot(global_config)
|
||||||
|
|
||||||
# 注册消息处理器
|
# 注册消息处理器
|
||||||
group_msg = on_message()
|
group_msg = on_message()
|
||||||
|
|
||||||
# 创建定时任务
|
# 创建定时任务
|
||||||
scheduler = require("nonebot_plugin_apscheduler").scheduler
|
scheduler = require("nonebot_plugin_apscheduler").scheduler
|
||||||
|
|
||||||
# 启动后台任务
|
|
||||||
|
|
||||||
@driver.on_startup
|
@driver.on_startup
|
||||||
async def start_background_tasks():
|
async def start_background_tasks():
|
||||||
"""启动后台任务"""
|
"""启动后台任务"""
|
||||||
# 只启动表情包管理任务
|
# 只启动表情包管理任务
|
||||||
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
||||||
|
|
||||||
bot_schedule.print_schedule()
|
bot_schedule.print_schedule()
|
||||||
|
|
||||||
|
@driver.on_startup
|
||||||
|
async def init_relationships():
|
||||||
|
"""在 NoneBot2 启动时初始化关系管理器"""
|
||||||
|
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...")
|
||||||
|
await relationship_manager.load_all_relationships()
|
||||||
|
asyncio.create_task(relationship_manager._start_relationship_manager())
|
||||||
|
|
||||||
@driver.on_bot_connect
|
@driver.on_bot_connect
|
||||||
async def _(bot: Bot):
|
async def _(bot: Bot):
|
||||||
@@ -68,19 +69,23 @@ async def _(bot: Bot):
|
|||||||
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
|
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
|
||||||
# 启动消息发送控制任务
|
# 启动消息发送控制任务
|
||||||
|
|
||||||
@driver.on_startup
|
|
||||||
async def init_relationships():
|
|
||||||
"""在 NoneBot2 启动时初始化关系管理器"""
|
|
||||||
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...")
|
|
||||||
await relationship_manager.load_all_relationships()
|
|
||||||
asyncio.create_task(relationship_manager._start_relationship_manager())
|
|
||||||
|
|
||||||
@group_msg.handle()
|
@group_msg.handle()
|
||||||
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
|
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
|
||||||
await chat_bot.handle_message(event, bot)
|
await chat_bot.handle_message(event, bot)
|
||||||
|
|
||||||
|
'''
|
||||||
@scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships")
|
@scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships")
|
||||||
async def monitor_relationships():
|
async def monitor_relationships():
|
||||||
"""每15秒打印一次关系数据"""
|
"""每15秒打印一次关系数据"""
|
||||||
relationship_manager.print_all_relationships()
|
relationship_manager.print_all_relationships()
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 添加build_memory定时任务
|
||||||
|
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
|
||||||
|
async def build_memory_task():
|
||||||
|
"""每30秒执行一次记忆构建"""
|
||||||
|
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
|
||||||
|
hippocampus.build_memory(chat_size=12)
|
||||||
|
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class ChatBot:
|
|||||||
|
|
||||||
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
||||||
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
||||||
print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}")
|
# print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}")
|
||||||
|
|
||||||
|
|
||||||
message = Message(
|
message = Message(
|
||||||
@@ -100,14 +100,19 @@ class ChatBot:
|
|||||||
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
||||||
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
|
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
|
||||||
|
|
||||||
|
all_num = 0
|
||||||
|
interested_num = 0
|
||||||
if topic:
|
if topic:
|
||||||
for current_topic in topic:
|
for current_topic in topic:
|
||||||
|
all_num += 1
|
||||||
first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2)
|
first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2)
|
||||||
if first_layer_items:
|
if first_layer_items:
|
||||||
print(f"\033[1;32m[记忆检索-bot]\033[0m 有印象:{current_topic}")
|
interested_num += 1
|
||||||
|
print(f"\033[1;32m[前额叶]\033[0m 对|{current_topic}|有印象")
|
||||||
|
interested_rate = interested_num / all_num if all_num > 0 else 0
|
||||||
|
|
||||||
|
|
||||||
await self.storage.store_message(message, topic[0] if topic else None)
|
await self.storage.store_message(message, topic[0] if topic else None)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
|
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
|
||||||
@@ -117,7 +122,8 @@ class ChatBot:
|
|||||||
is_mentioned,
|
is_mentioned,
|
||||||
self.config,
|
self.config,
|
||||||
event.user_id,
|
event.user_id,
|
||||||
message.is_emoji
|
message.is_emoji,
|
||||||
|
interested_rate
|
||||||
)
|
)
|
||||||
current_willing = willing_manager.get_willing(event.group_id)
|
current_willing = willing_manager.get_willing(event.group_id)
|
||||||
|
|
||||||
@@ -188,7 +194,8 @@ class ChatBot:
|
|||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.BOT_NICKNAME,
|
||||||
group_name=message.group_name,
|
group_name=message.group_name,
|
||||||
time=bot_response_time,
|
time=bot_response_time,
|
||||||
is_emoji=True
|
is_emoji=True,
|
||||||
|
translate_cq=False
|
||||||
)
|
)
|
||||||
message_sender.send_temp_container.add_message(bot_message)
|
message_sender.send_temp_container.add_message(bot_message)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import logging
|
|||||||
import configparser
|
import configparser
|
||||||
import tomli
|
import tomli
|
||||||
import sys
|
import sys
|
||||||
|
from loguru import logger
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -21,7 +23,7 @@ class BotConfig:
|
|||||||
MONGODB_PASSWORD: Optional[str] = None # 默认空值
|
MONGODB_PASSWORD: Optional[str] = None # 默认空值
|
||||||
MONGODB_AUTH_SOURCE: Optional[str] = None # 默认空值
|
MONGODB_AUTH_SOURCE: Optional[str] = None # 默认空值
|
||||||
|
|
||||||
BOT_QQ: Optional[int] = None
|
BOT_QQ: Optional[int] = 1
|
||||||
BOT_NICKNAME: Optional[str] = None
|
BOT_NICKNAME: Optional[str] = None
|
||||||
|
|
||||||
# 消息处理相关配置
|
# 消息处理相关配置
|
||||||
@@ -35,6 +37,7 @@ class BotConfig:
|
|||||||
talk_frequency_down_groups = set()
|
talk_frequency_down_groups = set()
|
||||||
ban_user_id = set()
|
ban_user_id = set()
|
||||||
|
|
||||||
|
build_memory_interval: int = 60 # 记忆构建间隔(秒)
|
||||||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
||||||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||||||
|
|
||||||
@@ -45,9 +48,21 @@ class BotConfig:
|
|||||||
|
|
||||||
enable_advance_output: bool = False # 是否启用高级输出
|
enable_advance_output: bool = False # 是否启用高级输出
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_config_path() -> str:
|
||||||
|
"""获取默认配置文件路径"""
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||||
|
config_dir = os.path.join(root_dir, 'config')
|
||||||
|
return os.path.join(config_dir, 'bot_config.toml')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_config(cls, config_path: str = "bot_config.toml") -> "BotConfig":
|
def load_config(cls, config_path: str = None) -> "BotConfig":
|
||||||
"""从TOML配置文件加载配置"""
|
"""从TOML配置文件加载配置"""
|
||||||
|
if config_path is None:
|
||||||
|
config_path = cls.get_default_config_path()
|
||||||
|
logger.info(f"使用默认配置文件路径: {config_path}")
|
||||||
|
|
||||||
config = cls()
|
config = cls()
|
||||||
if os.path.exists(config_path):
|
if os.path.exists(config_path):
|
||||||
with open(config_path, "rb") as f:
|
with open(config_path, "rb") as f:
|
||||||
@@ -93,6 +108,10 @@ class BotConfig:
|
|||||||
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
|
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
|
||||||
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
|
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
|
||||||
|
|
||||||
|
if "memory" in toml_dict:
|
||||||
|
memory_config = toml_dict["memory"]
|
||||||
|
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
|
||||||
|
|
||||||
# 群组配置
|
# 群组配置
|
||||||
if "groups" in toml_dict:
|
if "groups" in toml_dict:
|
||||||
groups_config = toml_dict["groups"]
|
groups_config = toml_dict["groups"]
|
||||||
@@ -104,16 +123,26 @@ class BotConfig:
|
|||||||
others_config = toml_dict["others"]
|
others_config = toml_dict["others"]
|
||||||
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
|
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
|
||||||
|
|
||||||
print(f"\033[1;32m成功加载配置文件: {config_path}\033[0m")
|
logger.success(f"成功加载配置文件: {config_path}")
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
global_config = BotConfig.load_config(".bot_config.toml")
|
# 获取配置文件路径
|
||||||
|
bot_config_path = BotConfig.get_default_config_path()
|
||||||
|
config_dir = os.path.dirname(bot_config_path)
|
||||||
|
env_path = os.path.join(config_dir, '.env')
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
|
||||||
load_dotenv(os.path.join(root_dir, '.env'))
|
# 加载环境变量
|
||||||
|
|
||||||
|
logger.info(f"尝试从 {env_path} 加载环境变量配置")
|
||||||
|
if os.path.exists(env_path):
|
||||||
|
load_dotenv(env_path)
|
||||||
|
logger.success("成功加载环境变量配置")
|
||||||
|
else:
|
||||||
|
logger.error(f"环境变量配置文件不存在: {env_path}")
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMConfig:
|
class LLMConfig:
|
||||||
@@ -132,9 +161,5 @@ llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
|
|||||||
|
|
||||||
|
|
||||||
if not global_config.enable_advance_output:
|
if not global_config.enable_advance_output:
|
||||||
# 只降低日志级别而不是完全移除
|
# logger.remove()
|
||||||
logger.remove()
|
pass
|
||||||
logger.add(sys.stderr, level="WARNING") # 添加一个只输出 WARNING 及以上级别的处理器
|
|
||||||
|
|
||||||
# 设置 nonebot 的日志级别
|
|
||||||
logging.getLogger('nonebot').setLevel(logging.WARNING)
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import asyncio
|
|||||||
import requests
|
import requests
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from .message import Message
|
from .message import Message
|
||||||
from .config import BotConfig
|
from .config import BotConfig, global_config
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
@@ -255,4 +255,4 @@ class LLMResponseGenerator:
|
|||||||
return processed_response, emotion_tags
|
return processed_response, emotion_tags
|
||||||
|
|
||||||
# 创建全局实例
|
# 创建全局实例
|
||||||
llm_response = LLMResponseGenerator(config=BotConfig())
|
llm_response = LLMResponseGenerator(global_config)
|
||||||
@@ -6,17 +6,13 @@ import os
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from .config import BotConfig, global_config
|
from .config import global_config
|
||||||
import urllib3
|
import urllib3
|
||||||
from .utils_user import get_user_nickname
|
from .utils_user import get_user_nickname
|
||||||
from .utils_cq import parse_cq_code
|
from .utils_cq import parse_cq_code
|
||||||
from .cq_code import cq_code_tool,CQCode
|
from .cq_code import cq_code_tool,CQCode
|
||||||
|
|
||||||
Message = ForwardRef('Message') # 添加这行
|
Message = ForwardRef('Message') # 添加这行
|
||||||
|
|
||||||
# 加载配置
|
|
||||||
bot_config = BotConfig.load_config()
|
|
||||||
|
|
||||||
# 禁用SSL警告
|
# 禁用SSL警告
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
@@ -48,6 +44,8 @@ class Message:
|
|||||||
|
|
||||||
is_emoji: bool = False # 是否是表情包
|
is_emoji: bool = False # 是否是表情包
|
||||||
has_emoji: bool = False # 是否包含表情包
|
has_emoji: bool = False # 是否包含表情包
|
||||||
|
|
||||||
|
translate_cq: bool = True # 是否翻译cq码
|
||||||
|
|
||||||
|
|
||||||
reply_benefits: float = 0.0
|
reply_benefits: float = 0.0
|
||||||
@@ -99,7 +97,7 @@ class Message:
|
|||||||
- cq_code_list:分割出的聊天对象,包括文本和CQ码
|
- cq_code_list:分割出的聊天对象,包括文本和CQ码
|
||||||
- trans_list:翻译后的对象列表
|
- trans_list:翻译后的对象列表
|
||||||
"""
|
"""
|
||||||
print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
|
# print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
|
||||||
cq_code_dict_list = []
|
cq_code_dict_list = []
|
||||||
trans_list = []
|
trans_list = []
|
||||||
|
|
||||||
|
|||||||
@@ -208,7 +208,15 @@ class MessageSendControl:
|
|||||||
print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}秒")
|
print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}秒")
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||||
print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}")
|
print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}")
|
||||||
await self.storage.store_message(message, None)
|
|
||||||
|
if message.is_emoji:
|
||||||
|
message.processed_plain_text = "[表情包]"
|
||||||
|
await self.storage.store_message(message, None)
|
||||||
|
else:
|
||||||
|
await self.storage.store_message(message, None)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
queue.update_send_time()
|
queue.update_send_time()
|
||||||
if queue.has_messages():
|
if queue.has_messages():
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
|
|||||||
@@ -53,8 +53,8 @@ class PromptBuilder:
|
|||||||
# 遍历所有topic
|
# 遍历所有topic
|
||||||
for current_topic in topic:
|
for current_topic in topic:
|
||||||
first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2)
|
first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2)
|
||||||
if first_layer_items:
|
# if first_layer_items:
|
||||||
print(f"\033[1;32m[pb记忆检索]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}")
|
# print(f"\033[1;32m[前额叶]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}")
|
||||||
|
|
||||||
# 记录第一层数据
|
# 记录第一层数据
|
||||||
all_first_layer_items.extend(first_layer_items)
|
all_first_layer_items.extend(first_layer_items)
|
||||||
@@ -68,14 +68,14 @@ class PromptBuilder:
|
|||||||
# 找到重叠的记忆
|
# 找到重叠的记忆
|
||||||
overlap = set(second_layer_items) & set(other_second_layer)
|
overlap = set(second_layer_items) & set(other_second_layer)
|
||||||
if overlap:
|
if overlap:
|
||||||
print(f"\033[1;32m[pb记忆检索]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}")
|
# print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}")
|
||||||
overlapping_second_layer.update(overlap)
|
overlapping_second_layer.update(overlap)
|
||||||
|
|
||||||
# 合并所有需要的记忆
|
# 合并所有需要的记忆
|
||||||
if all_first_layer_items:
|
if all_first_layer_items:
|
||||||
print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
|
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
|
||||||
if overlapping_second_layer:
|
if overlapping_second_layer:
|
||||||
print(f"\033[1;32m[pb记忆检索]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
||||||
|
|
||||||
all_memories = all_first_layer_items + list(overlapping_second_layer)
|
all_memories = all_first_layer_items + list(overlapping_second_layer)
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import numpy as np
|
|||||||
from .config import llm_config, global_config
|
from .config import llm_config, global_config
|
||||||
import re
|
import re
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
from collections import Counter
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
def combine_messages(messages: List[Message]) -> str:
|
def combine_messages(messages: List[Message]) -> str:
|
||||||
@@ -81,6 +83,39 @@ def cosine_similarity(v1, v2):
|
|||||||
norm2 = np.linalg.norm(v2)
|
norm2 = np.linalg.norm(v2)
|
||||||
return dot_product / (norm1 * norm2)
|
return dot_product / (norm1 * norm2)
|
||||||
|
|
||||||
|
def calculate_information_content(text):
|
||||||
|
"""计算文本的信息量(熵)"""
|
||||||
|
# 统计字符频率
|
||||||
|
char_count = Counter(text)
|
||||||
|
total_chars = len(text)
|
||||||
|
|
||||||
|
# 计算熵
|
||||||
|
entropy = 0
|
||||||
|
for count in char_count.values():
|
||||||
|
probability = count / total_chars
|
||||||
|
entropy -= probability * math.log2(probability)
|
||||||
|
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||||
|
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||||
|
chat_text = ''
|
||||||
|
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||||
|
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||||
|
|
||||||
|
if closest_record:
|
||||||
|
closest_time = closest_record['time']
|
||||||
|
group_id = closest_record['group_id'] # 获取groupid
|
||||||
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
|
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||||
|
for record in chat_record:
|
||||||
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||||
|
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
||||||
|
return chat_text
|
||||||
|
|
||||||
|
return [] # 如果没有找到记录,返回空列表
|
||||||
|
|
||||||
|
|
||||||
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||||
"""从数据库获取群组最近的消息记录
|
"""从数据库获取群组最近的消息记录
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,9 @@ import hashlib
|
|||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from .config import BotConfig
|
|
||||||
import zlib # 用于 CRC32
|
import zlib # 用于 CRC32
|
||||||
import base64
|
import base64
|
||||||
|
from .config import global_config
|
||||||
bot_config = BotConfig.load_config()
|
|
||||||
|
|
||||||
|
|
||||||
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
|
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
|
||||||
@@ -39,12 +37,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
|||||||
|
|
||||||
# 连接数据库
|
# 连接数据库
|
||||||
db = Database(
|
db = Database(
|
||||||
host=bot_config.MONGODB_HOST,
|
host=global_config.MONGODB_HOST,
|
||||||
port=bot_config.MONGODB_PORT,
|
port=global_config.MONGODB_PORT,
|
||||||
db_name=bot_config.DATABASE_NAME,
|
db_name=global_config.DATABASE_NAME,
|
||||||
username=bot_config.MONGODB_USERNAME,
|
username=global_config.MONGODB_USERNAME,
|
||||||
password=bot_config.MONGODB_PASSWORD,
|
password=global_config.MONGODB_PASSWORD,
|
||||||
auth_source=bot_config.MONGODB_AUTH_SOURCE
|
auth_source=global_config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否已存在相同哈希值的图片
|
# 检查是否已存在相同哈希值的图片
|
||||||
|
|||||||
@@ -22,22 +22,31 @@ class WillingManager:
|
|||||||
"""设置指定群组的回复意愿"""
|
"""设置指定群组的回复意愿"""
|
||||||
self.group_reply_willing[group_id] = willing
|
self.group_reply_willing[group_id] = willing
|
||||||
|
|
||||||
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False) -> float:
|
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float:
|
||||||
"""改变指定群组的回复意愿并返回回复概率"""
|
"""改变指定群组的回复意愿并返回回复概率"""
|
||||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||||
|
|
||||||
if topic and current_willing < 1:
|
print(f"初始意愿: {current_willing}")
|
||||||
current_willing += 0.2
|
|
||||||
elif topic:
|
# if topic and current_willing < 1:
|
||||||
current_willing += 0.05
|
# current_willing += 0.2
|
||||||
|
# elif topic:
|
||||||
|
# current_willing += 0.05
|
||||||
|
|
||||||
if is_mentioned_bot and current_willing < 1.0:
|
if is_mentioned_bot and current_willing < 1.0:
|
||||||
current_willing += 0.9
|
current_willing += 0.9
|
||||||
|
print(f"被提及, 当前意愿: {current_willing}")
|
||||||
elif is_mentioned_bot:
|
elif is_mentioned_bot:
|
||||||
current_willing += 0.05
|
current_willing += 0.05
|
||||||
|
print(f"被重复提及, 当前意愿: {current_willing}")
|
||||||
|
|
||||||
if is_emoji:
|
if is_emoji:
|
||||||
current_willing *= 0.2
|
current_willing *= 0.15
|
||||||
|
print(f"表情包, 当前意愿: {current_willing}")
|
||||||
|
|
||||||
|
if interested_rate > 0.6:
|
||||||
|
print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
|
||||||
|
current_willing += interested_rate-0.45
|
||||||
|
|
||||||
self.group_reply_willing[group_id] = min(current_willing, 3.0)
|
self.group_reply_willing[group_id] = min(current_willing, 3.0)
|
||||||
|
|
||||||
@@ -55,15 +64,15 @@ class WillingManager:
|
|||||||
return reply_probability
|
return reply_probability
|
||||||
|
|
||||||
def change_reply_willing_sent(self, group_id: int):
|
def change_reply_willing_sent(self, group_id: int):
|
||||||
"""发送消息后降低群组的回复意愿"""
|
"""开始思考后降低群组的回复意愿"""
|
||||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||||
self.group_reply_willing[group_id] = max(0, current_willing - 1.8)
|
self.group_reply_willing[group_id] = max(0, current_willing - 2)
|
||||||
|
|
||||||
def change_reply_willing_after_sent(self, group_id: int):
|
def change_reply_willing_after_sent(self, group_id: int):
|
||||||
"""发送消息后提高群组的回复意愿"""
|
"""发送消息后提高群组的回复意愿"""
|
||||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||||
if current_willing < 1:
|
if current_willing < 1:
|
||||||
self.group_reply_willing[group_id] = min(1, current_willing + 0.4)
|
self.group_reply_willing[group_id] = min(1, current_willing + 0.3)
|
||||||
|
|
||||||
async def ensure_started(self):
|
async def ensure_started(self):
|
||||||
"""确保衰减任务已启动"""
|
"""确保衰减任务已启动"""
|
||||||
|
|||||||
264
src/plugins/memory_system/draw_memory.py
Normal file
264
src/plugins/memory_system/draw_memory.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import sys
|
||||||
|
import jieba
|
||||||
|
from llm_module import LLMModel
|
||||||
|
import networkx as nx
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import math
|
||||||
|
from collections import Counter
|
||||||
|
import datetime
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
# from chat.config import global_config
|
||||||
|
import sys
|
||||||
|
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||||
|
from src.common.database import Database # 使用正确的导入语法
|
||||||
|
|
||||||
|
class Memory_graph:
|
||||||
|
def __init__(self):
|
||||||
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
|
self.db = Database.get_instance()
|
||||||
|
|
||||||
|
def connect_dot(self, concept1, concept2):
|
||||||
|
self.G.add_edge(concept1, concept2)
|
||||||
|
|
||||||
|
def add_dot(self, concept, memory):
|
||||||
|
if concept in self.G:
|
||||||
|
# 如果节点已存在,将新记忆添加到现有列表中
|
||||||
|
if 'memory_items' in self.G.nodes[concept]:
|
||||||
|
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||||
|
# 如果当前不是列表,将其转换为列表
|
||||||
|
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
|
||||||
|
self.G.nodes[concept]['memory_items'].append(memory)
|
||||||
|
else:
|
||||||
|
self.G.nodes[concept]['memory_items'] = [memory]
|
||||||
|
else:
|
||||||
|
# 如果是新节点,创建新的记忆列表
|
||||||
|
self.G.add_node(concept, memory_items=[memory])
|
||||||
|
|
||||||
|
def get_dot(self, concept):
|
||||||
|
# 检查节点是否存在于图中
|
||||||
|
if concept in self.G:
|
||||||
|
# 从图中获取节点数据
|
||||||
|
node_data = self.G.nodes[concept]
|
||||||
|
# print(node_data)
|
||||||
|
# 创建新的Memory_dot对象
|
||||||
|
return concept,node_data
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_related_item(self, topic, depth=1):
|
||||||
|
if topic not in self.G:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
first_layer_items = []
|
||||||
|
second_layer_items = []
|
||||||
|
|
||||||
|
# 获取相邻节点
|
||||||
|
neighbors = list(self.G.neighbors(topic))
|
||||||
|
# print(f"第一层: {topic}")
|
||||||
|
|
||||||
|
# 获取当前节点的记忆项
|
||||||
|
node_data = self.get_dot(topic)
|
||||||
|
if node_data:
|
||||||
|
concept, data = node_data
|
||||||
|
if 'memory_items' in data:
|
||||||
|
memory_items = data['memory_items']
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
first_layer_items.extend(memory_items)
|
||||||
|
else:
|
||||||
|
first_layer_items.append(memory_items)
|
||||||
|
|
||||||
|
# 只在depth=2时获取第二层记忆
|
||||||
|
if depth >= 2:
|
||||||
|
# 获取相邻节点的记忆项
|
||||||
|
for neighbor in neighbors:
|
||||||
|
# print(f"第二层: {neighbor}")
|
||||||
|
node_data = self.get_dot(neighbor)
|
||||||
|
if node_data:
|
||||||
|
concept, data = node_data
|
||||||
|
if 'memory_items' in data:
|
||||||
|
memory_items = data['memory_items']
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
second_layer_items.extend(memory_items)
|
||||||
|
else:
|
||||||
|
second_layer_items.append(memory_items)
|
||||||
|
|
||||||
|
return first_layer_items, second_layer_items
|
||||||
|
|
||||||
|
def store_memory(self):
|
||||||
|
for node in self.G.nodes():
|
||||||
|
dot_data = {
|
||||||
|
"concept": node
|
||||||
|
}
|
||||||
|
self.db.db.store_memory_dots.insert_one(dot_data)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dots(self):
|
||||||
|
# 返回所有节点对应的 Memory_dot 对象
|
||||||
|
return [self.get_dot(node) for node in self.G.nodes()]
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_chat_from_db(self, length: int, timestamp: str):
|
||||||
|
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||||
|
chat_text = ''
|
||||||
|
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||||
|
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||||
|
|
||||||
|
if closest_record:
|
||||||
|
closest_time = closest_record['time']
|
||||||
|
group_id = closest_record['group_id'] # 获取groupid
|
||||||
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
|
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||||
|
for record in chat_record:
|
||||||
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||||
|
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
||||||
|
return chat_text
|
||||||
|
|
||||||
|
return [] # 如果没有找到记录,返回空列表
|
||||||
|
|
||||||
|
def save_graph_to_db(self):
|
||||||
|
# 清空现有的图数据
|
||||||
|
self.db.db.graph_data.delete_many({})
|
||||||
|
# 保存节点
|
||||||
|
for node in self.G.nodes(data=True):
|
||||||
|
node_data = {
|
||||||
|
'concept': node[0],
|
||||||
|
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
||||||
|
}
|
||||||
|
self.db.db.graph_data.nodes.insert_one(node_data)
|
||||||
|
# 保存边
|
||||||
|
for edge in self.G.edges():
|
||||||
|
edge_data = {
|
||||||
|
'source': edge[0],
|
||||||
|
'target': edge[1]
|
||||||
|
}
|
||||||
|
self.db.db.graph_data.edges.insert_one(edge_data)
|
||||||
|
|
||||||
|
def load_graph_from_db(self):
|
||||||
|
# 清空当前图
|
||||||
|
self.G.clear()
|
||||||
|
# 加载节点
|
||||||
|
nodes = self.db.db.graph_data.nodes.find()
|
||||||
|
for node in nodes:
|
||||||
|
memory_items = node.get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
self.G.add_node(node['concept'], memory_items=memory_items)
|
||||||
|
# 加载边
|
||||||
|
edges = self.db.db.graph_data.edges.find()
|
||||||
|
for edge in edges:
|
||||||
|
self.G.add_edge(edge['source'], edge['target'])
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 初始化数据库
|
||||||
|
Database.initialize(
|
||||||
|
"127.0.0.1",
|
||||||
|
27017,
|
||||||
|
"MegBot"
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_graph = Memory_graph()
|
||||||
|
# 创建LLM模型实例
|
||||||
|
|
||||||
|
memory_graph.load_graph_from_db()
|
||||||
|
# 展示两种不同的可视化方式
|
||||||
|
print("\n按连接数量着色的图谱:")
|
||||||
|
visualize_graph(memory_graph, color_by_memory=False)
|
||||||
|
|
||||||
|
print("\n按记忆数量着色的图谱:")
|
||||||
|
visualize_graph(memory_graph, color_by_memory=True)
|
||||||
|
|
||||||
|
# memory_graph.save_graph_to_db()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||||
|
if query.lower() == '退出':
|
||||||
|
break
|
||||||
|
items_list = memory_graph.get_related_item(query)
|
||||||
|
if items_list:
|
||||||
|
# print(items_list)
|
||||||
|
for memory_item in items_list:
|
||||||
|
print(memory_item)
|
||||||
|
else:
|
||||||
|
print("未找到相关记忆。")
|
||||||
|
|
||||||
|
|
||||||
|
def segment_text(text):
|
||||||
|
seg_text = list(jieba.cut(text))
|
||||||
|
return seg_text
|
||||||
|
|
||||||
|
def find_topic(text, topic_num):
|
||||||
|
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def topic_what(text, topic):
|
||||||
|
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||||
|
# 设置中文字体
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||||
|
|
||||||
|
G = memory_graph.G
|
||||||
|
|
||||||
|
# 保存图到本地
|
||||||
|
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
|
||||||
|
|
||||||
|
# 根据连接条数或记忆数量设置节点颜色
|
||||||
|
node_colors = []
|
||||||
|
nodes = list(G.nodes()) # 获取图中实际的节点列表
|
||||||
|
|
||||||
|
if color_by_memory:
|
||||||
|
# 计算每个节点的记忆数量
|
||||||
|
memory_counts = []
|
||||||
|
for node in nodes:
|
||||||
|
memory_items = G.nodes[node].get('memory_items', [])
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
count = len(memory_items)
|
||||||
|
else:
|
||||||
|
count = 1 if memory_items else 0
|
||||||
|
memory_counts.append(count)
|
||||||
|
max_memories = max(memory_counts) if memory_counts else 1
|
||||||
|
|
||||||
|
for count in memory_counts:
|
||||||
|
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
||||||
|
if max_memories > 0:
|
||||||
|
intensity = min(1.0, count / max_memories)
|
||||||
|
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
||||||
|
else:
|
||||||
|
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
||||||
|
node_colors.append(color)
|
||||||
|
else:
|
||||||
|
# 使用原来的连接数量着色方案
|
||||||
|
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
|
||||||
|
for node in nodes:
|
||||||
|
degree = G.degree(node)
|
||||||
|
if max_degree > 0:
|
||||||
|
red = min(1.0, degree / max_degree)
|
||||||
|
blue = 1.0 - red
|
||||||
|
color = (red, 0, blue)
|
||||||
|
else:
|
||||||
|
color = (0, 0, 1)
|
||||||
|
node_colors.append(color)
|
||||||
|
|
||||||
|
# 绘制图形
|
||||||
|
plt.figure(figsize=(12, 8))
|
||||||
|
pos = nx.spring_layout(G, k=1, iterations=50)
|
||||||
|
nx.draw(G, pos,
|
||||||
|
with_labels=True,
|
||||||
|
node_color=node_colors,
|
||||||
|
node_size=2000,
|
||||||
|
font_size=10,
|
||||||
|
font_family='SimHei',
|
||||||
|
font_weight='bold')
|
||||||
|
|
||||||
|
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
||||||
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
82
src/plugins/memory_system/llm_module_memory_make.py
Normal file
82
src/plugins/memory_system/llm_module_memory_make.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from typing import Tuple, Union
|
||||||
|
import time
|
||||||
|
from ..chat.config import BotConfig
|
||||||
|
|
||||||
|
# 获取当前文件的绝对路径
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||||
|
env_path = os.path.join(root_dir, 'config', '.env')
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
print(f"尝试从 {env_path} 加载环境变量配置")
|
||||||
|
if os.path.exists(env_path):
|
||||||
|
load_dotenv(env_path)
|
||||||
|
print("成功加载环境变量配置")
|
||||||
|
else:
|
||||||
|
print(f"环境变量配置文件不存在: {env_path}")
|
||||||
|
|
||||||
|
class LLMModel:
|
||||||
|
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
||||||
|
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.params = kwargs
|
||||||
|
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||||
|
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||||
|
|
||||||
|
if not self.api_key or not self.base_url:
|
||||||
|
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
||||||
|
|
||||||
|
print(f"API URL: {self.base_url}") # 打印 base_url 用于调试
|
||||||
|
|
||||||
|
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||||
|
"""根据输入的提示生成模型的响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.5,
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的chat/completions端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15 # 基础等待时间(秒)
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.post(api_url, headers=headers, json=data)
|
||||||
|
|
||||||
|
if response.status_code == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
|
||||||
import jieba
|
import jieba
|
||||||
from .llm_module import LLMModel
|
from .llm_module import LLMModel
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
@@ -11,8 +10,8 @@ import random
|
|||||||
import time
|
import time
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
import sys
|
import sys
|
||||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
from ...common.database import Database # 使用正确的导入语法
|
||||||
from src.common.database import Database # 使用正确的导入语法
|
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
|
||||||
|
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -85,54 +84,66 @@ class Memory_graph:
|
|||||||
|
|
||||||
return first_layer_items, second_layer_items
|
return first_layer_items, second_layer_items
|
||||||
|
|
||||||
def store_memory(self):
|
|
||||||
for node in self.G.nodes():
|
|
||||||
dot_data = {
|
|
||||||
"concept": node
|
|
||||||
}
|
|
||||||
self.db.db.store_memory_dots.insert_one(dot_data)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dots(self):
|
def dots(self):
|
||||||
# 返回所有节点对应的 Memory_dot 对象
|
# 返回所有节点对应的 Memory_dot 对象
|
||||||
return [self.get_dot(node) for node in self.G.nodes()]
|
return [self.get_dot(node) for node in self.G.nodes()]
|
||||||
|
|
||||||
|
|
||||||
def get_random_chat_from_db(self, length: int, timestamp: str):
|
|
||||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
|
||||||
chat_text = ''
|
|
||||||
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
|
||||||
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
|
||||||
|
|
||||||
if closest_record:
|
|
||||||
closest_time = closest_record['time']
|
|
||||||
group_id = closest_record['group_id'] # 获取groupid
|
|
||||||
# 获取该时间戳之后的length条消息,且groupid相同
|
|
||||||
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
|
||||||
for record in chat_record:
|
|
||||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
|
||||||
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
|
||||||
return chat_text
|
|
||||||
|
|
||||||
return [] # 如果没有找到记录,返回空列表
|
|
||||||
|
|
||||||
def save_graph_to_db(self):
|
def save_graph_to_db(self):
|
||||||
# 清空现有的图数据
|
|
||||||
self.db.db.graph_data.delete_many({})
|
|
||||||
# 保存节点
|
# 保存节点
|
||||||
for node in self.G.nodes(data=True):
|
for node in self.G.nodes(data=True):
|
||||||
node_data = {
|
concept = node[0]
|
||||||
'concept': node[0],
|
memory_items = node[1].get('memory_items', [])
|
||||||
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
|
||||||
}
|
# 查找是否存在同名节点
|
||||||
self.db.db.graph_data.nodes.insert_one(node_data)
|
existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
|
||||||
|
if existing_node:
|
||||||
|
# 如果存在,合并memory_items并去重
|
||||||
|
existing_items = existing_node.get('memory_items', [])
|
||||||
|
if not isinstance(existing_items, list):
|
||||||
|
existing_items = [existing_items] if existing_items else []
|
||||||
|
|
||||||
|
# 合并并去重
|
||||||
|
all_items = list(set(existing_items + memory_items))
|
||||||
|
|
||||||
|
# 更新节点
|
||||||
|
self.db.db.graph_data.nodes.update_one(
|
||||||
|
{'concept': concept},
|
||||||
|
{'$set': {'memory_items': all_items}}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果不存在,创建新节点
|
||||||
|
node_data = {
|
||||||
|
'concept': concept,
|
||||||
|
'memory_items': memory_items
|
||||||
|
}
|
||||||
|
self.db.db.graph_data.nodes.insert_one(node_data)
|
||||||
|
|
||||||
# 保存边
|
# 保存边
|
||||||
for edge in self.G.edges():
|
for edge in self.G.edges():
|
||||||
edge_data = {
|
source, target = edge
|
||||||
'source': edge[0],
|
|
||||||
'target': edge[1]
|
# 查找是否存在同样的边
|
||||||
}
|
existing_edge = self.db.db.graph_data.edges.find_one({
|
||||||
self.db.db.graph_data.edges.insert_one(edge_data)
|
'source': source,
|
||||||
|
'target': target
|
||||||
|
})
|
||||||
|
|
||||||
|
if existing_edge:
|
||||||
|
# 如果存在,增加num属性
|
||||||
|
num = existing_edge.get('num', 1) + 1
|
||||||
|
self.db.db.graph_data.edges.update_one(
|
||||||
|
{'source': source, 'target': target},
|
||||||
|
{'$set': {'num': num}}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果不存在,创建新边
|
||||||
|
edge_data = {
|
||||||
|
'source': source,
|
||||||
|
'target': target,
|
||||||
|
'num': 1
|
||||||
|
}
|
||||||
|
self.db.db.graph_data.edges.insert_one(edge_data)
|
||||||
|
|
||||||
def load_graph_from_db(self):
|
def load_graph_from_db(self):
|
||||||
# 清空当前图
|
# 清空当前图
|
||||||
@@ -147,150 +158,92 @@ class Memory_graph:
|
|||||||
# 加载边
|
# 加载边
|
||||||
edges = self.db.db.graph_data.edges.find()
|
edges = self.db.db.graph_data.edges.find()
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
self.G.add_edge(edge['source'], edge['target'])
|
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
||||||
|
|
||||||
def calculate_information_content(text):
|
|
||||||
|
|
||||||
"""计算文本的信息量(熵)"""
|
|
||||||
# 统计字符频率
|
|
||||||
char_count = Counter(text)
|
|
||||||
total_chars = len(text)
|
|
||||||
|
|
||||||
# 计算熵
|
|
||||||
entropy = 0
|
|
||||||
for count in char_count.values():
|
|
||||||
probability = count / total_chars
|
|
||||||
entropy -= probability * math.log2(probability)
|
|
||||||
|
|
||||||
return entropy
|
|
||||||
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
Database.initialize(
|
|
||||||
global_config.MONGODB_HOST,
|
|
||||||
global_config.MONGODB_PORT,
|
|
||||||
global_config.DATABASE_NAME
|
|
||||||
)
|
|
||||||
memory_graph = Memory_graph()
|
|
||||||
|
|
||||||
llm_model = LLMModel()
|
|
||||||
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
|
||||||
|
|
||||||
memory_graph.load_graph_from_db()
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"加载海马体耗时: {end_time - start_time:.2f} 秒")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# 初始化数据库
|
|
||||||
Database.initialize(
|
|
||||||
"127.0.0.1",
|
|
||||||
27017,
|
|
||||||
"MegBot"
|
|
||||||
)
|
|
||||||
|
|
||||||
memory_graph = Memory_graph()
|
|
||||||
# 创建LLM模型实例
|
|
||||||
llm_model = LLMModel()
|
|
||||||
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
|
||||||
|
|
||||||
# 使用当前时间戳进行测试
|
|
||||||
current_timestamp = datetime.datetime.now().timestamp()
|
|
||||||
chat_text = []
|
|
||||||
|
|
||||||
chat_size =40
|
|
||||||
|
|
||||||
for _ in range(100): # 循环10次
|
|
||||||
random_time = current_timestamp - random.randint(1, 3600*39) # 随机时间
|
|
||||||
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
|
||||||
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
|
|
||||||
chat_text.append(chat_) # 拼接所有text
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
|
|
||||||
|
# 海马体
|
||||||
for input_text in chat_text:
|
class Hippocampus:
|
||||||
print(input_text)
|
def __init__(self,memory_graph:Memory_graph):
|
||||||
first_memory = set()
|
self.memory_graph = memory_graph
|
||||||
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
|
self.llm_model = LLMModel()
|
||||||
|
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||||
|
|
||||||
#将记忆加入到图谱中
|
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||||
for topic, memory in first_memory:
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
topics = segment_text(topic)
|
chat_text = []
|
||||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
#短期:1h 中期:4h 长期:24h
|
||||||
for split_topic in topics:
|
for _ in range(time_frequency.get('near')): # 循环10次
|
||||||
memory_graph.add_dot(split_topic,memory)
|
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
|
||||||
for split_topic in topics:
|
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||||
for other_split_topic in topics:
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
if split_topic != other_split_topic:
|
chat_text.append(chat_)
|
||||||
memory_graph.connect_dot(split_topic, other_split_topic)
|
for _ in range(time_frequency.get('mid')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
|
||||||
|
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
for _ in range(time_frequency.get('far')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
||||||
|
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
return chat_text
|
||||||
|
|
||||||
# memory_graph.store_memory()
|
def build_memory(self,chat_size=12):
|
||||||
|
#最近消息获取频率
|
||||||
# 展示两种不同的可视化方式
|
time_frequency = {'near':1,'mid':2,'far':2}
|
||||||
print("\n按连接数量着色的图谱:")
|
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||||
visualize_graph(memory_graph, color_by_memory=False)
|
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
|
||||||
|
|
||||||
print("\n按记忆数量着色的图谱:")
|
|
||||||
visualize_graph(memory_graph, color_by_memory=True)
|
for i, input_text in enumerate(memory_sample, 1):
|
||||||
|
#加载进度可视化
|
||||||
memory_graph.save_graph_to_db()
|
progress = (i / len(memory_sample)) * 100
|
||||||
# memory_graph.load_graph_from_db()
|
bar_length = 30
|
||||||
|
filled_length = int(bar_length * i // len(memory_sample))
|
||||||
while True:
|
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||||
if query.lower() == '退出':
|
|
||||||
break
|
|
||||||
items_list = memory_graph.get_related_item(query)
|
|
||||||
if items_list:
|
|
||||||
# print(items_list)
|
|
||||||
for memory_item in items_list:
|
|
||||||
print(memory_item)
|
|
||||||
else:
|
|
||||||
print("未找到相关记忆。")
|
|
||||||
|
|
||||||
while True:
|
# 生成压缩后记忆
|
||||||
query = input("请输入问题:")
|
first_memory = set()
|
||||||
|
first_memory = self.memory_compress(input_text, 2.5)
|
||||||
if query.lower() == '退出':
|
# 延时防止访问超频
|
||||||
break
|
# time.sleep(5)
|
||||||
|
#将记忆加入到图谱中
|
||||||
topic_prompt = find_topic(query, 3)
|
for topic, memory in first_memory:
|
||||||
topic_response = llm_model.generate_response(topic_prompt)
|
topics = segment_text(topic)
|
||||||
|
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||||
|
for split_topic in topics:
|
||||||
|
self.memory_graph.add_dot(split_topic,memory)
|
||||||
|
for split_topic in topics:
|
||||||
|
for other_split_topic in topics:
|
||||||
|
if split_topic != other_split_topic:
|
||||||
|
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||||
|
|
||||||
|
self.memory_graph.save_graph_to_db()
|
||||||
|
|
||||||
|
def memory_compress(self, input_text, rate=1):
|
||||||
|
information_content = calculate_information_content(input_text)
|
||||||
|
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
||||||
|
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
||||||
|
# print(topic_num)
|
||||||
|
topic_prompt = find_topic(input_text, topic_num)
|
||||||
|
topic_response = self.llm_model.generate_response(topic_prompt)
|
||||||
# 检查 topic_response 是否为元组
|
# 检查 topic_response 是否为元组
|
||||||
if isinstance(topic_response, tuple):
|
if isinstance(topic_response, tuple):
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||||
else:
|
else:
|
||||||
topics = topic_response.split(",")
|
topics = topic_response.split(",")
|
||||||
print(topics)
|
# print(topics)
|
||||||
|
compressed_memory = set()
|
||||||
for keyword in topics:
|
for topic in topics:
|
||||||
items_list = memory_graph.get_related_item(keyword)
|
topic_what_prompt = topic_what(input_text,topic)
|
||||||
if items_list:
|
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
|
||||||
print(items_list)
|
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
||||||
|
return compressed_memory
|
||||||
def memory_compress(input_text, llm_model, llm_model_small, rate=1):
|
|
||||||
information_content = calculate_information_content(input_text)
|
|
||||||
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
|
||||||
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
|
||||||
print(topic_num)
|
|
||||||
topic_prompt = find_topic(input_text, topic_num)
|
|
||||||
topic_response = llm_model.generate_response(topic_prompt)
|
|
||||||
# 检查 topic_response 是否为元组
|
|
||||||
if isinstance(topic_response, tuple):
|
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
|
||||||
else:
|
|
||||||
topics = topic_response.split(",")
|
|
||||||
print(topics)
|
|
||||||
compressed_memory = set()
|
|
||||||
for topic in topics:
|
|
||||||
topic_what_prompt = topic_what(input_text,topic)
|
|
||||||
topic_what_response = llm_model_small.generate_response(topic_what_prompt)
|
|
||||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
|
||||||
return compressed_memory
|
|
||||||
|
|
||||||
|
|
||||||
def segment_text(text):
|
def segment_text(text):
|
||||||
@@ -305,69 +258,21 @@ def topic_what(text, topic):
|
|||||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|
||||||
# 设置中文字体
|
|
||||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
|
||||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
|
||||||
|
|
||||||
G = memory_graph.G
|
|
||||||
|
|
||||||
# 保存图到本地
|
|
||||||
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
|
|
||||||
|
|
||||||
# 根据连接条数或记忆数量设置节点颜色
|
|
||||||
node_colors = []
|
|
||||||
nodes = list(G.nodes()) # 获取图中实际的节点列表
|
|
||||||
|
|
||||||
if color_by_memory:
|
|
||||||
# 计算每个节点的记忆数量
|
|
||||||
memory_counts = []
|
|
||||||
for node in nodes:
|
|
||||||
memory_items = G.nodes[node].get('memory_items', [])
|
|
||||||
if isinstance(memory_items, list):
|
|
||||||
count = len(memory_items)
|
|
||||||
else:
|
|
||||||
count = 1 if memory_items else 0
|
|
||||||
memory_counts.append(count)
|
|
||||||
max_memories = max(memory_counts) if memory_counts else 1
|
|
||||||
|
|
||||||
for count in memory_counts:
|
|
||||||
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
|
||||||
if max_memories > 0:
|
|
||||||
intensity = min(1.0, count / max_memories)
|
|
||||||
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
|
||||||
else:
|
|
||||||
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
|
||||||
node_colors.append(color)
|
|
||||||
else:
|
|
||||||
# 使用原来的连接数量着色方案
|
|
||||||
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
|
|
||||||
for node in nodes:
|
|
||||||
degree = G.degree(node)
|
|
||||||
if max_degree > 0:
|
|
||||||
red = min(1.0, degree / max_degree)
|
|
||||||
blue = 1.0 - red
|
|
||||||
color = (red, 0, blue)
|
|
||||||
else:
|
|
||||||
color = (0, 0, 1)
|
|
||||||
node_colors.append(color)
|
|
||||||
|
|
||||||
# 绘制图形
|
|
||||||
plt.figure(figsize=(12, 8))
|
|
||||||
pos = nx.spring_layout(G, k=1, iterations=50)
|
|
||||||
nx.draw(G, pos,
|
|
||||||
with_labels=True,
|
|
||||||
node_color=node_colors,
|
|
||||||
node_size=2000,
|
|
||||||
font_size=10,
|
|
||||||
font_family='SimHei',
|
|
||||||
font_weight='bold')
|
|
||||||
|
|
||||||
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
|
||||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
Database.initialize(
|
||||||
|
global_config.MONGODB_HOST,
|
||||||
|
global_config.MONGODB_PORT,
|
||||||
|
global_config.DATABASE_NAME
|
||||||
|
)
|
||||||
|
#创建记忆图
|
||||||
|
memory_graph = Memory_graph()
|
||||||
|
#加载数据库中存储的记忆图
|
||||||
|
memory_graph.load_graph_from_db()
|
||||||
|
#创建海马体
|
||||||
|
hippocampus = Hippocampus(memory_graph)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
import sys
|
||||||
import jieba
|
import jieba
|
||||||
from llm_module import LLMModel
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import math
|
import math
|
||||||
@@ -9,10 +8,12 @@ from collections import Counter
|
|||||||
import datetime
|
import datetime
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
# from chat.config import global_config
|
# from chat.config import global_config
|
||||||
import sys
|
|
||||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||||
from src.common.database import Database # 使用正确的导入语法
|
from src.common.database import Database # 使用正确的导入语法
|
||||||
|
from src.plugins.memory_system.llm_module import LLMModel
|
||||||
|
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -117,22 +118,60 @@ class Memory_graph:
|
|||||||
return [] # 如果没有找到记录,返回空列表
|
return [] # 如果没有找到记录,返回空列表
|
||||||
|
|
||||||
def save_graph_to_db(self):
|
def save_graph_to_db(self):
|
||||||
# 清空现有的图数据
|
|
||||||
self.db.db.graph_data.delete_many({})
|
|
||||||
# 保存节点
|
# 保存节点
|
||||||
for node in self.G.nodes(data=True):
|
for node in self.G.nodes(data=True):
|
||||||
node_data = {
|
concept = node[0]
|
||||||
'concept': node[0],
|
memory_items = node[1].get('memory_items', [])
|
||||||
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
|
||||||
}
|
# 查找是否存在同名节点
|
||||||
self.db.db.graph_data.nodes.insert_one(node_data)
|
existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
|
||||||
|
if existing_node:
|
||||||
|
# 如果存在,合并memory_items并去重
|
||||||
|
existing_items = existing_node.get('memory_items', [])
|
||||||
|
if not isinstance(existing_items, list):
|
||||||
|
existing_items = [existing_items] if existing_items else []
|
||||||
|
|
||||||
|
# 合并并去重
|
||||||
|
all_items = list(set(existing_items + memory_items))
|
||||||
|
|
||||||
|
# 更新节点
|
||||||
|
self.db.db.graph_data.nodes.update_one(
|
||||||
|
{'concept': concept},
|
||||||
|
{'$set': {'memory_items': all_items}}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果不存在,创建新节点
|
||||||
|
node_data = {
|
||||||
|
'concept': concept,
|
||||||
|
'memory_items': memory_items
|
||||||
|
}
|
||||||
|
self.db.db.graph_data.nodes.insert_one(node_data)
|
||||||
|
|
||||||
# 保存边
|
# 保存边
|
||||||
for edge in self.G.edges():
|
for edge in self.G.edges():
|
||||||
edge_data = {
|
source, target = edge
|
||||||
'source': edge[0],
|
|
||||||
'target': edge[1]
|
# 查找是否存在同样的边
|
||||||
}
|
existing_edge = self.db.db.graph_data.edges.find_one({
|
||||||
self.db.db.graph_data.edges.insert_one(edge_data)
|
'source': source,
|
||||||
|
'target': target
|
||||||
|
})
|
||||||
|
|
||||||
|
if existing_edge:
|
||||||
|
# 如果存在,增加num属性
|
||||||
|
num = existing_edge.get('num', 1) + 1
|
||||||
|
self.db.db.graph_data.edges.update_one(
|
||||||
|
{'source': source, 'target': target},
|
||||||
|
{'$set': {'num': num}}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果不存在,创建新边
|
||||||
|
edge_data = {
|
||||||
|
'source': source,
|
||||||
|
'target': target,
|
||||||
|
'num': 1
|
||||||
|
}
|
||||||
|
self.db.db.graph_data.edges.insert_one(edge_data)
|
||||||
|
|
||||||
def load_graph_from_db(self):
|
def load_graph_from_db(self):
|
||||||
# 清空当前图
|
# 清空当前图
|
||||||
@@ -147,7 +186,7 @@ class Memory_graph:
|
|||||||
# 加载边
|
# 加载边
|
||||||
edges = self.db.db.graph_data.edges.find()
|
edges = self.db.db.graph_data.edges.find()
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
self.G.add_edge(edge['source'], edge['target'])
|
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
||||||
|
|
||||||
def calculate_information_content(text):
|
def calculate_information_content(text):
|
||||||
|
|
||||||
@@ -180,6 +219,19 @@ def calculate_information_content(text):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# 获取当前文件的绝对路径
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||||
|
env_path = os.path.join(root_dir, 'config', '.env')
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
print(f"尝试从 {env_path} 加载环境变量配置")
|
||||||
|
if os.path.exists(env_path):
|
||||||
|
load_dotenv(env_path)
|
||||||
|
print("成功加载环境变量配置")
|
||||||
|
else:
|
||||||
|
print(f"环境变量配置文件不存在: {env_path}")
|
||||||
|
|
||||||
# 初始化数据库
|
# 初始化数据库
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
"127.0.0.1",
|
"127.0.0.1",
|
||||||
@@ -196,10 +248,10 @@ def main():
|
|||||||
current_timestamp = datetime.datetime.now().timestamp()
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
chat_text = []
|
chat_text = []
|
||||||
|
|
||||||
chat_size =20
|
chat_size =25
|
||||||
|
|
||||||
for _ in range(10): # 循环10次
|
for _ in range(30): # 循环10次
|
||||||
random_time = current_timestamp - random.randint(1, 3600*3) # 随机时间
|
random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间
|
||||||
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||||
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
|
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
|
||||||
chat_text.append(chat_) # 拼接所有text
|
chat_text.append(chat_) # 拼接所有text
|
||||||
@@ -218,7 +270,7 @@ def main():
|
|||||||
# print(input_text)
|
# print(input_text)
|
||||||
first_memory = set()
|
first_memory = set()
|
||||||
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
|
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
|
||||||
time.sleep(5)
|
# time.sleep(5)
|
||||||
|
|
||||||
#将记忆加入到图谱中
|
#将记忆加入到图谱中
|
||||||
for topic, memory in first_memory:
|
for topic, memory in first_memory:
|
||||||
Reference in New Issue
Block a user